[CUDA] Make CudaEvent work with multi-device (#2614)

* Set current device when creating cuda event

* Separate cuda events by device

* Avoid race condition in pool
This commit is contained in:
Cheng
2025-09-27 11:27:17 +09:00
committed by GitHub
parent 7a6adda1e6
commit b466dea982
7 changed files with 73 additions and 26 deletions

View File

@@ -68,8 +68,8 @@ Device::~Device() {
void Device::make_current() {
// We need to set/get current CUDA device very frequently, cache it to reduce
// actual calls of CUDA APIs. This function assumes single-thread in host.
static int current = 0;
// actual calls of CUDA APIs.
static thread_local int current = 0;
if (current != device_) {
CHECK_CUDA_ERROR(cudaSetDevice(device_));
current = device_;
@@ -196,6 +196,7 @@ CommandEncoder::CommandEncoder(Device& d)
: device_(d),
stream_(d),
graph_(d),
worker_(d),
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {}
void CommandEncoder::add_completed_handler(std::function<void()> task) {

View File

@@ -140,7 +140,7 @@ class Device {
Device(const Device&) = delete;
Device& operator=(const Device&) = delete;
// Make this device the current cuda device, required by some cuda calls.
// Make this device the current cuda device, this method is thread-safe.
void make_current();
CommandEncoder& get_command_encoder(Stream s);

View File

@@ -15,9 +15,10 @@ bool is_available() {
}
void new_stream(Stream s) {
// Force initalization of CUDA by creating an event, so the CUDA runtime and
// our CUDA event pool get destroyed last.
cu::CudaEvent(cudaEventDefault);
// Force initalization of CUDA, so CUDA runtime get destroyed at last.
cudaFree(nullptr);
// Make sure CUDA event pool get destroyed after device and stream.
cu::CudaEvent::init_pool();
// Ensure the static stream objects get created.
cu::get_command_encoder(s);
}

View File

@@ -22,11 +22,15 @@ namespace cu {
namespace {
// Manage cached cudaEvent_t objects.
struct CudaEventPool {
static CudaEventHandle create(int flags) {
auto& cache = cache_for(flags);
class CudaEventPool {
public:
CudaEventHandle create(Device& d, int flags) {
if (!on_creation_thread()) {
return CudaEventHandle(d, flags);
}
auto& cache = cache_for(d, flags);
if (cache.empty()) {
return CudaEventHandle(flags);
return CudaEventHandle(d, flags);
} else {
CudaEventHandle ret = std::move(cache.back());
cache.pop_back();
@@ -34,54 +38,89 @@ struct CudaEventPool {
}
}
static void release(CudaEventHandle event) {
cache_for(event.flags).push_back(std::move(event));
void release(CudaEventHandle event) {
if (!on_creation_thread()) {
// Event will be destroyed directly instead of getting moved to cache.
return;
}
cache_for(event.device, event.flags).push_back(std::move(event));
}
static std::vector<CudaEventHandle>& cache_for(int flags) {
static std::map<int, std::vector<CudaEventHandle>> cache;
return cache[flags];
private:
std::vector<CudaEventHandle>& cache_for(Device& d, int flags) {
return cache_[d.cuda_device()][flags];
}
bool on_creation_thread() {
return std::this_thread::get_id() == thread_id_;
}
// The CudaEvent may be created and destroyed on different threads (for
// example when waiting on GPU work in CPU stream), we don't want to make
// the cache thread-safe as it adds overhead, so we just skip cache when
// using events in worker threads.
std::thread::id thread_id_{std::this_thread::get_id()};
// {device: {flags: [events]}}
std::map<int, std::map<int, std::vector<CudaEventHandle>>> cache_;
};
CudaEventPool& cuda_event_pool() {
static CudaEventPool pool;
return pool;
}
} // namespace
CudaEventHandle::CudaEventHandle(int flags) : flags(flags) {
CudaEventHandle::CudaEventHandle(Device& d, int flags)
: device(d), flags(flags) {
device.make_current();
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(&handle_, flags));
assert(handle_ != nullptr);
}
CudaEvent::CudaEvent(int flags) : event_(CudaEventPool::create(flags)) {}
CudaEvent::CudaEvent(Device& d, int flags)
: event_(cuda_event_pool().create(d, flags)) {}
CudaEvent::~CudaEvent() {
CudaEventPool::release(std::move(event_));
cuda_event_pool().release(std::move(event_));
}
void CudaEvent::wait() {
nvtx3::scoped_range r("cu::CudaEvent::wait");
event_.device.make_current();
cudaEventSynchronize(event_);
}
void CudaEvent::wait(cudaStream_t stream) {
event_.device.make_current();
cudaStreamWaitEvent(stream, event_);
}
void CudaEvent::record(cudaStream_t stream) {
event_.device.make_current();
cudaEventRecord(event_, stream);
}
bool CudaEvent::completed() const {
// Note: cudaEventQuery can be safely called from any device.
return cudaEventQuery(event_) == cudaSuccess;
}
// static
void CudaEvent::init_pool() {
cuda_event_pool();
}
// Wraps CudaEvent with a few features:
// 1. The class can be copied.
// 2. Make wait/record work with CPU streams.
// 3. Add checks for waiting on un-recorded event.
class CopyableCudaEvent {
public:
CopyableCudaEvent()
explicit CopyableCudaEvent(Device& d)
: event_(std::make_shared<CudaEvent>(
d,
cudaEventDisableTiming | cudaEventBlockingSync)) {}
void wait() {
@@ -245,7 +284,7 @@ struct EventImpl {
nvtx3::mark("Using slow AtomicEvent");
atomic = std::make_unique<cu::AtomicEvent>();
} else {
cuda = std::make_unique<cu::CopyableCudaEvent>();
cuda = std::make_unique<cu::CopyableCudaEvent>(cu::device(s.device));
}
}
};

View File

@@ -13,9 +13,12 @@
namespace mlx::core::cu {
class Device;
// RAII-managed move-only wrapper of cudaEvent_t.
struct CudaEventHandle : public CudaHandle<cudaEvent_t, cudaEventDestroy> {
CudaEventHandle(int flags);
CudaEventHandle(Device& d, int flags);
Device& device;
int flags;
};
@@ -23,7 +26,7 @@ struct CudaEventHandle : public CudaHandle<cudaEvent_t, cudaEventDestroy> {
// on GPU stream in CPU stream, but can not wait on CPU stream.
class CudaEvent {
public:
explicit CudaEvent(int flags);
CudaEvent(Device& d, int flags);
~CudaEvent();
CudaEvent(CudaEvent&&) = default;
@@ -40,6 +43,9 @@ class CudaEvent {
// returns true if record() has not been called.
bool completed() const;
// Internal: make sure event pool is initialized.
static void init_pool();
private:
CudaEventHandle event_;
};

View File

@@ -5,9 +5,9 @@
namespace mlx::core::cu {
Worker::Worker()
: signal_stream_(device(mlx::core::Device::gpu)),
signal_event_(cudaEventDisableTiming | cudaEventBlockingSync),
Worker::Worker(Device& d)
: signal_stream_(d),
signal_event_(d, cudaEventDisableTiming | cudaEventBlockingSync),
worker_(&Worker::thread_fn, this) {}
Worker::~Worker() {

View File

@@ -15,7 +15,7 @@ namespace mlx::core::cu {
// Run tasks in worker thread, synchronized with cuda stream.
class Worker {
public:
Worker();
explicit Worker(Device& d);
~Worker();
Worker(const Worker&) = delete;