diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 2a8ef9963..8c9a40d03 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -10,7 +10,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cu - ${CMAKE_CURRENT_SOURCE_DIR}/fence.cu + ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp diff --git a/mlx/backend/cuda/event.cu b/mlx/backend/cuda/event.cu index f462720a9..9fc5c641b 100644 --- a/mlx/backend/cuda/event.cu +++ b/mlx/backend/cuda/event.cu @@ -156,7 +156,10 @@ void SharedEvent::signal(cudaStream_t stream, uint64_t value) { void SharedEvent::signal(Stream s, uint64_t value) { nvtx3::scoped_range r("cu::SharedEvent::signal(s)"); if (s.device == mlx::core::Device::cpu) { - scheduler::enqueue(s, [*this, value]() mutable { signal(value); }); + // Signal through a GPU stream so the atomic is updated in GPU - updating + // the atomic in CPU sometimes does not get GPU notified. + static CudaStream stream(device(mlx::core::Device::gpu)); + scheduler::enqueue(s, [*this, value]() mutable { signal(stream, value); }); } else { auto& encoder = get_command_encoder(s); encoder.launch_kernel( diff --git a/mlx/backend/cuda/fence.cpp b/mlx/backend/cuda/fence.cpp new file mode 100644 index 000000000..f399c4ebb --- /dev/null +++ b/mlx/backend/cuda/fence.cpp @@ -0,0 +1,29 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/fence.h" +#include "mlx/backend/cuda/event.h" + +namespace mlx::core { + +struct FenceImpl { + uint32_t count; + cu::SharedEvent event; +}; + +Fence::Fence(Stream s) { + fence_ = std::shared_ptr( + new FenceImpl{0}, [](void* ptr) { delete static_cast(ptr); }); +} + +void Fence::wait(Stream s, const array&) { + auto* fence = static_cast(fence_.get()); + fence->event.wait(fence->count); +} + +void Fence::update(Stream s, const array&) { + auto* fence = static_cast(fence_.get()); + fence->count++; + fence->event.signal(s, fence->count); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/fence.cu b/mlx/backend/cuda/fence.cu deleted file mode 100644 index 091b252c1..000000000 --- a/mlx/backend/cuda/fence.cu +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/event.h" -#include "mlx/fence.h" -#include "mlx/scheduler.h" - -#include - -namespace mlx::core { - -namespace { - -__host__ __device__ void busy_wait(cuda::atomic* ac, uint64_t value) { - while (true) { - // In theory the atomic_thread_fence is not needed, but for CUDA 11 without - // it the load() may never return new value. - cuda::atomic_thread_fence(cuda::memory_order_seq_cst); - uint64_t current = ac->load(); - if (current >= value) { - break; - } - } -} - -__global__ void busy_wait_kernel(cuda::atomic* ac, uint64_t value) { - busy_wait(ac, value); -} - -} // namespace - -struct FenceImpl { - uint32_t count; - cu::SharedEvent event; -}; - -Fence::Fence(Stream s) { - fence_ = std::shared_ptr( - new FenceImpl{0}, [](void* ptr) { delete static_cast(ptr); }); -} - -void Fence::wait(Stream s, const array&) { - auto* fence = static_cast(fence_.get()); - // We can't use SharedEvent::wait because it could hang in CUDA 11, see also: - // https://github.com/ml-explore/mlx/issues/2137 - const auto& ac = fence->event.atomic(); - if (s.device == mlx::core::Device::cpu) { - scheduler::enqueue(s, [ac, count = fence->count]() { - nvtx3::scoped_range r("Fence::wait()"); - busy_wait(ac.get(), count); - }); - } else { - nvtx3::scoped_range r("Fence::wait(s)"); - auto& encoder = cu::get_command_encoder(s); - encoder.launch_kernel( - encoder.stream().last_cuda_stream(), [&](cudaStream_t stream) { - busy_wait_kernel<<<1, 1, 0>>>(ac.get(), fence->count); - }); - encoder.add_completed_handler([ac]() {}); - encoder.end_encoding(); - } -} - -void Fence::update(Stream s, const array&) { - auto* fence = static_cast(fence_.get()); - fence->count++; - fence->event.signal(s, fence->count); -} - -} // namespace mlx::core