mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Avoid atomic updates across CPU/GPU in CUDA event (#2231)
This commit is contained in:
parent
0bb89e9e5f
commit
85a8beb5e4
@ -10,7 +10,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||||
|
@ -156,7 +156,10 @@ void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
|
|||||||
void SharedEvent::signal(Stream s, uint64_t value) {
|
void SharedEvent::signal(Stream s, uint64_t value) {
|
||||||
nvtx3::scoped_range r("cu::SharedEvent::signal(s)");
|
nvtx3::scoped_range r("cu::SharedEvent::signal(s)");
|
||||||
if (s.device == mlx::core::Device::cpu) {
|
if (s.device == mlx::core::Device::cpu) {
|
||||||
scheduler::enqueue(s, [*this, value]() mutable { signal(value); });
|
// Signal through a GPU stream so the atomic is updated in GPU - updating
|
||||||
|
// the atomic in CPU sometimes does not get GPU notified.
|
||||||
|
static CudaStream stream(device(mlx::core::Device::gpu));
|
||||||
|
scheduler::enqueue(s, [*this, value]() mutable { signal(stream, value); });
|
||||||
} else {
|
} else {
|
||||||
auto& encoder = get_command_encoder(s);
|
auto& encoder = get_command_encoder(s);
|
||||||
encoder.launch_kernel(
|
encoder.launch_kernel(
|
||||||
|
29
mlx/backend/cuda/fence.cpp
Normal file
29
mlx/backend/cuda/fence.cpp
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/fence.h"
|
||||||
|
#include "mlx/backend/cuda/event.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
struct FenceImpl {
|
||||||
|
uint32_t count;
|
||||||
|
cu::SharedEvent event;
|
||||||
|
};
|
||||||
|
|
||||||
|
Fence::Fence(Stream s) {
|
||||||
|
fence_ = std::shared_ptr<void>(
|
||||||
|
new FenceImpl{0}, [](void* ptr) { delete static_cast<FenceImpl*>(ptr); });
|
||||||
|
}
|
||||||
|
|
||||||
|
void Fence::wait(Stream s, const array&) {
|
||||||
|
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
||||||
|
fence->event.wait(fence->count);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Fence::update(Stream s, const array&) {
|
||||||
|
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
||||||
|
fence->count++;
|
||||||
|
fence->event.signal(s, fence->count);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -1,70 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
|
||||||
#include "mlx/backend/cuda/event.h"
|
|
||||||
#include "mlx/fence.h"
|
|
||||||
#include "mlx/scheduler.h"
|
|
||||||
|
|
||||||
#include <nvtx3/nvtx3.hpp>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
__host__ __device__ void busy_wait(cuda::atomic<uint64_t>* ac, uint64_t value) {
|
|
||||||
while (true) {
|
|
||||||
// In theory the atomic_thread_fence is not needed, but for CUDA 11 without
|
|
||||||
// it the load() may never return new value.
|
|
||||||
cuda::atomic_thread_fence(cuda::memory_order_seq_cst);
|
|
||||||
uint64_t current = ac->load();
|
|
||||||
if (current >= value) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
__global__ void busy_wait_kernel(cuda::atomic<uint64_t>* ac, uint64_t value) {
|
|
||||||
busy_wait(ac, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
struct FenceImpl {
|
|
||||||
uint32_t count;
|
|
||||||
cu::SharedEvent event;
|
|
||||||
};
|
|
||||||
|
|
||||||
Fence::Fence(Stream s) {
|
|
||||||
fence_ = std::shared_ptr<void>(
|
|
||||||
new FenceImpl{0}, [](void* ptr) { delete static_cast<FenceImpl*>(ptr); });
|
|
||||||
}
|
|
||||||
|
|
||||||
void Fence::wait(Stream s, const array&) {
|
|
||||||
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
|
||||||
// We can't use SharedEvent::wait because it could hang in CUDA 11, see also:
|
|
||||||
// https://github.com/ml-explore/mlx/issues/2137
|
|
||||||
const auto& ac = fence->event.atomic();
|
|
||||||
if (s.device == mlx::core::Device::cpu) {
|
|
||||||
scheduler::enqueue(s, [ac, count = fence->count]() {
|
|
||||||
nvtx3::scoped_range r("Fence::wait()");
|
|
||||||
busy_wait(ac.get(), count);
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
nvtx3::scoped_range r("Fence::wait(s)");
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
|
||||||
encoder.launch_kernel(
|
|
||||||
encoder.stream().last_cuda_stream(), [&](cudaStream_t stream) {
|
|
||||||
busy_wait_kernel<<<1, 1, 0>>>(ac.get(), fence->count);
|
|
||||||
});
|
|
||||||
encoder.add_completed_handler([ac]() {});
|
|
||||||
encoder.end_encoding();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Fence::update(Stream s, const array&) {
|
|
||||||
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
|
||||||
fence->count++;
|
|
||||||
fence->event.signal(s, fence->count);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
Loading…
Reference in New Issue
Block a user