mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
simplify allocator and fixe race with small pool
This commit is contained in:
@@ -94,18 +94,21 @@ Buffer CudaAllocator::malloc(size_t size) {
|
|||||||
if (!buf) {
|
if (!buf) {
|
||||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||||
// try to reclaim memory from the cache.
|
// try to reclaim memory from the cache.
|
||||||
size_t mem_required = get_active_memory() + get_cache_memory() + size;
|
int64_t mem_to_free =
|
||||||
if (mem_required >= memory_limit_) {
|
get_active_memory() + get_cache_memory() + size - memory_limit_;
|
||||||
buffer_cache_.release_cached_buffers(mem_required - memory_limit_);
|
mem_to_free = std::max(
|
||||||
|
static_cast<int64_t>(get_cache_memory() - max_pool_size_), mem_to_free);
|
||||||
|
if (mem_to_free > 0) {
|
||||||
|
buffer_cache_.release_cached_buffers(mem_to_free);
|
||||||
}
|
}
|
||||||
|
|
||||||
lock.unlock();
|
|
||||||
buf = new CudaBuffer{nullptr, size};
|
buf = new CudaBuffer{nullptr, size};
|
||||||
|
|
||||||
// Try the scalar pool first
|
// Try the scalar pool first
|
||||||
if (size <= small_block_size) {
|
if (size <= small_block_size) {
|
||||||
buf->data = scalar_pool_.malloc();
|
buf->data = scalar_pool_.malloc();
|
||||||
}
|
}
|
||||||
|
lock.unlock();
|
||||||
if (!buf->data) {
|
if (!buf->data) {
|
||||||
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
||||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||||
@@ -113,7 +116,6 @@ Buffer CudaAllocator::malloc(size_t size) {
|
|||||||
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
lock.lock();
|
lock.lock();
|
||||||
}
|
}
|
||||||
active_memory_ += size;
|
active_memory_ += size;
|
||||||
@@ -123,7 +125,6 @@ Buffer CudaAllocator::malloc(size_t size) {
|
|||||||
if (get_cache_memory() > max_pool_size_) {
|
if (get_cache_memory() > max_pool_size_) {
|
||||||
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
return Buffer{buf};
|
return Buffer{buf};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -135,13 +136,7 @@ void CudaAllocator::free(Buffer buffer) {
|
|||||||
|
|
||||||
std::unique_lock lock(mutex_);
|
std::unique_lock lock(mutex_);
|
||||||
active_memory_ -= buf->size;
|
active_memory_ -= buf->size;
|
||||||
if (get_cache_memory() < max_pool_size_) {
|
|
||||||
buffer_cache_.recycle_to_cache(buf);
|
buffer_cache_.recycle_to_cache(buf);
|
||||||
} else {
|
|
||||||
lock.unlock();
|
|
||||||
cuda_free(buf->data);
|
|
||||||
delete buf;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t CudaAllocator::size(Buffer buffer) const {
|
size_t CudaAllocator::size(Buffer buffer) const {
|
||||||
@@ -152,26 +147,8 @@ size_t CudaAllocator::size(Buffer buffer) const {
|
|||||||
return buf->size;
|
return buf->size;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CudaAllocator::register_this_thread() {
|
// This must be called with mutex_ aquired
|
||||||
std::lock_guard lock(worker_mutex_);
|
|
||||||
allowed_threads_.insert(std::this_thread::get_id());
|
|
||||||
}
|
|
||||||
|
|
||||||
void CudaAllocator::cuda_free(void* buf) {
|
void CudaAllocator::cuda_free(void* buf) {
|
||||||
// If cuda_free() is called from a unregistered thread, reschedule the call to
|
|
||||||
// worker.
|
|
||||||
{
|
|
||||||
std::lock_guard lock(worker_mutex_);
|
|
||||||
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
|
|
||||||
if (!worker_) {
|
|
||||||
worker_.reset(new Worker);
|
|
||||||
}
|
|
||||||
worker_->add_task([this, buf]() { this->cuda_free(buf); });
|
|
||||||
worker_->end_batch();
|
|
||||||
worker_->commit();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (scalar_pool_.in_pool(buf)) {
|
if (scalar_pool_.in_pool(buf)) {
|
||||||
scalar_pool_.free(buf);
|
scalar_pool_.free(buf);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -7,13 +7,10 @@
|
|||||||
|
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <thread>
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
class Worker;
|
|
||||||
|
|
||||||
using allocator::Buffer;
|
using allocator::Buffer;
|
||||||
|
|
||||||
// Stores cuda-managed unified memory.
|
// Stores cuda-managed unified memory.
|
||||||
@@ -50,15 +47,6 @@ class CudaAllocator : public allocator::Allocator {
|
|||||||
void free(Buffer buffer) override;
|
void free(Buffer buffer) override;
|
||||||
size_t size(Buffer buffer) const override;
|
size_t size(Buffer buffer) const override;
|
||||||
|
|
||||||
// Register current thread as safe to free buffers.
|
|
||||||
// In cuda freeing a buffer implicitly synchronizes stream, and for threads
|
|
||||||
// that may be waited by gpu stream (for example cpu stream threads), freeing
|
|
||||||
// buffers there would result in dead lock.
|
|
||||||
void register_this_thread();
|
|
||||||
|
|
||||||
// Call cudaFree in the safe thread.
|
|
||||||
void cuda_free(void* buf);
|
|
||||||
|
|
||||||
size_t get_active_memory() const;
|
size_t get_active_memory() const;
|
||||||
size_t get_peak_memory() const;
|
size_t get_peak_memory() const;
|
||||||
void reset_peak_memory();
|
void reset_peak_memory();
|
||||||
@@ -69,13 +57,11 @@ class CudaAllocator : public allocator::Allocator {
|
|||||||
void clear_cache();
|
void clear_cache();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void cuda_free(void* buf);
|
||||||
|
|
||||||
CudaAllocator();
|
CudaAllocator();
|
||||||
friend CudaAllocator& allocator();
|
friend CudaAllocator& allocator();
|
||||||
|
|
||||||
std::mutex worker_mutex_;
|
|
||||||
std::unique_ptr<Worker> worker_;
|
|
||||||
std::set<std::thread::id> allowed_threads_;
|
|
||||||
|
|
||||||
std::mutex mutex_;
|
std::mutex mutex_;
|
||||||
size_t memory_limit_;
|
size_t memory_limit_;
|
||||||
size_t max_pool_size_;
|
size_t max_pool_size_;
|
||||||
|
|||||||
@@ -315,7 +315,6 @@ void CommandEncoder::synchronize() {
|
|||||||
auto p = std::make_shared<std::promise<void>>();
|
auto p = std::make_shared<std::promise<void>>();
|
||||||
std::future<void> f = p->get_future();
|
std::future<void> f = p->get_future();
|
||||||
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
||||||
worker_.end_batch();
|
|
||||||
commit();
|
commit();
|
||||||
f.wait();
|
f.wait();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,8 +19,6 @@ void new_stream(Stream s) {
|
|||||||
cudaFree(nullptr);
|
cudaFree(nullptr);
|
||||||
// Ensure the static stream objects get created.
|
// Ensure the static stream objects get created.
|
||||||
cu::get_command_encoder(s);
|
cu::get_command_encoder(s);
|
||||||
// The main thread is safe to free buffers.
|
|
||||||
cu::allocator().register_this_thread();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void eval(array& arr) {
|
void eval(array& arr) {
|
||||||
|
|||||||
@@ -110,24 +110,26 @@ __global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
|
|||||||
event_signal(ac, value);
|
event_signal(ac, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SharedEvent::Atomic* to_atomic(std::shared_ptr<Buffer> buf) {
|
||||||
|
return static_cast<SharedEvent::Atomic*>(buf->raw_ptr());
|
||||||
|
}
|
||||||
|
|
||||||
SharedEvent::SharedEvent() {
|
SharedEvent::SharedEvent() {
|
||||||
// Allocate cuda::atomic on managed memory.
|
buf_ = std::shared_ptr<Buffer>(
|
||||||
Atomic* ac;
|
new Buffer{allocator().malloc(sizeof(Atomic))}, [](Buffer* ptr) {
|
||||||
CHECK_CUDA_ERROR(cudaMallocManaged(&ac, sizeof(Atomic)));
|
allocator().free(*ptr);
|
||||||
new (ac) Atomic(0);
|
delete ptr;
|
||||||
ac_ = std::shared_ptr<Atomic>(ac, [](Atomic* ptr) {
|
|
||||||
ptr->~Atomic();
|
|
||||||
allocator().cuda_free(ptr);
|
|
||||||
});
|
});
|
||||||
|
*static_cast<uint64_t*>(buf_->raw_ptr()) = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SharedEvent::wait(uint64_t value) {
|
void SharedEvent::wait(uint64_t value) {
|
||||||
nvtx3::scoped_range r("cu::SharedEvent::wait");
|
nvtx3::scoped_range r("cu::SharedEvent::wait");
|
||||||
event_wait(ac_.get(), value);
|
event_wait(to_atomic(buf_), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SharedEvent::wait(cudaStream_t stream, uint64_t value) {
|
void SharedEvent::wait(cudaStream_t stream, uint64_t value) {
|
||||||
event_wait_kernel<<<1, 1, 0, stream>>>(ac_.get(), value);
|
event_wait_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SharedEvent::wait(Stream s, uint64_t value) {
|
void SharedEvent::wait(Stream s, uint64_t value) {
|
||||||
@@ -138,17 +140,17 @@ void SharedEvent::wait(Stream s, uint64_t value) {
|
|||||||
auto& encoder = get_command_encoder(s);
|
auto& encoder = get_command_encoder(s);
|
||||||
encoder.commit();
|
encoder.commit();
|
||||||
wait(encoder.stream(), value);
|
wait(encoder.stream(), value);
|
||||||
encoder.add_completed_handler([ac = ac_]() {});
|
encoder.add_completed_handler([buf = buf_]() {});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SharedEvent::signal(uint64_t value) {
|
void SharedEvent::signal(uint64_t value) {
|
||||||
nvtx3::scoped_range r("cu::SharedEvent::signal");
|
nvtx3::scoped_range r("cu::SharedEvent::signal");
|
||||||
event_signal(ac_.get(), value);
|
event_signal(to_atomic(buf_), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
|
void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
|
||||||
event_signal_kernel<<<1, 1, 0, stream>>>(ac_.get(), value);
|
event_signal_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SharedEvent::signal(Stream s, uint64_t value) {
|
void SharedEvent::signal(Stream s, uint64_t value) {
|
||||||
@@ -162,18 +164,18 @@ void SharedEvent::signal(Stream s, uint64_t value) {
|
|||||||
auto& encoder = get_command_encoder(s);
|
auto& encoder = get_command_encoder(s);
|
||||||
encoder.commit();
|
encoder.commit();
|
||||||
signal(encoder.stream(), value);
|
signal(encoder.stream(), value);
|
||||||
encoder.add_completed_handler([ac = ac_]() {});
|
encoder.add_completed_handler([buf = buf_]() {});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool SharedEvent::is_signaled(uint64_t value) const {
|
bool SharedEvent::is_signaled(uint64_t value) const {
|
||||||
nvtx3::scoped_range r("cu::SharedEvent::is_signaled");
|
nvtx3::scoped_range r("cu::SharedEvent::is_signaled");
|
||||||
return ac_->load() >= value;
|
return to_atomic(buf_)->load() >= value;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t SharedEvent::value() const {
|
uint64_t SharedEvent::value() const {
|
||||||
nvtx3::scoped_range r("cu::SharedEvent::value");
|
nvtx3::scoped_range r("cu::SharedEvent::value");
|
||||||
return ac_->load();
|
return to_atomic(buf_)->load();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cu
|
} // namespace cu
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/stream.h"
|
#include "mlx/stream.h"
|
||||||
|
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
@@ -55,12 +56,8 @@ class SharedEvent {
|
|||||||
bool is_signaled(uint64_t value) const;
|
bool is_signaled(uint64_t value) const;
|
||||||
uint64_t value() const;
|
uint64_t value() const;
|
||||||
|
|
||||||
const std::shared_ptr<Atomic>& atomic() const {
|
|
||||||
return ac_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<Atomic> ac_;
|
std::shared_ptr<mlx::core::allocator::Buffer> buf_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core::cu
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/cuda/worker.h"
|
#include "mlx/backend/cuda/worker.h"
|
||||||
#include "mlx/backend/cuda/allocator.h"
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
@@ -23,13 +22,6 @@ void Worker::add_task(std::function<void()> task) {
|
|||||||
pending_tasks_.push_back(std::move(task));
|
pending_tasks_.push_back(std::move(task));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Worker::consume_in_this_thread() {
|
|
||||||
for (auto& task : pending_tasks_) {
|
|
||||||
task();
|
|
||||||
}
|
|
||||||
pending_tasks_.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
void Worker::end_batch() {
|
void Worker::end_batch() {
|
||||||
batch_++;
|
batch_++;
|
||||||
{
|
{
|
||||||
@@ -39,14 +31,6 @@ void Worker::end_batch() {
|
|||||||
uncommited_batches_++;
|
uncommited_batches_++;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Worker::commit() {
|
|
||||||
if (uncommited_batches_ == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
uncommited_batches_ = 0;
|
|
||||||
worker_event_.signal(batch_);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Worker::commit(cudaStream_t stream) {
|
void Worker::commit(cudaStream_t stream) {
|
||||||
if (uncommited_batches_ == 0) {
|
if (uncommited_batches_ == 0) {
|
||||||
return;
|
return;
|
||||||
@@ -60,9 +44,6 @@ void Worker::commit(cudaStream_t stream) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Worker::thread_fn() {
|
void Worker::thread_fn() {
|
||||||
// The worker thread is safe to free buffers.
|
|
||||||
allocator().register_this_thread();
|
|
||||||
|
|
||||||
while (!stop_) {
|
while (!stop_) {
|
||||||
uint64_t batch = worker_event_.value();
|
uint64_t batch = worker_event_.value();
|
||||||
Tasks tasks;
|
Tasks tasks;
|
||||||
|
|||||||
@@ -24,15 +24,9 @@ class Worker {
|
|||||||
// Add a pending |task| that will run when consumed or commited.
|
// Add a pending |task| that will run when consumed or commited.
|
||||||
void add_task(std::function<void()> task);
|
void add_task(std::function<void()> task);
|
||||||
|
|
||||||
// Run pending tasks immediately in current thread.
|
|
||||||
void consume_in_this_thread();
|
|
||||||
|
|
||||||
// Put pending tasks in a batch.
|
// Put pending tasks in a batch.
|
||||||
void end_batch();
|
void end_batch();
|
||||||
|
|
||||||
// Inform worker thread to run current batches now.
|
|
||||||
void commit();
|
|
||||||
|
|
||||||
// Inform worker thread to run current batches after kernels in |stream|
|
// Inform worker thread to run current batches after kernels in |stream|
|
||||||
// finish running.
|
// finish running.
|
||||||
void commit(cudaStream_t stream);
|
void commit(cudaStream_t stream);
|
||||||
|
|||||||
Reference in New Issue
Block a user