diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index 00f78fd4f..1d17d7df5 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -106,7 +106,6 @@ void CudaAllocator::cuda_free(void* buf) { return; } } - cudaFree(buf); } diff --git a/mlx/backend/cuda/copy/copy_general.cu b/mlx/backend/cuda/copy/copy_general.cu index 9f50c8a31..2dc08c60a 100644 --- a/mlx/backend/cuda/copy/copy_general.cu +++ b/mlx/backend/cuda/copy/copy_general.cu @@ -63,25 +63,30 @@ void copy_general( MLX_SWITCH_BOOL(large, LARGE, { using IdxT = std::conditional_t; int ndim = shape.size(); + size_t data_size = 1; + for (auto& s : shape) + data_size *= s; if (ndim <= 3) { MLX_SWITCH_1_2_3(ndim, NDIM, { auto kernel = cu::copy_gg_nd; - auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); + auto [num_blocks, block_dims] = + get_launch_args(kernel, data_size, shape, out.strides(), large); kernel<<>>( in_ptr, out_ptr, - out.size(), + data_size, const_param(shape), const_param(strides_in), const_param(strides_out)); }); } else { // ndim >= 4 auto kernel = cu::copy_gg; - auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); + auto [num_blocks, block_dims] = + get_launch_args(kernel, data_size, shape, out.strides(), large); kernel<<>>( in_ptr, out_ptr, - out.size(), + data_size, const_param(shape), const_param(strides_in), const_param(strides_out), diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 8a3d66c8e..fcf7fdf5e 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -6,6 +6,7 @@ #include #include +#include namespace mlx::core { @@ -107,6 +108,16 @@ void CommandEncoder::commit() { worker_.commit(stream_.last_cuda_stream()); } +void CommandEncoder::synchronize() { + stream().synchronize(); + auto p = std::make_shared>(); + std::future f = p->get_future(); + add_completed_handler([p = std::move(p)]() { p->set_value(); }); + worker_.end_batch(); + worker_.commit(); + f.wait(); +} + Device& device(mlx::core::Device device) { static std::unordered_map devices; auto it = devices.find(device.index); diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index 5b2cc0607..744f77f62 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -123,6 +123,9 @@ class CommandEncoder { return has_gpu_work_; } + // Wait until kernels and completion handlers are finished + void synchronize(); + private: Device& device_; DeviceStream& stream_; diff --git a/mlx/backend/cuda/eval.cpp b/mlx/backend/cuda/eval.cpp index b309ad60e..21b019cd8 100644 --- a/mlx/backend/cuda/eval.cpp +++ b/mlx/backend/cuda/eval.cpp @@ -62,7 +62,7 @@ void finalize(Stream s) { void synchronize(Stream s) { nvtx3::scoped_range r("gpu::synchronize"); - cu::get_stream(s).synchronize(); + cu::get_command_encoder(s).synchronize(); } } // namespace mlx::core::gpu diff --git a/mlx/backend/cuda/worker.cpp b/mlx/backend/cuda/worker.cpp index 64b5c7679..3b35c830b 100644 --- a/mlx/backend/cuda/worker.cpp +++ b/mlx/backend/cuda/worker.cpp @@ -80,7 +80,9 @@ void Worker::thread_fn() { } worker_tasks_.erase(worker_tasks_.begin(), end); } - for (auto& task : tasks) { + // Make sure tasks are cleared before the next wait + for (int i = 0; i < tasks.size(); ++i) { + auto task = std::move(tasks[i]); task(); } worker_event_.wait(batch + 1); diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 36388c3c5..bcb95dbb7 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -6,7 +6,6 @@ cuda_skip = { "TestEinsum.test_ellipses", "TestEinsum.test_opt_einsum_test_cases", "TestLoad.test_load_f8_e4m3", - "TestMemory.test_memory_info", "TestLayers.test_group_norm", "TestLayers.test_pooling", "TestLayers.test_quantized_embedding",