mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 01:18:12 +08:00
[CUDA] Make CudaEvent work with multi-device (#2614)
* Set current device when creating cuda event * Separate cuda events by device * Avoid race condition in pool
This commit is contained in:
@@ -68,8 +68,8 @@ Device::~Device() {
|
|||||||
|
|
||||||
void Device::make_current() {
|
void Device::make_current() {
|
||||||
// We need to set/get current CUDA device very frequently, cache it to reduce
|
// We need to set/get current CUDA device very frequently, cache it to reduce
|
||||||
// actual calls of CUDA APIs. This function assumes single-thread in host.
|
// actual calls of CUDA APIs.
|
||||||
static int current = 0;
|
static thread_local int current = 0;
|
||||||
if (current != device_) {
|
if (current != device_) {
|
||||||
CHECK_CUDA_ERROR(cudaSetDevice(device_));
|
CHECK_CUDA_ERROR(cudaSetDevice(device_));
|
||||||
current = device_;
|
current = device_;
|
||||||
@@ -196,6 +196,7 @@ CommandEncoder::CommandEncoder(Device& d)
|
|||||||
: device_(d),
|
: device_(d),
|
||||||
stream_(d),
|
stream_(d),
|
||||||
graph_(d),
|
graph_(d),
|
||||||
|
worker_(d),
|
||||||
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {}
|
graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {}
|
||||||
|
|
||||||
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
void CommandEncoder::add_completed_handler(std::function<void()> task) {
|
||||||
|
@@ -140,7 +140,7 @@ class Device {
|
|||||||
Device(const Device&) = delete;
|
Device(const Device&) = delete;
|
||||||
Device& operator=(const Device&) = delete;
|
Device& operator=(const Device&) = delete;
|
||||||
|
|
||||||
// Make this device the current cuda device, required by some cuda calls.
|
// Make this device the current cuda device, this method is thread-safe.
|
||||||
void make_current();
|
void make_current();
|
||||||
|
|
||||||
CommandEncoder& get_command_encoder(Stream s);
|
CommandEncoder& get_command_encoder(Stream s);
|
||||||
|
@@ -15,9 +15,10 @@ bool is_available() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void new_stream(Stream s) {
|
void new_stream(Stream s) {
|
||||||
// Force initalization of CUDA by creating an event, so the CUDA runtime and
|
// Force initalization of CUDA, so CUDA runtime get destroyed at last.
|
||||||
// our CUDA event pool get destroyed last.
|
cudaFree(nullptr);
|
||||||
cu::CudaEvent(cudaEventDefault);
|
// Make sure CUDA event pool get destroyed after device and stream.
|
||||||
|
cu::CudaEvent::init_pool();
|
||||||
// Ensure the static stream objects get created.
|
// Ensure the static stream objects get created.
|
||||||
cu::get_command_encoder(s);
|
cu::get_command_encoder(s);
|
||||||
}
|
}
|
||||||
|
@@ -22,11 +22,15 @@ namespace cu {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Manage cached cudaEvent_t objects.
|
// Manage cached cudaEvent_t objects.
|
||||||
struct CudaEventPool {
|
class CudaEventPool {
|
||||||
static CudaEventHandle create(int flags) {
|
public:
|
||||||
auto& cache = cache_for(flags);
|
CudaEventHandle create(Device& d, int flags) {
|
||||||
|
if (!on_creation_thread()) {
|
||||||
|
return CudaEventHandle(d, flags);
|
||||||
|
}
|
||||||
|
auto& cache = cache_for(d, flags);
|
||||||
if (cache.empty()) {
|
if (cache.empty()) {
|
||||||
return CudaEventHandle(flags);
|
return CudaEventHandle(d, flags);
|
||||||
} else {
|
} else {
|
||||||
CudaEventHandle ret = std::move(cache.back());
|
CudaEventHandle ret = std::move(cache.back());
|
||||||
cache.pop_back();
|
cache.pop_back();
|
||||||
@@ -34,54 +38,89 @@ struct CudaEventPool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void release(CudaEventHandle event) {
|
void release(CudaEventHandle event) {
|
||||||
cache_for(event.flags).push_back(std::move(event));
|
if (!on_creation_thread()) {
|
||||||
|
// Event will be destroyed directly instead of getting moved to cache.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
cache_for(event.device, event.flags).push_back(std::move(event));
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<CudaEventHandle>& cache_for(int flags) {
|
private:
|
||||||
static std::map<int, std::vector<CudaEventHandle>> cache;
|
std::vector<CudaEventHandle>& cache_for(Device& d, int flags) {
|
||||||
return cache[flags];
|
return cache_[d.cuda_device()][flags];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool on_creation_thread() {
|
||||||
|
return std::this_thread::get_id() == thread_id_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The CudaEvent may be created and destroyed on different threads (for
|
||||||
|
// example when waiting on GPU work in CPU stream), we don't want to make
|
||||||
|
// the cache thread-safe as it adds overhead, so we just skip cache when
|
||||||
|
// using events in worker threads.
|
||||||
|
std::thread::id thread_id_{std::this_thread::get_id()};
|
||||||
|
|
||||||
|
// {device: {flags: [events]}}
|
||||||
|
std::map<int, std::map<int, std::vector<CudaEventHandle>>> cache_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
CudaEventPool& cuda_event_pool() {
|
||||||
|
static CudaEventPool pool;
|
||||||
|
return pool;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
CudaEventHandle::CudaEventHandle(int flags) : flags(flags) {
|
CudaEventHandle::CudaEventHandle(Device& d, int flags)
|
||||||
|
: device(d), flags(flags) {
|
||||||
|
device.make_current();
|
||||||
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(&handle_, flags));
|
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(&handle_, flags));
|
||||||
assert(handle_ != nullptr);
|
assert(handle_ != nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
CudaEvent::CudaEvent(int flags) : event_(CudaEventPool::create(flags)) {}
|
CudaEvent::CudaEvent(Device& d, int flags)
|
||||||
|
: event_(cuda_event_pool().create(d, flags)) {}
|
||||||
|
|
||||||
CudaEvent::~CudaEvent() {
|
CudaEvent::~CudaEvent() {
|
||||||
CudaEventPool::release(std::move(event_));
|
cuda_event_pool().release(std::move(event_));
|
||||||
}
|
}
|
||||||
|
|
||||||
void CudaEvent::wait() {
|
void CudaEvent::wait() {
|
||||||
nvtx3::scoped_range r("cu::CudaEvent::wait");
|
nvtx3::scoped_range r("cu::CudaEvent::wait");
|
||||||
|
event_.device.make_current();
|
||||||
cudaEventSynchronize(event_);
|
cudaEventSynchronize(event_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CudaEvent::wait(cudaStream_t stream) {
|
void CudaEvent::wait(cudaStream_t stream) {
|
||||||
|
event_.device.make_current();
|
||||||
cudaStreamWaitEvent(stream, event_);
|
cudaStreamWaitEvent(stream, event_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CudaEvent::record(cudaStream_t stream) {
|
void CudaEvent::record(cudaStream_t stream) {
|
||||||
|
event_.device.make_current();
|
||||||
cudaEventRecord(event_, stream);
|
cudaEventRecord(event_, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CudaEvent::completed() const {
|
bool CudaEvent::completed() const {
|
||||||
|
// Note: cudaEventQuery can be safely called from any device.
|
||||||
return cudaEventQuery(event_) == cudaSuccess;
|
return cudaEventQuery(event_) == cudaSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// static
|
||||||
|
void CudaEvent::init_pool() {
|
||||||
|
cuda_event_pool();
|
||||||
|
}
|
||||||
|
|
||||||
// Wraps CudaEvent with a few features:
|
// Wraps CudaEvent with a few features:
|
||||||
// 1. The class can be copied.
|
// 1. The class can be copied.
|
||||||
// 2. Make wait/record work with CPU streams.
|
// 2. Make wait/record work with CPU streams.
|
||||||
// 3. Add checks for waiting on un-recorded event.
|
// 3. Add checks for waiting on un-recorded event.
|
||||||
class CopyableCudaEvent {
|
class CopyableCudaEvent {
|
||||||
public:
|
public:
|
||||||
CopyableCudaEvent()
|
explicit CopyableCudaEvent(Device& d)
|
||||||
: event_(std::make_shared<CudaEvent>(
|
: event_(std::make_shared<CudaEvent>(
|
||||||
|
d,
|
||||||
cudaEventDisableTiming | cudaEventBlockingSync)) {}
|
cudaEventDisableTiming | cudaEventBlockingSync)) {}
|
||||||
|
|
||||||
void wait() {
|
void wait() {
|
||||||
@@ -245,7 +284,7 @@ struct EventImpl {
|
|||||||
nvtx3::mark("Using slow AtomicEvent");
|
nvtx3::mark("Using slow AtomicEvent");
|
||||||
atomic = std::make_unique<cu::AtomicEvent>();
|
atomic = std::make_unique<cu::AtomicEvent>();
|
||||||
} else {
|
} else {
|
||||||
cuda = std::make_unique<cu::CopyableCudaEvent>();
|
cuda = std::make_unique<cu::CopyableCudaEvent>(cu::device(s.device));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@@ -13,9 +13,12 @@
|
|||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
class Device;
|
||||||
|
|
||||||
// RAII-managed move-only wrapper of cudaEvent_t.
|
// RAII-managed move-only wrapper of cudaEvent_t.
|
||||||
struct CudaEventHandle : public CudaHandle<cudaEvent_t, cudaEventDestroy> {
|
struct CudaEventHandle : public CudaHandle<cudaEvent_t, cudaEventDestroy> {
|
||||||
CudaEventHandle(int flags);
|
CudaEventHandle(Device& d, int flags);
|
||||||
|
Device& device;
|
||||||
int flags;
|
int flags;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -23,7 +26,7 @@ struct CudaEventHandle : public CudaHandle<cudaEvent_t, cudaEventDestroy> {
|
|||||||
// on GPU stream in CPU stream, but can not wait on CPU stream.
|
// on GPU stream in CPU stream, but can not wait on CPU stream.
|
||||||
class CudaEvent {
|
class CudaEvent {
|
||||||
public:
|
public:
|
||||||
explicit CudaEvent(int flags);
|
CudaEvent(Device& d, int flags);
|
||||||
~CudaEvent();
|
~CudaEvent();
|
||||||
|
|
||||||
CudaEvent(CudaEvent&&) = default;
|
CudaEvent(CudaEvent&&) = default;
|
||||||
@@ -40,6 +43,9 @@ class CudaEvent {
|
|||||||
// returns true if record() has not been called.
|
// returns true if record() has not been called.
|
||||||
bool completed() const;
|
bool completed() const;
|
||||||
|
|
||||||
|
// Internal: make sure event pool is initialized.
|
||||||
|
static void init_pool();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
CudaEventHandle event_;
|
CudaEventHandle event_;
|
||||||
};
|
};
|
||||||
|
@@ -5,9 +5,9 @@
|
|||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
Worker::Worker()
|
Worker::Worker(Device& d)
|
||||||
: signal_stream_(device(mlx::core::Device::gpu)),
|
: signal_stream_(d),
|
||||||
signal_event_(cudaEventDisableTiming | cudaEventBlockingSync),
|
signal_event_(d, cudaEventDisableTiming | cudaEventBlockingSync),
|
||||||
worker_(&Worker::thread_fn, this) {}
|
worker_(&Worker::thread_fn, this) {}
|
||||||
|
|
||||||
Worker::~Worker() {
|
Worker::~Worker() {
|
||||||
|
@@ -15,7 +15,7 @@ namespace mlx::core::cu {
|
|||||||
// Run tasks in worker thread, synchronized with cuda stream.
|
// Run tasks in worker thread, synchronized with cuda stream.
|
||||||
class Worker {
|
class Worker {
|
||||||
public:
|
public:
|
||||||
Worker();
|
explicit Worker(Device& d);
|
||||||
~Worker();
|
~Worker();
|
||||||
|
|
||||||
Worker(const Worker&) = delete;
|
Worker(const Worker&) = delete;
|
||||||
|
Reference in New Issue
Block a user