diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index c9a248104..de3357cf9 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -92,7 +92,7 @@ CudaAllocator::CudaAllocator() [this](CudaBuffer* buf) { cuda_free(buf); }) { size_t free, total; CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); - memory_limit_ = total * 0.95; + memory_limit_ = total * 0.9; max_pool_size_ = memory_limit_; int device_count = 0; @@ -176,7 +176,7 @@ CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) { buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_); } // Copy to managed here if the buffer is not on the right device - if (buf->device != device) { + if (buf->device >= 0 && buf->device != device) { copy_to_managed(*buf); } return Buffer{buf}; @@ -219,9 +219,9 @@ void CudaAllocator::cuda_free(CudaBuffer* buf) { scalar_pool_.free(buf); } else { if (buf->device >= 0) { - cudaFreeAsync(buf->data, free_streams_[buf->device]); + CHECK_CUDA_ERROR(cudaFreeAsync(buf->data, free_streams_[buf->device])); } else { - cudaFree(buf->data); + CHECK_CUDA_ERROR(cudaFree(buf->data)); } delete buf; }