|
|
|
@@ -3,10 +3,12 @@
|
|
|
|
|
#include "mlx/backend/cuda/allocator.h"
|
|
|
|
|
#include "mlx/backend/cuda/device.h"
|
|
|
|
|
#include "mlx/backend/cuda/event.h"
|
|
|
|
|
#include "mlx/backend/cuda/utils.h"
|
|
|
|
|
#include "mlx/event.h"
|
|
|
|
|
#include "mlx/scheduler.h"
|
|
|
|
|
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#include <nvtx3/nvtx3.hpp>
|
|
|
|
|
|
|
|
|
|
namespace mlx::core {
|
|
|
|
@@ -17,104 +19,141 @@ namespace cu {
|
|
|
|
|
// CudaEvent implementations
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
// Cuda event managed with RAII.
|
|
|
|
|
class CudaEventHandle {
|
|
|
|
|
public:
|
|
|
|
|
CudaEventHandle() {
|
|
|
|
|
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(
|
|
|
|
|
&event_, cudaEventDisableTiming | cudaEventBlockingSync));
|
|
|
|
|
namespace {
|
|
|
|
|
|
|
|
|
|
// Manage cached cudaEvent_t objects.
|
|
|
|
|
struct CudaEventPool {
|
|
|
|
|
static CudaEventHandle create(int flags) {
|
|
|
|
|
auto& cache = cache_for(flags);
|
|
|
|
|
if (cache.empty()) {
|
|
|
|
|
return CudaEventHandle(flags);
|
|
|
|
|
} else {
|
|
|
|
|
CudaEventHandle ret = std::move(cache.back());
|
|
|
|
|
cache.pop_back();
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
~CudaEventHandle() {
|
|
|
|
|
CHECK_CUDA_ERROR(cudaEventDestroy(event_));
|
|
|
|
|
static void release(CudaEventHandle event) {
|
|
|
|
|
cache_for(event.flags).push_back(std::move(event));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CudaEventHandle(const CudaEventHandle&) = delete;
|
|
|
|
|
CudaEventHandle& operator=(const CudaEventHandle&) = delete;
|
|
|
|
|
|
|
|
|
|
operator cudaEvent_t() const {
|
|
|
|
|
return event_;
|
|
|
|
|
static std::vector<CudaEventHandle>& cache_for(int flags) {
|
|
|
|
|
static std::map<int, std::vector<CudaEventHandle>> cache;
|
|
|
|
|
return cache[flags];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
cudaEvent_t event_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
CudaEvent::CudaEvent() : event_(std::make_shared<CudaEventHandle>()) {}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
CudaEventHandle::CudaEventHandle(int flags) : flags(flags) {
|
|
|
|
|
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(&handle_, flags));
|
|
|
|
|
assert(handle_ != nullptr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CudaEvent::CudaEvent(int flags) : event_(CudaEventPool::create(flags)) {}
|
|
|
|
|
|
|
|
|
|
CudaEvent::~CudaEvent() {
|
|
|
|
|
CudaEventPool::release(std::move(event_));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CudaEvent::wait() {
|
|
|
|
|
nvtx3::scoped_range r("cu::CudaEvent::wait");
|
|
|
|
|
if (!recorded_) {
|
|
|
|
|
throw std::runtime_error("Should not wait on a CudaEvent before record.");
|
|
|
|
|
}
|
|
|
|
|
cudaEventSynchronize(*event_);
|
|
|
|
|
cudaEventSynchronize(event_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CudaEvent::wait(cudaStream_t stream) {
|
|
|
|
|
if (!recorded_) {
|
|
|
|
|
throw std::runtime_error("Should not wait on a CudaEvent before record.");
|
|
|
|
|
}
|
|
|
|
|
cudaStreamWaitEvent(stream, *event_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CudaEvent::wait(Stream s) {
|
|
|
|
|
if (s.device == mlx::core::Device::cpu) {
|
|
|
|
|
scheduler::enqueue(s, [*this]() mutable { wait(); });
|
|
|
|
|
} else {
|
|
|
|
|
auto& enc = cu::get_command_encoder(s);
|
|
|
|
|
enc.commit();
|
|
|
|
|
wait(enc.stream());
|
|
|
|
|
}
|
|
|
|
|
cudaStreamWaitEvent(stream, event_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CudaEvent::record(cudaStream_t stream) {
|
|
|
|
|
cudaEventRecord(*event_, stream);
|
|
|
|
|
recorded_ = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CudaEvent::record(Stream s) {
|
|
|
|
|
if (s.device == mlx::core::Device::cpu) {
|
|
|
|
|
throw std::runtime_error("CudaEvent can not wait on cpu stream.");
|
|
|
|
|
} else {
|
|
|
|
|
auto& enc = cu::get_command_encoder(s);
|
|
|
|
|
enc.commit();
|
|
|
|
|
record(enc.stream());
|
|
|
|
|
}
|
|
|
|
|
cudaEventRecord(event_, stream);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool CudaEvent::completed() const {
|
|
|
|
|
return cudaEventQuery(*event_) == cudaSuccess;
|
|
|
|
|
return cudaEventQuery(event_) == cudaSuccess;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Wraps CudaEvent with a few features:
|
|
|
|
|
// 1. The class can be copied.
|
|
|
|
|
// 2. Make wait/record work with CPU streams.
|
|
|
|
|
// 3. Add checks for waiting on un-recorded event.
|
|
|
|
|
class CopyableCudaEvent {
|
|
|
|
|
public:
|
|
|
|
|
CopyableCudaEvent()
|
|
|
|
|
: event_(std::make_shared<CudaEvent>(
|
|
|
|
|
cudaEventDisableTiming | cudaEventBlockingSync)) {}
|
|
|
|
|
|
|
|
|
|
void wait() {
|
|
|
|
|
event_->wait();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void wait(Stream s) {
|
|
|
|
|
if (s.device == mlx::core::Device::cpu) {
|
|
|
|
|
scheduler::enqueue(s, [*this]() mutable {
|
|
|
|
|
check_recorded();
|
|
|
|
|
event_->wait();
|
|
|
|
|
});
|
|
|
|
|
} else {
|
|
|
|
|
check_recorded();
|
|
|
|
|
auto& encoder = cu::get_command_encoder(s);
|
|
|
|
|
encoder.commit();
|
|
|
|
|
event_->wait(encoder.stream());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void record(Stream s) {
|
|
|
|
|
if (s.device == mlx::core::Device::cpu) {
|
|
|
|
|
throw std::runtime_error("CudaEvent can not wait on CPU stream.");
|
|
|
|
|
} else {
|
|
|
|
|
auto& encoder = cu::get_command_encoder(s);
|
|
|
|
|
encoder.commit();
|
|
|
|
|
event_->record(encoder.stream());
|
|
|
|
|
recorded_ = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool is_signaled() const {
|
|
|
|
|
return recorded_ && event_->completed();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
void check_recorded() const {
|
|
|
|
|
if (!recorded_) {
|
|
|
|
|
throw std::runtime_error(
|
|
|
|
|
"Should not wait on a CudaEvent before recording.");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<CudaEvent> event_;
|
|
|
|
|
bool recorded_{false};
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
// SharedEvent implementations
|
|
|
|
|
// AtomicEvent implementations
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
|
|
__host__ __device__ void event_wait(SharedEvent::Atomic* ac, uint64_t value) {
|
|
|
|
|
__host__ __device__ void event_wait(AtomicEvent::Atomic* ac, uint64_t value) {
|
|
|
|
|
uint64_t current;
|
|
|
|
|
while ((current = ac->load()) < value) {
|
|
|
|
|
ac->wait(current);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
__host__ __device__ void event_signal(SharedEvent::Atomic* ac, uint64_t value) {
|
|
|
|
|
__host__ __device__ void event_signal(AtomicEvent::Atomic* ac, uint64_t value) {
|
|
|
|
|
ac->store(value);
|
|
|
|
|
ac->notify_all();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
__global__ void event_wait_kernel(SharedEvent::Atomic* ac, uint64_t value) {
|
|
|
|
|
__global__ void event_wait_kernel(AtomicEvent::Atomic* ac, uint64_t value) {
|
|
|
|
|
event_wait(ac, value);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
__global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
|
|
|
|
|
__global__ void event_signal_kernel(AtomicEvent::Atomic* ac, uint64_t value) {
|
|
|
|
|
event_signal(ac, value);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SharedEvent::Atomic* to_atomic(std::shared_ptr<Buffer> buf) {
|
|
|
|
|
return static_cast<SharedEvent::Atomic*>(buf->raw_ptr());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SharedEvent::SharedEvent() {
|
|
|
|
|
AtomicEvent::AtomicEvent() {
|
|
|
|
|
buf_ = std::shared_ptr<Buffer>(
|
|
|
|
|
new Buffer{allocator().malloc(sizeof(Atomic))}, [](Buffer* ptr) {
|
|
|
|
|
allocator().free(*ptr);
|
|
|
|
@@ -123,17 +162,17 @@ SharedEvent::SharedEvent() {
|
|
|
|
|
*static_cast<uint64_t*>(buf_->raw_ptr()) = 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SharedEvent::wait(uint64_t value) {
|
|
|
|
|
nvtx3::scoped_range r("cu::SharedEvent::wait");
|
|
|
|
|
event_wait(to_atomic(buf_), value);
|
|
|
|
|
void AtomicEvent::wait(uint64_t value) {
|
|
|
|
|
nvtx3::scoped_range r("cu::AtomicEvent::wait");
|
|
|
|
|
event_wait(atomic(), value);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SharedEvent::wait(cudaStream_t stream, uint64_t value) {
|
|
|
|
|
event_wait_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
|
|
|
|
|
void AtomicEvent::wait(cudaStream_t stream, uint64_t value) {
|
|
|
|
|
event_wait_kernel<<<1, 1, 0, stream>>>(atomic(), value);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SharedEvent::wait(Stream s, uint64_t value) {
|
|
|
|
|
nvtx3::scoped_range r("cu::SharedEvent::wait(s)");
|
|
|
|
|
void AtomicEvent::wait(Stream s, uint64_t value) {
|
|
|
|
|
nvtx3::scoped_range r("cu::AtomicEvent::wait(s)");
|
|
|
|
|
if (s.device == mlx::core::Device::cpu) {
|
|
|
|
|
scheduler::enqueue(s, [*this, value]() mutable { wait(value); });
|
|
|
|
|
} else {
|
|
|
|
@@ -144,17 +183,17 @@ void SharedEvent::wait(Stream s, uint64_t value) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SharedEvent::signal(uint64_t value) {
|
|
|
|
|
nvtx3::scoped_range r("cu::SharedEvent::signal");
|
|
|
|
|
event_signal(to_atomic(buf_), value);
|
|
|
|
|
void AtomicEvent::signal(uint64_t value) {
|
|
|
|
|
nvtx3::scoped_range r("cu::AtomicEvent::signal");
|
|
|
|
|
event_signal(atomic(), value);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
|
|
|
|
|
event_signal_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
|
|
|
|
|
void AtomicEvent::signal(cudaStream_t stream, uint64_t value) {
|
|
|
|
|
event_signal_kernel<<<1, 1, 0, stream>>>(atomic(), value);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SharedEvent::signal(Stream s, uint64_t value) {
|
|
|
|
|
nvtx3::scoped_range r("cu::SharedEvent::signal(s)");
|
|
|
|
|
void AtomicEvent::signal(Stream s, uint64_t value) {
|
|
|
|
|
nvtx3::scoped_range r("cu::AtomicEvent::signal(s)");
|
|
|
|
|
if (s.device == mlx::core::Device::cpu) {
|
|
|
|
|
// Signal through a GPU stream so the atomic is updated in GPU - updating
|
|
|
|
|
// the atomic in CPU sometimes does not get GPU notified.
|
|
|
|
@@ -168,14 +207,14 @@ void SharedEvent::signal(Stream s, uint64_t value) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool SharedEvent::is_signaled(uint64_t value) const {
|
|
|
|
|
nvtx3::scoped_range r("cu::SharedEvent::is_signaled");
|
|
|
|
|
return to_atomic(buf_)->load() >= value;
|
|
|
|
|
bool AtomicEvent::is_signaled(uint64_t value) const {
|
|
|
|
|
nvtx3::scoped_range r("cu::AtomicEvent::is_signaled");
|
|
|
|
|
return atomic()->load() >= value;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
uint64_t SharedEvent::value() const {
|
|
|
|
|
nvtx3::scoped_range r("cu::SharedEvent::value");
|
|
|
|
|
return to_atomic(buf_)->load();
|
|
|
|
|
uint64_t AtomicEvent::value() const {
|
|
|
|
|
nvtx3::scoped_range r("cu::AtomicEvent::value");
|
|
|
|
|
return atomic()->load();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace cu
|
|
|
|
@@ -188,14 +227,14 @@ namespace {
|
|
|
|
|
|
|
|
|
|
struct EventImpl {
|
|
|
|
|
// CudaEvent is preferred when possible because it is fast, however we have
|
|
|
|
|
// to fallback to SharedEvent in following cases:
|
|
|
|
|
// to fallback to AtomicEvent in following cases:
|
|
|
|
|
// 1. the event is used to wait/signal a cpu stream;
|
|
|
|
|
// 2. signal value other than 1 has been specified.
|
|
|
|
|
std::unique_ptr<cu::CudaEvent> cuda;
|
|
|
|
|
std::unique_ptr<cu::SharedEvent> shared;
|
|
|
|
|
std::unique_ptr<cu::CopyableCudaEvent> cuda;
|
|
|
|
|
std::unique_ptr<cu::AtomicEvent> atomic;
|
|
|
|
|
|
|
|
|
|
bool is_created() const {
|
|
|
|
|
return cuda || shared;
|
|
|
|
|
return cuda || atomic;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ensure_created(Stream s, uint64_t signal_value) {
|
|
|
|
@@ -203,10 +242,10 @@ struct EventImpl {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
if (s.device == mlx::core::Device::cpu || signal_value > 1) {
|
|
|
|
|
nvtx3::mark("Using slow SharedEvent");
|
|
|
|
|
shared = std::make_unique<cu::SharedEvent>();
|
|
|
|
|
nvtx3::mark("Using slow AtomicEvent");
|
|
|
|
|
atomic = std::make_unique<cu::AtomicEvent>();
|
|
|
|
|
} else {
|
|
|
|
|
cuda = std::make_unique<cu::CudaEvent>();
|
|
|
|
|
cuda = std::make_unique<cu::CopyableCudaEvent>();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@@ -225,7 +264,7 @@ void Event::wait() {
|
|
|
|
|
assert(value() == 1);
|
|
|
|
|
event->cuda->wait();
|
|
|
|
|
} else {
|
|
|
|
|
event->shared->wait(value());
|
|
|
|
|
event->atomic->wait(value());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@@ -236,7 +275,7 @@ void Event::wait(Stream s) {
|
|
|
|
|
assert(value() == 1);
|
|
|
|
|
event->cuda->wait(s);
|
|
|
|
|
} else {
|
|
|
|
|
event->shared->wait(s, value());
|
|
|
|
|
event->atomic->wait(s, value());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@@ -247,7 +286,7 @@ void Event::signal(Stream s) {
|
|
|
|
|
assert(value() == 1);
|
|
|
|
|
event->cuda->record(s);
|
|
|
|
|
} else {
|
|
|
|
|
event->shared->signal(s, value());
|
|
|
|
|
event->atomic->signal(s, value());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@@ -258,9 +297,9 @@ bool Event::is_signaled() const {
|
|
|
|
|
}
|
|
|
|
|
if (event->cuda) {
|
|
|
|
|
assert(value() == 1);
|
|
|
|
|
return event->cuda->recorded() && event->cuda->completed();
|
|
|
|
|
return event->cuda->is_signaled();
|
|
|
|
|
} else {
|
|
|
|
|
return event->shared->is_signaled(value());
|
|
|
|
|
return event->atomic->is_signaled(value());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|