mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 16:48:10 +08:00
[CUDA] Make CudaEvent work with multi-device (#2614)
* Set current device when creating cuda event * Separate cuda events by device * Avoid race condition in pool
This commit is contained in:
@@ -68,8 +68,8 @@ Device::~Device() {
|
||||
|
||||
void Device::make_current() {
|
||||
// We need to set/get current CUDA device very frequently, cache it to reduce
|
||||
// actual calls of CUDA APIs. This function assumes single-thread in host.
|
||||
static int current = 0;
|
||||
// actual calls of CUDA APIs.
|
||||
static thread_local int current = 0;
|
||||
if (current != device_) {
|
||||
CHECK_CUDA_ERROR(cudaSetDevice(device_));
|
||||
current = device_;
|
||||
@@ -196,6 +196,7 @@ CommandEncoder::CommandEncoder(Device& d)
|
||||
: device_(d),
|
||||
stream_(d),
|
||||
graph_(d),
|
||||
worker_(d),
|
||||
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {}
|
||||
|
||||
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
||||
|
@@ -140,7 +140,7 @@ class Device {
|
||||
Device(const Device&) = delete;
|
||||
Device& operator=(const Device&) = delete;
|
||||
|
||||
// Make this device the current cuda device, required by some cuda calls.
|
||||
// Make this device the current cuda device, this method is thread-safe.
|
||||
void make_current();
|
||||
|
||||
CommandEncoder& get_command_encoder(Stream s);
|
||||
|
@@ -15,9 +15,10 @@ bool is_available() {
|
||||
}
|
||||
|
||||
void new_stream(Stream s) {
|
||||
// Force initalization of CUDA by creating an event, so the CUDA runtime and
|
||||
// our CUDA event pool get destroyed last.
|
||||
cu::CudaEvent(cudaEventDefault);
|
||||
// Force initalization of CUDA, so CUDA runtime get destroyed at last.
|
||||
cudaFree(nullptr);
|
||||
// Make sure CUDA event pool get destroyed after device and stream.
|
||||
cu::CudaEvent::init_pool();
|
||||
// Ensure the static stream objects get created.
|
||||
cu::get_command_encoder(s);
|
||||
}
|
||||
|
@@ -22,11 +22,15 @@ namespace cu {
|
||||
namespace {
|
||||
|
||||
// Manage cached cudaEvent_t objects.
|
||||
struct CudaEventPool {
|
||||
static CudaEventHandle create(int flags) {
|
||||
auto& cache = cache_for(flags);
|
||||
class CudaEventPool {
|
||||
public:
|
||||
CudaEventHandle create(Device& d, int flags) {
|
||||
if (!on_creation_thread()) {
|
||||
return CudaEventHandle(d, flags);
|
||||
}
|
||||
auto& cache = cache_for(d, flags);
|
||||
if (cache.empty()) {
|
||||
return CudaEventHandle(flags);
|
||||
return CudaEventHandle(d, flags);
|
||||
} else {
|
||||
CudaEventHandle ret = std::move(cache.back());
|
||||
cache.pop_back();
|
||||
@@ -34,54 +38,89 @@ struct CudaEventPool {
|
||||
}
|
||||
}
|
||||
|
||||
static void release(CudaEventHandle event) {
|
||||
cache_for(event.flags).push_back(std::move(event));
|
||||
void release(CudaEventHandle event) {
|
||||
if (!on_creation_thread()) {
|
||||
// Event will be destroyed directly instead of getting moved to cache.
|
||||
return;
|
||||
}
|
||||
cache_for(event.device, event.flags).push_back(std::move(event));
|
||||
}
|
||||
|
||||
static std::vector<CudaEventHandle>& cache_for(int flags) {
|
||||
static std::map<int, std::vector<CudaEventHandle>> cache;
|
||||
return cache[flags];
|
||||
private:
|
||||
std::vector<CudaEventHandle>& cache_for(Device& d, int flags) {
|
||||
return cache_[d.cuda_device()][flags];
|
||||
}
|
||||
|
||||
bool on_creation_thread() {
|
||||
return std::this_thread::get_id() == thread_id_;
|
||||
}
|
||||
|
||||
// The CudaEvent may be created and destroyed on different threads (for
|
||||
// example when waiting on GPU work in CPU stream), we don't want to make
|
||||
// the cache thread-safe as it adds overhead, so we just skip cache when
|
||||
// using events in worker threads.
|
||||
std::thread::id thread_id_{std::this_thread::get_id()};
|
||||
|
||||
// {device: {flags: [events]}}
|
||||
std::map<int, std::map<int, std::vector<CudaEventHandle>>> cache_;
|
||||
};
|
||||
|
||||
CudaEventPool& cuda_event_pool() {
|
||||
static CudaEventPool pool;
|
||||
return pool;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
CudaEventHandle::CudaEventHandle(int flags) : flags(flags) {
|
||||
CudaEventHandle::CudaEventHandle(Device& d, int flags)
|
||||
: device(d), flags(flags) {
|
||||
device.make_current();
|
||||
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(&handle_, flags));
|
||||
assert(handle_ != nullptr);
|
||||
}
|
||||
|
||||
CudaEvent::CudaEvent(int flags) : event_(CudaEventPool::create(flags)) {}
|
||||
CudaEvent::CudaEvent(Device& d, int flags)
|
||||
: event_(cuda_event_pool().create(d, flags)) {}
|
||||
|
||||
CudaEvent::~CudaEvent() {
|
||||
CudaEventPool::release(std::move(event_));
|
||||
cuda_event_pool().release(std::move(event_));
|
||||
}
|
||||
|
||||
void CudaEvent::wait() {
|
||||
nvtx3::scoped_range r("cu::CudaEvent::wait");
|
||||
event_.device.make_current();
|
||||
cudaEventSynchronize(event_);
|
||||
}
|
||||
|
||||
void CudaEvent::wait(cudaStream_t stream) {
|
||||
event_.device.make_current();
|
||||
cudaStreamWaitEvent(stream, event_);
|
||||
}
|
||||
|
||||
void CudaEvent::record(cudaStream_t stream) {
|
||||
event_.device.make_current();
|
||||
cudaEventRecord(event_, stream);
|
||||
}
|
||||
|
||||
bool CudaEvent::completed() const {
|
||||
// Note: cudaEventQuery can be safely called from any device.
|
||||
return cudaEventQuery(event_) == cudaSuccess;
|
||||
}
|
||||
|
||||
// static
|
||||
void CudaEvent::init_pool() {
|
||||
cuda_event_pool();
|
||||
}
|
||||
|
||||
// Wraps CudaEvent with a few features:
|
||||
// 1. The class can be copied.
|
||||
// 2. Make wait/record work with CPU streams.
|
||||
// 3. Add checks for waiting on un-recorded event.
|
||||
class CopyableCudaEvent {
|
||||
public:
|
||||
CopyableCudaEvent()
|
||||
explicit CopyableCudaEvent(Device& d)
|
||||
: event_(std::make_shared<CudaEvent>(
|
||||
d,
|
||||
cudaEventDisableTiming | cudaEventBlockingSync)) {}
|
||||
|
||||
void wait() {
|
||||
@@ -245,7 +284,7 @@ struct EventImpl {
|
||||
nvtx3::mark("Using slow AtomicEvent");
|
||||
atomic = std::make_unique<cu::AtomicEvent>();
|
||||
} else {
|
||||
cuda = std::make_unique<cu::CopyableCudaEvent>();
|
||||
cuda = std::make_unique<cu::CopyableCudaEvent>(cu::device(s.device));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@@ -13,9 +13,12 @@
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
class Device;
|
||||
|
||||
// RAII-managed move-only wrapper of cudaEvent_t.
|
||||
struct CudaEventHandle : public CudaHandle<cudaEvent_t, cudaEventDestroy> {
|
||||
CudaEventHandle(int flags);
|
||||
CudaEventHandle(Device& d, int flags);
|
||||
Device& device;
|
||||
int flags;
|
||||
};
|
||||
|
||||
@@ -23,7 +26,7 @@ struct CudaEventHandle : public CudaHandle<cudaEvent_t, cudaEventDestroy> {
|
||||
// on GPU stream in CPU stream, but can not wait on CPU stream.
|
||||
class CudaEvent {
|
||||
public:
|
||||
explicit CudaEvent(int flags);
|
||||
CudaEvent(Device& d, int flags);
|
||||
~CudaEvent();
|
||||
|
||||
CudaEvent(CudaEvent&&) = default;
|
||||
@@ -40,6 +43,9 @@ class CudaEvent {
|
||||
// returns true if record() has not been called.
|
||||
bool completed() const;
|
||||
|
||||
// Internal: make sure event pool is initialized.
|
||||
static void init_pool();
|
||||
|
||||
private:
|
||||
CudaEventHandle event_;
|
||||
};
|
||||
|
@@ -5,9 +5,9 @@
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
Worker::Worker()
|
||||
: signal_stream_(device(mlx::core::Device::gpu)),
|
||||
signal_event_(cudaEventDisableTiming | cudaEventBlockingSync),
|
||||
Worker::Worker(Device& d)
|
||||
: signal_stream_(d),
|
||||
signal_event_(d, cudaEventDisableTiming | cudaEventBlockingSync),
|
||||
worker_(&Worker::thread_fn, this) {}
|
||||
|
||||
Worker::~Worker() {
|
||||
|
@@ -15,7 +15,7 @@ namespace mlx::core::cu {
|
||||
// Run tasks in worker thread, synchronized with cuda stream.
|
||||
class Worker {
|
||||
public:
|
||||
Worker();
|
||||
explicit Worker(Device& d);
|
||||
~Worker();
|
||||
|
||||
Worker(const Worker&) = delete;
|
||||
|
Reference in New Issue
Block a user