Avoid atomic updates across CPU/GPU in CUDA event (#2231)

This commit is contained in:
Cheng
2025-06-04 08:49:06 +09:00
committed by GitHub
parent 0bb89e9e5f
commit 85a8beb5e4
4 changed files with 34 additions and 72 deletions

View File

@@ -156,7 +156,10 @@ void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
void SharedEvent::signal(Stream s, uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::signal(s)");
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this, value]() mutable { signal(value); });
// Signal through a GPU stream so the atomic is updated in GPU - updating
// the atomic in CPU sometimes does not get GPU notified.
static CudaStream stream(device(mlx::core::Device::gpu));
scheduler::enqueue(s, [*this, value]() mutable { signal(stream, value); });
} else {
auto& encoder = get_command_encoder(s);
encoder.launch_kernel(