diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index 7d1207cfa..1a696f08d 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -106,6 +106,10 @@ CudaAllocator::CudaAllocator() } Buffer CudaAllocator::malloc_impl(size_t size, cudaStream_t stream) { + if (size == 0) { + return Buffer{new CudaBuffer{nullptr, 0, -1}}; + } + // Find available buffer from cache. auto orig_size = size; std::unique_lock lock(mutex_); @@ -173,6 +177,10 @@ void CudaAllocator::free(Buffer buffer) { if (!buf) { return; } + if (buf->size == 0) { + delete buf; + return; + } std::unique_lock lock(mutex_); active_memory_ -= buf->size; diff --git a/mlx/backend/cuda/fence.cpp b/mlx/backend/cuda/fence.cpp index f11ff7e9e..a50c9111f 100644 --- a/mlx/backend/cuda/fence.cpp +++ b/mlx/backend/cuda/fence.cpp @@ -34,8 +34,8 @@ void Fence::update(Stream s, const array& a, bool cross_device) { cbuf.device = -1; auto& encoder = cu::device(s.device).get_command_encoder(s); encoder.commit(); - CHECK_CUDA_ERROR( - cudaMemcpyAsync(new_data, cbuf.data, cbuf.size, cudaMemcpyDefault)); + CHECK_CUDA_ERROR(cudaMemcpyAsync( + new_data, cbuf.data, cbuf.size, cudaMemcpyDefault, encoder.stream())); CHECK_CUDA_ERROR(cudaFreeAsync(cbuf.data, encoder.stream())); cbuf.data = new_data; }