mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-27 00:08:09 +08:00
[CUDA] Recycle CUDA events (#2604)
* Make CudaEvent a CudaHandle * Add caching for CudaEvent * Make sure cuda events are destroyed at last * Fix headers * SharedEvent => AtomicEvent * RawCudaEvent => CudaEventHandle, CudaEventWrapper => CopyableCudaEvent * Remove unneeded asserts
This commit is contained in:
@@ -15,8 +15,9 @@ bool is_available() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void new_stream(Stream s) {
|
void new_stream(Stream s) {
|
||||||
// Force initalization of cuda, so cuda runtime get destroyed at last.
|
// Force initalization of CUDA by creating an event, so the CUDA runtime and
|
||||||
cudaFree(nullptr);
|
// our CUDA event pool get destroyed last.
|
||||||
|
cu::CudaEvent(cudaEventDefault);
|
||||||
// Ensure the static stream objects get created.
|
// Ensure the static stream objects get created.
|
||||||
cu::get_command_encoder(s);
|
cu::get_command_encoder(s);
|
||||||
}
|
}
|
||||||
|
@@ -3,10 +3,12 @@
|
|||||||
#include "mlx/backend/cuda/allocator.h"
|
#include "mlx/backend/cuda/allocator.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/event.h"
|
#include "mlx/backend/cuda/event.h"
|
||||||
#include "mlx/backend/cuda/utils.h"
|
|
||||||
#include "mlx/event.h"
|
#include "mlx/event.h"
|
||||||
#include "mlx/scheduler.h"
|
#include "mlx/scheduler.h"
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include <nvtx3/nvtx3.hpp>
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@@ -17,104 +19,141 @@ namespace cu {
|
|||||||
// CudaEvent implementations
|
// CudaEvent implementations
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
// Cuda event managed with RAII.
|
namespace {
|
||||||
class CudaEventHandle {
|
|
||||||
public:
|
// Manage cached cudaEvent_t objects.
|
||||||
CudaEventHandle() {
|
struct CudaEventPool {
|
||||||
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(
|
static CudaEventHandle create(int flags) {
|
||||||
&event_, cudaEventDisableTiming | cudaEventBlockingSync));
|
auto& cache = cache_for(flags);
|
||||||
|
if (cache.empty()) {
|
||||||
|
return CudaEventHandle(flags);
|
||||||
|
} else {
|
||||||
|
CudaEventHandle ret = std::move(cache.back());
|
||||||
|
cache.pop_back();
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
~CudaEventHandle() {
|
static void release(CudaEventHandle event) {
|
||||||
CHECK_CUDA_ERROR(cudaEventDestroy(event_));
|
cache_for(event.flags).push_back(std::move(event));
|
||||||
}
|
}
|
||||||
|
|
||||||
CudaEventHandle(const CudaEventHandle&) = delete;
|
static std::vector<CudaEventHandle>& cache_for(int flags) {
|
||||||
CudaEventHandle& operator=(const CudaEventHandle&) = delete;
|
static std::map<int, std::vector<CudaEventHandle>> cache;
|
||||||
|
return cache[flags];
|
||||||
operator cudaEvent_t() const {
|
|
||||||
return event_;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
cudaEvent_t event_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
CudaEvent::CudaEvent() : event_(std::make_shared<CudaEventHandle>()) {}
|
} // namespace
|
||||||
|
|
||||||
|
CudaEventHandle::CudaEventHandle(int flags) : flags(flags) {
|
||||||
|
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(&handle_, flags));
|
||||||
|
assert(handle_ != nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
CudaEvent::CudaEvent(int flags) : event_(CudaEventPool::create(flags)) {}
|
||||||
|
|
||||||
|
CudaEvent::~CudaEvent() {
|
||||||
|
CudaEventPool::release(std::move(event_));
|
||||||
|
}
|
||||||
|
|
||||||
void CudaEvent::wait() {
|
void CudaEvent::wait() {
|
||||||
nvtx3::scoped_range r("cu::CudaEvent::wait");
|
nvtx3::scoped_range r("cu::CudaEvent::wait");
|
||||||
if (!recorded_) {
|
cudaEventSynchronize(event_);
|
||||||
throw std::runtime_error("Should not wait on a CudaEvent before record.");
|
|
||||||
}
|
|
||||||
cudaEventSynchronize(*event_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void CudaEvent::wait(cudaStream_t stream) {
|
void CudaEvent::wait(cudaStream_t stream) {
|
||||||
if (!recorded_) {
|
cudaStreamWaitEvent(stream, event_);
|
||||||
throw std::runtime_error("Should not wait on a CudaEvent before record.");
|
|
||||||
}
|
|
||||||
cudaStreamWaitEvent(stream, *event_);
|
|
||||||
}
|
|
||||||
|
|
||||||
void CudaEvent::wait(Stream s) {
|
|
||||||
if (s.device == mlx::core::Device::cpu) {
|
|
||||||
scheduler::enqueue(s, [*this]() mutable { wait(); });
|
|
||||||
} else {
|
|
||||||
auto& enc = cu::get_command_encoder(s);
|
|
||||||
enc.commit();
|
|
||||||
wait(enc.stream());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void CudaEvent::record(cudaStream_t stream) {
|
void CudaEvent::record(cudaStream_t stream) {
|
||||||
cudaEventRecord(*event_, stream);
|
cudaEventRecord(event_, stream);
|
||||||
recorded_ = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void CudaEvent::record(Stream s) {
|
|
||||||
if (s.device == mlx::core::Device::cpu) {
|
|
||||||
throw std::runtime_error("CudaEvent can not wait on cpu stream.");
|
|
||||||
} else {
|
|
||||||
auto& enc = cu::get_command_encoder(s);
|
|
||||||
enc.commit();
|
|
||||||
record(enc.stream());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CudaEvent::completed() const {
|
bool CudaEvent::completed() const {
|
||||||
return cudaEventQuery(*event_) == cudaSuccess;
|
return cudaEventQuery(event_) == cudaSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wraps CudaEvent with a few features:
|
||||||
|
// 1. The class can be copied.
|
||||||
|
// 2. Make wait/record work with CPU streams.
|
||||||
|
// 3. Add checks for waiting on un-recorded event.
|
||||||
|
class CopyableCudaEvent {
|
||||||
|
public:
|
||||||
|
CopyableCudaEvent()
|
||||||
|
: event_(std::make_shared<CudaEvent>(
|
||||||
|
cudaEventDisableTiming | cudaEventBlockingSync)) {}
|
||||||
|
|
||||||
|
void wait() {
|
||||||
|
event_->wait();
|
||||||
|
}
|
||||||
|
|
||||||
|
void wait(Stream s) {
|
||||||
|
if (s.device == mlx::core::Device::cpu) {
|
||||||
|
scheduler::enqueue(s, [*this]() mutable {
|
||||||
|
check_recorded();
|
||||||
|
event_->wait();
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
check_recorded();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.commit();
|
||||||
|
event_->wait(encoder.stream());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void record(Stream s) {
|
||||||
|
if (s.device == mlx::core::Device::cpu) {
|
||||||
|
throw std::runtime_error("CudaEvent can not wait on CPU stream.");
|
||||||
|
} else {
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
encoder.commit();
|
||||||
|
event_->record(encoder.stream());
|
||||||
|
recorded_ = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_signaled() const {
|
||||||
|
return recorded_ && event_->completed();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void check_recorded() const {
|
||||||
|
if (!recorded_) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Should not wait on a CudaEvent before recording.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<CudaEvent> event_;
|
||||||
|
bool recorded_{false};
|
||||||
|
};
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// SharedEvent implementations
|
// AtomicEvent implementations
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
__host__ __device__ void event_wait(SharedEvent::Atomic* ac, uint64_t value) {
|
__host__ __device__ void event_wait(AtomicEvent::Atomic* ac, uint64_t value) {
|
||||||
uint64_t current;
|
uint64_t current;
|
||||||
while ((current = ac->load()) < value) {
|
while ((current = ac->load()) < value) {
|
||||||
ac->wait(current);
|
ac->wait(current);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__host__ __device__ void event_signal(SharedEvent::Atomic* ac, uint64_t value) {
|
__host__ __device__ void event_signal(AtomicEvent::Atomic* ac, uint64_t value) {
|
||||||
ac->store(value);
|
ac->store(value);
|
||||||
ac->notify_all();
|
ac->notify_all();
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void event_wait_kernel(SharedEvent::Atomic* ac, uint64_t value) {
|
__global__ void event_wait_kernel(AtomicEvent::Atomic* ac, uint64_t value) {
|
||||||
event_wait(ac, value);
|
event_wait(ac, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
|
__global__ void event_signal_kernel(AtomicEvent::Atomic* ac, uint64_t value) {
|
||||||
event_signal(ac, value);
|
event_signal(ac, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
SharedEvent::Atomic* to_atomic(std::shared_ptr<Buffer> buf) {
|
AtomicEvent::AtomicEvent() {
|
||||||
return static_cast<SharedEvent::Atomic*>(buf->raw_ptr());
|
|
||||||
}
|
|
||||||
|
|
||||||
SharedEvent::SharedEvent() {
|
|
||||||
buf_ = std::shared_ptr<Buffer>(
|
buf_ = std::shared_ptr<Buffer>(
|
||||||
new Buffer{allocator().malloc(sizeof(Atomic))}, [](Buffer* ptr) {
|
new Buffer{allocator().malloc(sizeof(Atomic))}, [](Buffer* ptr) {
|
||||||
allocator().free(*ptr);
|
allocator().free(*ptr);
|
||||||
@@ -123,17 +162,17 @@ SharedEvent::SharedEvent() {
|
|||||||
*static_cast<uint64_t*>(buf_->raw_ptr()) = 0;
|
*static_cast<uint64_t*>(buf_->raw_ptr()) = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SharedEvent::wait(uint64_t value) {
|
void AtomicEvent::wait(uint64_t value) {
|
||||||
nvtx3::scoped_range r("cu::SharedEvent::wait");
|
nvtx3::scoped_range r("cu::AtomicEvent::wait");
|
||||||
event_wait(to_atomic(buf_), value);
|
event_wait(atomic(), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SharedEvent::wait(cudaStream_t stream, uint64_t value) {
|
void AtomicEvent::wait(cudaStream_t stream, uint64_t value) {
|
||||||
event_wait_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
|
event_wait_kernel<<<1, 1, 0, stream>>>(atomic(), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SharedEvent::wait(Stream s, uint64_t value) {
|
void AtomicEvent::wait(Stream s, uint64_t value) {
|
||||||
nvtx3::scoped_range r("cu::SharedEvent::wait(s)");
|
nvtx3::scoped_range r("cu::AtomicEvent::wait(s)");
|
||||||
if (s.device == mlx::core::Device::cpu) {
|
if (s.device == mlx::core::Device::cpu) {
|
||||||
scheduler::enqueue(s, [*this, value]() mutable { wait(value); });
|
scheduler::enqueue(s, [*this, value]() mutable { wait(value); });
|
||||||
} else {
|
} else {
|
||||||
@@ -144,17 +183,17 @@ void SharedEvent::wait(Stream s, uint64_t value) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SharedEvent::signal(uint64_t value) {
|
void AtomicEvent::signal(uint64_t value) {
|
||||||
nvtx3::scoped_range r("cu::SharedEvent::signal");
|
nvtx3::scoped_range r("cu::AtomicEvent::signal");
|
||||||
event_signal(to_atomic(buf_), value);
|
event_signal(atomic(), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
|
void AtomicEvent::signal(cudaStream_t stream, uint64_t value) {
|
||||||
event_signal_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
|
event_signal_kernel<<<1, 1, 0, stream>>>(atomic(), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SharedEvent::signal(Stream s, uint64_t value) {
|
void AtomicEvent::signal(Stream s, uint64_t value) {
|
||||||
nvtx3::scoped_range r("cu::SharedEvent::signal(s)");
|
nvtx3::scoped_range r("cu::AtomicEvent::signal(s)");
|
||||||
if (s.device == mlx::core::Device::cpu) {
|
if (s.device == mlx::core::Device::cpu) {
|
||||||
// Signal through a GPU stream so the atomic is updated in GPU - updating
|
// Signal through a GPU stream so the atomic is updated in GPU - updating
|
||||||
// the atomic in CPU sometimes does not get GPU notified.
|
// the atomic in CPU sometimes does not get GPU notified.
|
||||||
@@ -168,14 +207,14 @@ void SharedEvent::signal(Stream s, uint64_t value) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool SharedEvent::is_signaled(uint64_t value) const {
|
bool AtomicEvent::is_signaled(uint64_t value) const {
|
||||||
nvtx3::scoped_range r("cu::SharedEvent::is_signaled");
|
nvtx3::scoped_range r("cu::AtomicEvent::is_signaled");
|
||||||
return to_atomic(buf_)->load() >= value;
|
return atomic()->load() >= value;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t SharedEvent::value() const {
|
uint64_t AtomicEvent::value() const {
|
||||||
nvtx3::scoped_range r("cu::SharedEvent::value");
|
nvtx3::scoped_range r("cu::AtomicEvent::value");
|
||||||
return to_atomic(buf_)->load();
|
return atomic()->load();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cu
|
} // namespace cu
|
||||||
@@ -188,14 +227,14 @@ namespace {
|
|||||||
|
|
||||||
struct EventImpl {
|
struct EventImpl {
|
||||||
// CudaEvent is preferred when possible because it is fast, however we have
|
// CudaEvent is preferred when possible because it is fast, however we have
|
||||||
// to fallback to SharedEvent in following cases:
|
// to fallback to AtomicEvent in following cases:
|
||||||
// 1. the event is used to wait/signal a cpu stream;
|
// 1. the event is used to wait/signal a cpu stream;
|
||||||
// 2. signal value other than 1 has been specified.
|
// 2. signal value other than 1 has been specified.
|
||||||
std::unique_ptr<cu::CudaEvent> cuda;
|
std::unique_ptr<cu::CopyableCudaEvent> cuda;
|
||||||
std::unique_ptr<cu::SharedEvent> shared;
|
std::unique_ptr<cu::AtomicEvent> atomic;
|
||||||
|
|
||||||
bool is_created() const {
|
bool is_created() const {
|
||||||
return cuda || shared;
|
return cuda || atomic;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ensure_created(Stream s, uint64_t signal_value) {
|
void ensure_created(Stream s, uint64_t signal_value) {
|
||||||
@@ -203,10 +242,10 @@ struct EventImpl {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (s.device == mlx::core::Device::cpu || signal_value > 1) {
|
if (s.device == mlx::core::Device::cpu || signal_value > 1) {
|
||||||
nvtx3::mark("Using slow SharedEvent");
|
nvtx3::mark("Using slow AtomicEvent");
|
||||||
shared = std::make_unique<cu::SharedEvent>();
|
atomic = std::make_unique<cu::AtomicEvent>();
|
||||||
} else {
|
} else {
|
||||||
cuda = std::make_unique<cu::CudaEvent>();
|
cuda = std::make_unique<cu::CopyableCudaEvent>();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -225,7 +264,7 @@ void Event::wait() {
|
|||||||
assert(value() == 1);
|
assert(value() == 1);
|
||||||
event->cuda->wait();
|
event->cuda->wait();
|
||||||
} else {
|
} else {
|
||||||
event->shared->wait(value());
|
event->atomic->wait(value());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -236,7 +275,7 @@ void Event::wait(Stream s) {
|
|||||||
assert(value() == 1);
|
assert(value() == 1);
|
||||||
event->cuda->wait(s);
|
event->cuda->wait(s);
|
||||||
} else {
|
} else {
|
||||||
event->shared->wait(s, value());
|
event->atomic->wait(s, value());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -247,7 +286,7 @@ void Event::signal(Stream s) {
|
|||||||
assert(value() == 1);
|
assert(value() == 1);
|
||||||
event->cuda->record(s);
|
event->cuda->record(s);
|
||||||
} else {
|
} else {
|
||||||
event->shared->signal(s, value());
|
event->atomic->signal(s, value());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -258,9 +297,9 @@ bool Event::is_signaled() const {
|
|||||||
}
|
}
|
||||||
if (event->cuda) {
|
if (event->cuda) {
|
||||||
assert(value() == 1);
|
assert(value() == 1);
|
||||||
return event->cuda->recorded() && event->cuda->completed();
|
return event->cuda->is_signaled();
|
||||||
} else {
|
} else {
|
||||||
return event->shared->is_signaled(value());
|
return event->atomic->is_signaled(value());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -3,49 +3,54 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/backend/cuda/utils.h"
|
||||||
#include "mlx/stream.h"
|
#include "mlx/stream.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <cuda/atomic>
|
#include <cuda/atomic>
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
class CudaEventHandle;
|
// RAII-managed move-only wrapper of cudaEvent_t.
|
||||||
|
struct CudaEventHandle : public CudaHandle<cudaEvent_t, cudaEventDestroy> {
|
||||||
|
CudaEventHandle(int flags);
|
||||||
|
int flags;
|
||||||
|
};
|
||||||
|
|
||||||
// Wrapper of native cuda event. It can synchronize between GPU streams, or wait
|
// Wrapper of native cuda event. It can synchronize between GPU streams, or wait
|
||||||
// on GPU stream in CPU stream, but can not wait on CPU stream.
|
// on GPU stream in CPU stream, but can not wait on CPU stream.
|
||||||
class CudaEvent {
|
class CudaEvent {
|
||||||
public:
|
public:
|
||||||
CudaEvent();
|
explicit CudaEvent(int flags);
|
||||||
|
~CudaEvent();
|
||||||
|
|
||||||
|
CudaEvent(CudaEvent&&) = default;
|
||||||
|
CudaEvent& operator=(CudaEvent&&) = default;
|
||||||
|
|
||||||
|
CudaEvent(const CudaEvent&) = delete;
|
||||||
|
CudaEvent& operator=(const CudaEvent&) = delete;
|
||||||
|
|
||||||
void wait();
|
void wait();
|
||||||
void wait(cudaStream_t stream);
|
void wait(cudaStream_t stream);
|
||||||
void wait(Stream s);
|
|
||||||
void record(cudaStream_t stream);
|
void record(cudaStream_t stream);
|
||||||
void record(Stream s);
|
|
||||||
|
|
||||||
// Return whether the recorded kernels have completed. Note that this method
|
// Return whether the recorded kernels have completed. Note that this method
|
||||||
// returns true if record() has not been called.
|
// returns true if record() has not been called.
|
||||||
bool completed() const;
|
bool completed() const;
|
||||||
|
|
||||||
bool recorded() const {
|
|
||||||
return recorded_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool recorded_{false};
|
CudaEventHandle event_;
|
||||||
std::shared_ptr<CudaEventHandle> event_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Event that can synchronize between CPU and GPU. It is much slower than
|
// Event that can synchronize between CPU and GPU. It is much slower than
|
||||||
// CudaEvent so the latter should always be preferred when possible.
|
// CudaEvent so the latter should always be preferred when possible.
|
||||||
class SharedEvent {
|
class AtomicEvent {
|
||||||
public:
|
public:
|
||||||
using Atomic = cuda::atomic<uint64_t>;
|
using Atomic = cuda::atomic<uint64_t>;
|
||||||
|
|
||||||
SharedEvent();
|
AtomicEvent();
|
||||||
|
|
||||||
void wait(uint64_t value);
|
void wait(uint64_t value);
|
||||||
void wait(cudaStream_t stream, uint64_t value);
|
void wait(cudaStream_t stream, uint64_t value);
|
||||||
@@ -57,7 +62,11 @@ class SharedEvent {
|
|||||||
uint64_t value() const;
|
uint64_t value() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<mlx::core::allocator::Buffer> buf_;
|
Atomic* atomic() const {
|
||||||
|
return static_cast<AtomicEvent::Atomic*>(buf_->raw_ptr());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<allocator::Buffer> buf_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core::cu
|
||||||
|
@@ -7,7 +7,7 @@ namespace mlx::core {
|
|||||||
|
|
||||||
struct FenceImpl {
|
struct FenceImpl {
|
||||||
uint32_t count;
|
uint32_t count;
|
||||||
cu::SharedEvent event;
|
cu::AtomicEvent event;
|
||||||
};
|
};
|
||||||
|
|
||||||
Fence::Fence(Stream s) {
|
Fence::Fence(Stream s) {
|
||||||
|
@@ -7,6 +7,7 @@ namespace mlx::core::cu {
|
|||||||
|
|
||||||
Worker::Worker()
|
Worker::Worker()
|
||||||
: signal_stream_(device(mlx::core::Device::gpu)),
|
: signal_stream_(device(mlx::core::Device::gpu)),
|
||||||
|
signal_event_(cudaEventDisableTiming | cudaEventBlockingSync),
|
||||||
worker_(&Worker::thread_fn, this) {}
|
worker_(&Worker::thread_fn, this) {}
|
||||||
|
|
||||||
Worker::~Worker() {
|
Worker::~Worker() {
|
||||||
|
@@ -3,7 +3,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/backend/cuda/event.h"
|
#include "mlx/backend/cuda/event.h"
|
||||||
#include "mlx/backend/cuda/utils.h"
|
|
||||||
|
|
||||||
#include <condition_variable>
|
#include <condition_variable>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
Reference in New Issue
Block a user