[CUDA] Make CudaEvent work with multi-device (#2614)

* Set current device when creating cuda event

* Separate cuda events by device

* Avoid race condition in pool
This commit is contained in:
Cheng
2025-09-27 11:27:17 +09:00
committed by GitHub
parent 7a6adda1e6
commit b466dea982
7 changed files with 73 additions and 26 deletions

View File

@@ -68,8 +68,8 @@ Device::~Device() {
void Device::make_current() { void Device::make_current() {
// We need to set/get current CUDA device very frequently, cache it to reduce // We need to set/get current CUDA device very frequently, cache it to reduce
// actual calls of CUDA APIs. This function assumes single-thread in host. // actual calls of CUDA APIs.
static int current = 0; static thread_local int current = 0;
if (current != device_) { if (current != device_) {
CHECK_CUDA_ERROR(cudaSetDevice(device_)); CHECK_CUDA_ERROR(cudaSetDevice(device_));
current = device_; current = device_;
@@ -196,6 +196,7 @@ CommandEncoder::CommandEncoder(Device& d)
: device_(d), : device_(d),
stream_(d), stream_(d),
graph_(d), graph_(d),
worker_(d),
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {} graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {}
void CommandEncoder::add_completed_handler(std::function<void()> task) { void CommandEncoder::add_completed_handler(std::function<void()> task) {

View File

@@ -140,7 +140,7 @@ class Device {
Device(const Device&) = delete; Device(const Device&) = delete;
Device& operator=(const Device&) = delete; Device& operator=(const Device&) = delete;
// Make this device the current cuda device, required by some cuda calls. // Make this device the current cuda device, this method is thread-safe.
void make_current(); void make_current();
CommandEncoder& get_command_encoder(Stream s); CommandEncoder& get_command_encoder(Stream s);

View File

@@ -15,9 +15,10 @@ bool is_available() {
} }
void new_stream(Stream s) { void new_stream(Stream s) {
// Force initalization of CUDA by creating an event, so the CUDA runtime and // Force initalization of CUDA, so CUDA runtime get destroyed at last.
// our CUDA event pool get destroyed last. cudaFree(nullptr);
cu::CudaEvent(cudaEventDefault); // Make sure CUDA event pool get destroyed after device and stream.
cu::CudaEvent::init_pool();
// Ensure the static stream objects get created. // Ensure the static stream objects get created.
cu::get_command_encoder(s); cu::get_command_encoder(s);
} }

View File

@@ -22,11 +22,15 @@ namespace cu {
namespace { namespace {
// Manage cached cudaEvent_t objects. // Manage cached cudaEvent_t objects.
struct CudaEventPool { class CudaEventPool {
static CudaEventHandle create(int flags) { public:
auto& cache = cache_for(flags); CudaEventHandle create(Device& d, int flags) {
if (!on_creation_thread()) {
return CudaEventHandle(d, flags);
}
auto& cache = cache_for(d, flags);
if (cache.empty()) { if (cache.empty()) {
return CudaEventHandle(flags); return CudaEventHandle(d, flags);
} else { } else {
CudaEventHandle ret = std::move(cache.back()); CudaEventHandle ret = std::move(cache.back());
cache.pop_back(); cache.pop_back();
@@ -34,54 +38,89 @@ struct CudaEventPool {
} }
} }
static void release(CudaEventHandle event) { void release(CudaEventHandle event) {
cache_for(event.flags).push_back(std::move(event)); if (!on_creation_thread()) {
// Event will be destroyed directly instead of getting moved to cache.
return;
}
cache_for(event.device, event.flags).push_back(std::move(event));
} }
static std::vector<CudaEventHandle>& cache_for(int flags) { private:
static std::map<int, std::vector<CudaEventHandle>> cache; std::vector<CudaEventHandle>& cache_for(Device& d, int flags) {
return cache[flags]; return cache_[d.cuda_device()][flags];
} }
bool on_creation_thread() {
return std::this_thread::get_id() == thread_id_;
}
// The CudaEvent may be created and destroyed on different threads (for
// example when waiting on GPU work in CPU stream), we don't want to make
// the cache thread-safe as it adds overhead, so we just skip cache when
// using events in worker threads.
std::thread::id thread_id_{std::this_thread::get_id()};
// {device: {flags: [events]}}
std::map<int, std::map<int, std::vector<CudaEventHandle>>> cache_;
}; };
CudaEventPool& cuda_event_pool() {
static CudaEventPool pool;
return pool;
}
} // namespace } // namespace
CudaEventHandle::CudaEventHandle(int flags) : flags(flags) { CudaEventHandle::CudaEventHandle(Device& d, int flags)
: device(d), flags(flags) {
device.make_current();
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(&handle_, flags)); CHECK_CUDA_ERROR(cudaEventCreateWithFlags(&handle_, flags));
assert(handle_ != nullptr); assert(handle_ != nullptr);
} }
CudaEvent::CudaEvent(int flags) : event_(CudaEventPool::create(flags)) {} CudaEvent::CudaEvent(Device& d, int flags)
: event_(cuda_event_pool().create(d, flags)) {}
CudaEvent::~CudaEvent() { CudaEvent::~CudaEvent() {
CudaEventPool::release(std::move(event_)); cuda_event_pool().release(std::move(event_));
} }
void CudaEvent::wait() { void CudaEvent::wait() {
nvtx3::scoped_range r("cu::CudaEvent::wait"); nvtx3::scoped_range r("cu::CudaEvent::wait");
event_.device.make_current();
cudaEventSynchronize(event_); cudaEventSynchronize(event_);
} }
void CudaEvent::wait(cudaStream_t stream) { void CudaEvent::wait(cudaStream_t stream) {
event_.device.make_current();
cudaStreamWaitEvent(stream, event_); cudaStreamWaitEvent(stream, event_);
} }
void CudaEvent::record(cudaStream_t stream) { void CudaEvent::record(cudaStream_t stream) {
event_.device.make_current();
cudaEventRecord(event_, stream); cudaEventRecord(event_, stream);
} }
bool CudaEvent::completed() const { bool CudaEvent::completed() const {
// Note: cudaEventQuery can be safely called from any device.
return cudaEventQuery(event_) == cudaSuccess; return cudaEventQuery(event_) == cudaSuccess;
} }
// static
void CudaEvent::init_pool() {
cuda_event_pool();
}
// Wraps CudaEvent with a few features: // Wraps CudaEvent with a few features:
// 1. The class can be copied. // 1. The class can be copied.
// 2. Make wait/record work with CPU streams. // 2. Make wait/record work with CPU streams.
// 3. Add checks for waiting on un-recorded event. // 3. Add checks for waiting on un-recorded event.
class CopyableCudaEvent { class CopyableCudaEvent {
public: public:
CopyableCudaEvent() explicit CopyableCudaEvent(Device& d)
: event_(std::make_shared<CudaEvent>( : event_(std::make_shared<CudaEvent>(
d,
cudaEventDisableTiming | cudaEventBlockingSync)) {} cudaEventDisableTiming | cudaEventBlockingSync)) {}
void wait() { void wait() {
@@ -245,7 +284,7 @@ struct EventImpl {
nvtx3::mark("Using slow AtomicEvent"); nvtx3::mark("Using slow AtomicEvent");
atomic = std::make_unique<cu::AtomicEvent>(); atomic = std::make_unique<cu::AtomicEvent>();
} else { } else {
cuda = std::make_unique<cu::CopyableCudaEvent>(); cuda = std::make_unique<cu::CopyableCudaEvent>(cu::device(s.device));
} }
} }
}; };

View File

@@ -13,9 +13,12 @@
namespace mlx::core::cu { namespace mlx::core::cu {
class Device;
// RAII-managed move-only wrapper of cudaEvent_t. // RAII-managed move-only wrapper of cudaEvent_t.
struct CudaEventHandle : public CudaHandle<cudaEvent_t, cudaEventDestroy> { struct CudaEventHandle : public CudaHandle<cudaEvent_t, cudaEventDestroy> {
CudaEventHandle(int flags); CudaEventHandle(Device& d, int flags);
Device& device;
int flags; int flags;
}; };
@@ -23,7 +26,7 @@ struct CudaEventHandle : public CudaHandle<cudaEvent_t, cudaEventDestroy> {
// on GPU stream in CPU stream, but can not wait on CPU stream. // on GPU stream in CPU stream, but can not wait on CPU stream.
class CudaEvent { class CudaEvent {
public: public:
explicit CudaEvent(int flags); CudaEvent(Device& d, int flags);
~CudaEvent(); ~CudaEvent();
CudaEvent(CudaEvent&&) = default; CudaEvent(CudaEvent&&) = default;
@@ -40,6 +43,9 @@ class CudaEvent {
// returns true if record() has not been called. // returns true if record() has not been called.
bool completed() const; bool completed() const;
// Internal: make sure event pool is initialized.
static void init_pool();
private: private:
CudaEventHandle event_; CudaEventHandle event_;
}; };

View File

@@ -5,9 +5,9 @@
namespace mlx::core::cu { namespace mlx::core::cu {
Worker::Worker() Worker::Worker(Device& d)
: signal_stream_(device(mlx::core::Device::gpu)), : signal_stream_(d),
signal_event_(cudaEventDisableTiming | cudaEventBlockingSync), signal_event_(d, cudaEventDisableTiming | cudaEventBlockingSync),
worker_(&Worker::thread_fn, this) {} worker_(&Worker::thread_fn, this) {}
Worker::~Worker() { Worker::~Worker() {

View File

@@ -15,7 +15,7 @@ namespace mlx::core::cu {
// Run tasks in worker thread, synchronized with cuda stream. // Run tasks in worker thread, synchronized with cuda stream.
class Worker { class Worker {
public: public:
Worker(); explicit Worker(Device& d);
~Worker(); ~Worker();
Worker(const Worker&) = delete; Worker(const Worker&) = delete;