mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Avoid invoking allocator::malloc when creating CUDA event (#2232)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/event.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
@@ -111,12 +112,12 @@ __global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
|
||||
|
||||
SharedEvent::SharedEvent() {
|
||||
// Allocate cuda::atomic on managed memory.
|
||||
allocator::Buffer buffer = allocator::malloc(sizeof(Atomic));
|
||||
Atomic* ac = static_cast<Atomic*>(buffer.raw_ptr());
|
||||
Atomic* ac;
|
||||
CHECK_CUDA_ERROR(cudaMallocManaged(&ac, sizeof(Atomic)));
|
||||
new (ac) Atomic(0);
|
||||
ac_ = std::shared_ptr<Atomic>(ac, [buffer](Atomic* ptr) {
|
||||
ac_ = std::shared_ptr<Atomic>(ac, [](Atomic* ptr) {
|
||||
ptr->~Atomic();
|
||||
allocator::free(buffer);
|
||||
allocator().cuda_free(ptr);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user