mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix cuda allocator copy condition (#2800)
This commit is contained in:
@@ -92,7 +92,7 @@ CudaAllocator::CudaAllocator()
|
|||||||
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
||||||
size_t free, total;
|
size_t free, total;
|
||||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||||
memory_limit_ = total * 0.95;
|
memory_limit_ = total * 0.9;
|
||||||
max_pool_size_ = memory_limit_;
|
max_pool_size_ = memory_limit_;
|
||||||
|
|
||||||
int device_count = 0;
|
int device_count = 0;
|
||||||
@@ -176,7 +176,7 @@ CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
|
|||||||
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
||||||
}
|
}
|
||||||
// Copy to managed here if the buffer is not on the right device
|
// Copy to managed here if the buffer is not on the right device
|
||||||
if (buf->device != device) {
|
if (buf->device >= 0 && buf->device != device) {
|
||||||
copy_to_managed(*buf);
|
copy_to_managed(*buf);
|
||||||
}
|
}
|
||||||
return Buffer{buf};
|
return Buffer{buf};
|
||||||
@@ -219,9 +219,9 @@ void CudaAllocator::cuda_free(CudaBuffer* buf) {
|
|||||||
scalar_pool_.free(buf);
|
scalar_pool_.free(buf);
|
||||||
} else {
|
} else {
|
||||||
if (buf->device >= 0) {
|
if (buf->device >= 0) {
|
||||||
cudaFreeAsync(buf->data, free_streams_[buf->device]);
|
CHECK_CUDA_ERROR(cudaFreeAsync(buf->data, free_streams_[buf->device]));
|
||||||
} else {
|
} else {
|
||||||
cudaFree(buf->data);
|
CHECK_CUDA_ERROR(cudaFree(buf->data));
|
||||||
}
|
}
|
||||||
delete buf;
|
delete buf;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user