[CUDA] synch properly waits for all tasks to finish and clear (#2303)

* cuda synch properly waits for all tasks to finish and clear

* fix copy
This commit is contained in:
Awni Hannun 2025-06-17 12:03:25 -07:00 committed by GitHub
parent b8022c578a
commit cad5c0241c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 27 additions and 8 deletions

View File

@ -106,7 +106,6 @@ void CudaAllocator::cuda_free(void* buf) {
return;
}
}
cudaFree(buf);
}

View File

@ -63,25 +63,30 @@ void copy_general(
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
size_t data_size = 1;
for (auto& s : shape)
data_size *= s;
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
auto [num_blocks, block_dims] =
get_launch_args(kernel, data_size, shape, out.strides(), large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.size(),
data_size,
const_param<NDIM>(shape),
const_param<NDIM>(strides_in),
const_param<NDIM>(strides_out));
});
} else { // ndim >= 4
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
auto [num_blocks, block_dims] =
get_launch_args(kernel, data_size, shape, out.strides(), large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.size(),
data_size,
const_param(shape),
const_param(strides_in),
const_param(strides_out),

View File

@ -6,6 +6,7 @@
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp>
#include <future>
namespace mlx::core {
@ -107,6 +108,16 @@ void CommandEncoder::commit() {
worker_.commit(stream_.last_cuda_stream());
}
void CommandEncoder::synchronize() {
stream().synchronize();
auto p = std::make_shared<std::promise<void>>();
std::future<void> f = p->get_future();
add_completed_handler([p = std::move(p)]() { p->set_value(); });
worker_.end_batch();
worker_.commit();
f.wait();
}
Device& device(mlx::core::Device device) {
static std::unordered_map<int, Device> devices;
auto it = devices.find(device.index);

View File

@ -123,6 +123,9 @@ class CommandEncoder {
return has_gpu_work_;
}
// Wait until kernels and completion handlers are finished
void synchronize();
private:
Device& device_;
DeviceStream& stream_;

View File

@ -62,7 +62,7 @@ void finalize(Stream s) {
void synchronize(Stream s) {
nvtx3::scoped_range r("gpu::synchronize");
cu::get_stream(s).synchronize();
cu::get_command_encoder(s).synchronize();
}
} // namespace mlx::core::gpu

View File

@ -80,7 +80,9 @@ void Worker::thread_fn() {
}
worker_tasks_.erase(worker_tasks_.begin(), end);
}
for (auto& task : tasks) {
// Make sure tasks are cleared before the next wait
for (int i = 0; i < tasks.size(); ++i) {
auto task = std::move(tasks[i]);
task();
}
worker_event_.wait(batch + 1);

View File

@ -6,7 +6,6 @@ cuda_skip = {
"TestEinsum.test_ellipses",
"TestEinsum.test_opt_einsum_test_cases",
"TestLoad.test_load_f8_e4m3",
"TestMemory.test_memory_info",
"TestLayers.test_group_norm",
"TestLayers.test_pooling",
"TestLayers.test_quantized_embedding",