mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-29 05:31:18 +08:00
[CUDA] Simplify allocator (#2392)
* simplify allocator and fixe race with small pool * Don't use shared event in worker * use cuda buffer in small pool * comment * comment
This commit is contained in:
parent
74eccbf3fa
commit
1e496ddb82
@ -2,7 +2,6 @@
|
||||
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
@ -25,52 +24,58 @@ constexpr int small_block_size = 8;
|
||||
constexpr int small_pool_size = 4 * page_size;
|
||||
|
||||
SmallSizePool::SmallSizePool() {
|
||||
CHECK_CUDA_ERROR(cudaMallocManaged(&buffer_, small_pool_size));
|
||||
end_ = reinterpret_cast<void*>(
|
||||
reinterpret_cast<char*>(buffer_) + small_pool_size);
|
||||
next_free_ = reinterpret_cast<Block*>(buffer_);
|
||||
|
||||
auto num_blocks = small_pool_size / small_block_size;
|
||||
buffer_ = new Block[num_blocks];
|
||||
|
||||
next_free_ = buffer_;
|
||||
|
||||
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
|
||||
CHECK_CUDA_ERROR(
|
||||
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, 0));
|
||||
|
||||
auto curr = next_free_;
|
||||
for (size_t i = 0; i < num_blocks - 1; ++i) {
|
||||
curr->next = reinterpret_cast<Block*>(
|
||||
reinterpret_cast<char*>(buffer_) + (i + 1) * small_block_size);
|
||||
for (size_t i = 1; i < num_blocks; ++i) {
|
||||
curr->next = buffer_ + i;
|
||||
curr = curr->next;
|
||||
}
|
||||
curr->next = nullptr;
|
||||
}
|
||||
|
||||
SmallSizePool::~SmallSizePool() {
|
||||
CHECK_CUDA_ERROR(cudaFree(buffer_));
|
||||
CHECK_CUDA_ERROR(cudaFree(data_));
|
||||
delete[] buffer_;
|
||||
}
|
||||
|
||||
void* SmallSizePool::malloc() {
|
||||
CudaBuffer* SmallSizePool::malloc() {
|
||||
if (next_free_ == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
Block* b = next_free_;
|
||||
uint64_t i = next_free_ - buffer_;
|
||||
next_free_ = next_free_->next;
|
||||
return static_cast<void*>(b);
|
||||
b->buf.data = static_cast<char*>(data_) + i * small_block_size;
|
||||
b->buf.size = small_block_size;
|
||||
return &b->buf;
|
||||
}
|
||||
|
||||
void SmallSizePool::free(void* p) {
|
||||
auto b = static_cast<Block*>(p);
|
||||
void SmallSizePool::free(CudaBuffer* buf) {
|
||||
auto b = reinterpret_cast<Block*>(buf);
|
||||
b->next = next_free_;
|
||||
next_free_ = b;
|
||||
}
|
||||
|
||||
bool SmallSizePool::in_pool(void* p) {
|
||||
return (p >= buffer_) && (p < end_);
|
||||
bool SmallSizePool::in_pool(CudaBuffer* buf) {
|
||||
constexpr int num_blocks = (small_pool_size / small_block_size);
|
||||
auto b = reinterpret_cast<Block*>(buf);
|
||||
int64_t block_num = b - buffer_;
|
||||
return block_num >= 0 && block_num < num_blocks;
|
||||
}
|
||||
|
||||
CudaAllocator::CudaAllocator()
|
||||
: buffer_cache_(
|
||||
page_size,
|
||||
[](CudaBuffer* buf) { return buf->size; },
|
||||
[this](CudaBuffer* buf) {
|
||||
cuda_free(buf->data);
|
||||
delete buf;
|
||||
}) {
|
||||
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
||||
// TODO: Set memory limit for multi-device.
|
||||
size_t free, total;
|
||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||
@ -92,28 +97,26 @@ Buffer CudaAllocator::malloc(size_t size) {
|
||||
|
||||
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||
if (!buf) {
|
||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||
// try to reclaim memory from the cache.
|
||||
size_t mem_required = get_active_memory() + get_cache_memory() + size;
|
||||
if (mem_required >= memory_limit_) {
|
||||
buffer_cache_.release_cached_buffers(mem_required - memory_limit_);
|
||||
// If we have a lot of memory pressure try to reclaim memory from the cache.
|
||||
int64_t mem_to_free =
|
||||
get_active_memory() + get_cache_memory() + size - memory_limit_;
|
||||
if (mem_to_free > 0) {
|
||||
buffer_cache_.release_cached_buffers(mem_to_free);
|
||||
}
|
||||
|
||||
lock.unlock();
|
||||
buf = new CudaBuffer{nullptr, size};
|
||||
|
||||
// Try the scalar pool first
|
||||
if (size <= small_block_size) {
|
||||
buf->data = scalar_pool_.malloc();
|
||||
buf = scalar_pool_.malloc();
|
||||
}
|
||||
if (!buf->data) {
|
||||
lock.unlock();
|
||||
if (!buf) {
|
||||
buf = new CudaBuffer{nullptr, size};
|
||||
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||
}
|
||||
}
|
||||
|
||||
lock.lock();
|
||||
}
|
||||
active_memory_ += size;
|
||||
@ -123,7 +126,6 @@ Buffer CudaAllocator::malloc(size_t size) {
|
||||
if (get_cache_memory() > max_pool_size_) {
|
||||
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
||||
}
|
||||
|
||||
return Buffer{buf};
|
||||
}
|
||||
|
||||
@ -138,9 +140,7 @@ void CudaAllocator::free(Buffer buffer) {
|
||||
if (get_cache_memory() < max_pool_size_) {
|
||||
buffer_cache_.recycle_to_cache(buf);
|
||||
} else {
|
||||
lock.unlock();
|
||||
cuda_free(buf->data);
|
||||
delete buf;
|
||||
cuda_free(buf);
|
||||
}
|
||||
}
|
||||
|
||||
@ -152,30 +152,13 @@ size_t CudaAllocator::size(Buffer buffer) const {
|
||||
return buf->size;
|
||||
}
|
||||
|
||||
void CudaAllocator::register_this_thread() {
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
allowed_threads_.insert(std::this_thread::get_id());
|
||||
}
|
||||
|
||||
void CudaAllocator::cuda_free(void* buf) {
|
||||
// If cuda_free() is called from a unregistered thread, reschedule the call to
|
||||
// worker.
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
|
||||
if (!worker_) {
|
||||
worker_.reset(new Worker);
|
||||
}
|
||||
worker_->add_task([this, buf]() { this->cuda_free(buf); });
|
||||
worker_->end_batch();
|
||||
worker_->commit();
|
||||
return;
|
||||
}
|
||||
}
|
||||
// This must be called with mutex_ aquired
|
||||
void CudaAllocator::cuda_free(CudaBuffer* buf) {
|
||||
if (scalar_pool_.in_pool(buf)) {
|
||||
scalar_pool_.free(buf);
|
||||
} else {
|
||||
cudaFree(buf);
|
||||
cudaFree(buf->data);
|
||||
delete buf;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -7,13 +7,10 @@
|
||||
|
||||
#include <mutex>
|
||||
#include <set>
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
class Worker;
|
||||
|
||||
using allocator::Buffer;
|
||||
|
||||
// Stores cuda-managed unified memory.
|
||||
@ -24,13 +21,14 @@ struct CudaBuffer {
|
||||
|
||||
class SmallSizePool {
|
||||
private:
|
||||
struct Block {
|
||||
union Block {
|
||||
Block* next;
|
||||
CudaBuffer buf;
|
||||
};
|
||||
|
||||
void* buffer_{nullptr};
|
||||
Block* buffer_{nullptr};
|
||||
void* data_{nullptr};
|
||||
Block* next_free_{nullptr};
|
||||
void* end_{nullptr};
|
||||
|
||||
public:
|
||||
SmallSizePool();
|
||||
@ -39,9 +37,9 @@ class SmallSizePool {
|
||||
SmallSizePool(const SmallSizePool&) = delete;
|
||||
SmallSizePool& operator=(const SmallSizePool&) = delete;
|
||||
|
||||
void* malloc();
|
||||
void free(void* p);
|
||||
bool in_pool(void* p);
|
||||
CudaBuffer* malloc();
|
||||
void free(CudaBuffer* buf);
|
||||
bool in_pool(CudaBuffer* buf);
|
||||
};
|
||||
|
||||
class CudaAllocator : public allocator::Allocator {
|
||||
@ -50,15 +48,6 @@ class CudaAllocator : public allocator::Allocator {
|
||||
void free(Buffer buffer) override;
|
||||
size_t size(Buffer buffer) const override;
|
||||
|
||||
// Register current thread as safe to free buffers.
|
||||
// In cuda freeing a buffer implicitly synchronizes stream, and for threads
|
||||
// that may be waited by gpu stream (for example cpu stream threads), freeing
|
||||
// buffers there would result in dead lock.
|
||||
void register_this_thread();
|
||||
|
||||
// Call cudaFree in the safe thread.
|
||||
void cuda_free(void* buf);
|
||||
|
||||
size_t get_active_memory() const;
|
||||
size_t get_peak_memory() const;
|
||||
void reset_peak_memory();
|
||||
@ -69,13 +58,11 @@ class CudaAllocator : public allocator::Allocator {
|
||||
void clear_cache();
|
||||
|
||||
private:
|
||||
void cuda_free(CudaBuffer* buf);
|
||||
|
||||
CudaAllocator();
|
||||
friend CudaAllocator& allocator();
|
||||
|
||||
std::mutex worker_mutex_;
|
||||
std::unique_ptr<Worker> worker_;
|
||||
std::set<std::thread::id> allowed_threads_;
|
||||
|
||||
std::mutex mutex_;
|
||||
size_t memory_limit_;
|
||||
size_t max_pool_size_;
|
||||
|
@ -306,7 +306,6 @@ void CommandEncoder::commit() {
|
||||
}
|
||||
|
||||
// Put completion handlers in a batch.
|
||||
worker_.end_batch();
|
||||
worker_.commit(stream_);
|
||||
}
|
||||
|
||||
@ -315,7 +314,6 @@ void CommandEncoder::synchronize() {
|
||||
auto p = std::make_shared<std::promise<void>>();
|
||||
std::future<void> f = p->get_future();
|
||||
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
||||
worker_.end_batch();
|
||||
commit();
|
||||
f.wait();
|
||||
}
|
||||
|
@ -19,8 +19,6 @@ void new_stream(Stream s) {
|
||||
cudaFree(nullptr);
|
||||
// Ensure the static stream objects get created.
|
||||
cu::get_command_encoder(s);
|
||||
// The main thread is safe to free buffers.
|
||||
cu::allocator().register_this_thread();
|
||||
}
|
||||
|
||||
void eval(array& arr) {
|
||||
|
@ -110,24 +110,26 @@ __global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
|
||||
event_signal(ac, value);
|
||||
}
|
||||
|
||||
SharedEvent::Atomic* to_atomic(std::shared_ptr<Buffer> buf) {
|
||||
return static_cast<SharedEvent::Atomic*>(buf->raw_ptr());
|
||||
}
|
||||
|
||||
SharedEvent::SharedEvent() {
|
||||
// Allocate cuda::atomic on managed memory.
|
||||
Atomic* ac;
|
||||
CHECK_CUDA_ERROR(cudaMallocManaged(&ac, sizeof(Atomic)));
|
||||
new (ac) Atomic(0);
|
||||
ac_ = std::shared_ptr<Atomic>(ac, [](Atomic* ptr) {
|
||||
ptr->~Atomic();
|
||||
allocator().cuda_free(ptr);
|
||||
});
|
||||
buf_ = std::shared_ptr<Buffer>(
|
||||
new Buffer{allocator().malloc(sizeof(Atomic))}, [](Buffer* ptr) {
|
||||
allocator().free(*ptr);
|
||||
delete ptr;
|
||||
});
|
||||
*static_cast<uint64_t*>(buf_->raw_ptr()) = 0;
|
||||
}
|
||||
|
||||
void SharedEvent::wait(uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::wait");
|
||||
event_wait(ac_.get(), value);
|
||||
event_wait(to_atomic(buf_), value);
|
||||
}
|
||||
|
||||
void SharedEvent::wait(cudaStream_t stream, uint64_t value) {
|
||||
event_wait_kernel<<<1, 1, 0, stream>>>(ac_.get(), value);
|
||||
event_wait_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
|
||||
}
|
||||
|
||||
void SharedEvent::wait(Stream s, uint64_t value) {
|
||||
@ -138,17 +140,17 @@ void SharedEvent::wait(Stream s, uint64_t value) {
|
||||
auto& encoder = get_command_encoder(s);
|
||||
encoder.commit();
|
||||
wait(encoder.stream(), value);
|
||||
encoder.add_completed_handler([ac = ac_]() {});
|
||||
encoder.add_completed_handler([buf = buf_]() {});
|
||||
}
|
||||
}
|
||||
|
||||
void SharedEvent::signal(uint64_t value) {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::signal");
|
||||
event_signal(ac_.get(), value);
|
||||
event_signal(to_atomic(buf_), value);
|
||||
}
|
||||
|
||||
void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
|
||||
event_signal_kernel<<<1, 1, 0, stream>>>(ac_.get(), value);
|
||||
event_signal_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
|
||||
}
|
||||
|
||||
void SharedEvent::signal(Stream s, uint64_t value) {
|
||||
@ -162,18 +164,18 @@ void SharedEvent::signal(Stream s, uint64_t value) {
|
||||
auto& encoder = get_command_encoder(s);
|
||||
encoder.commit();
|
||||
signal(encoder.stream(), value);
|
||||
encoder.add_completed_handler([ac = ac_]() {});
|
||||
encoder.add_completed_handler([buf = buf_]() {});
|
||||
}
|
||||
}
|
||||
|
||||
bool SharedEvent::is_signaled(uint64_t value) const {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::is_signaled");
|
||||
return ac_->load() >= value;
|
||||
return to_atomic(buf_)->load() >= value;
|
||||
}
|
||||
|
||||
uint64_t SharedEvent::value() const {
|
||||
nvtx3::scoped_range r("cu::SharedEvent::value");
|
||||
return ac_->load();
|
||||
return to_atomic(buf_)->load();
|
||||
}
|
||||
|
||||
} // namespace cu
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
@ -55,12 +56,8 @@ class SharedEvent {
|
||||
bool is_signaled(uint64_t value) const;
|
||||
uint64_t value() const;
|
||||
|
||||
const std::shared_ptr<Atomic>& atomic() const {
|
||||
return ac_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<Atomic> ac_;
|
||||
std::shared_ptr<mlx::core::allocator::Buffer> buf_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
|
@ -1,7 +1,6 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/worker.h"
|
||||
#include "mlx/backend/cuda/allocator.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
@ -12,10 +11,10 @@ Worker::Worker()
|
||||
|
||||
Worker::~Worker() {
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
std::lock_guard lock(mtx_);
|
||||
stop_ = true;
|
||||
}
|
||||
worker_event_.signal(batch_ + 1);
|
||||
cond_.notify_one();
|
||||
worker_.join();
|
||||
}
|
||||
|
||||
@ -23,53 +22,41 @@ void Worker::add_task(std::function<void()> task) {
|
||||
pending_tasks_.push_back(std::move(task));
|
||||
}
|
||||
|
||||
void Worker::consume_in_this_thread() {
|
||||
for (auto& task : pending_tasks_) {
|
||||
task();
|
||||
}
|
||||
pending_tasks_.clear();
|
||||
}
|
||||
|
||||
void Worker::end_batch() {
|
||||
batch_++;
|
||||
void Worker::signal(void* data) {
|
||||
auto w = static_cast<Worker*>(data);
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
worker_tasks_[batch_] = std::move(pending_tasks_);
|
||||
std::lock_guard lock(w->mtx_);
|
||||
w->signaled_batch_++;
|
||||
}
|
||||
uncommited_batches_++;
|
||||
}
|
||||
|
||||
void Worker::commit() {
|
||||
if (uncommited_batches_ == 0) {
|
||||
return;
|
||||
}
|
||||
uncommited_batches_ = 0;
|
||||
worker_event_.signal(batch_);
|
||||
w->cond_.notify_one();
|
||||
}
|
||||
|
||||
void Worker::commit(cudaStream_t stream) {
|
||||
if (uncommited_batches_ == 0) {
|
||||
// Move pending tasks into tasks
|
||||
if (pending_tasks_.empty()) {
|
||||
return;
|
||||
}
|
||||
uncommited_batches_ = 0;
|
||||
// Signal the |worker_event_| in |signal_stream_| after the kernels in
|
||||
// |stream_| finish running.
|
||||
{
|
||||
std::lock_guard lock(mtx_);
|
||||
// Move pending tasks into ready tasks
|
||||
worker_tasks_[++committed_batch_] = std::move(pending_tasks_);
|
||||
}
|
||||
signal_event_.record(stream);
|
||||
signal_event_.wait(signal_stream_);
|
||||
worker_event_.signal(signal_stream_, batch_);
|
||||
cudaLaunchHostFunc(signal_stream_, signal, this);
|
||||
}
|
||||
|
||||
void Worker::thread_fn() {
|
||||
// The worker thread is safe to free buffers.
|
||||
allocator().register_this_thread();
|
||||
|
||||
while (!stop_) {
|
||||
uint64_t batch = worker_event_.value();
|
||||
uint64_t current_batch = 0;
|
||||
Tasks tasks;
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
// Move tasks in signaled batches.
|
||||
auto end = worker_tasks_.upper_bound(batch);
|
||||
std::unique_lock<std::mutex> lk(mtx_);
|
||||
cond_.wait(lk, [this, ¤t_batch] {
|
||||
return this->signaled_batch_ > current_batch || this->stop_;
|
||||
});
|
||||
current_batch = signaled_batch_;
|
||||
auto end = worker_tasks_.upper_bound(current_batch);
|
||||
for (auto it = worker_tasks_.begin(); it != end; ++it) {
|
||||
if (tasks.empty()) {
|
||||
tasks = std::move(it->second);
|
||||
@ -85,7 +72,6 @@ void Worker::thread_fn() {
|
||||
auto task = std::move(tasks[i]);
|
||||
task();
|
||||
}
|
||||
worker_event_.wait(batch + 1);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include "mlx/backend/cuda/event.h"
|
||||
#include "mlx/backend/cuda/utils.h"
|
||||
|
||||
#include <condition_variable>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
@ -24,38 +25,24 @@ class Worker {
|
||||
// Add a pending |task| that will run when consumed or commited.
|
||||
void add_task(std::function<void()> task);
|
||||
|
||||
// Run pending tasks immediately in current thread.
|
||||
void consume_in_this_thread();
|
||||
|
||||
// Put pending tasks in a batch.
|
||||
void end_batch();
|
||||
|
||||
// Inform worker thread to run current batches now.
|
||||
void commit();
|
||||
|
||||
// Inform worker thread to run current batches after kernels in |stream|
|
||||
// finish running.
|
||||
void commit(cudaStream_t stream);
|
||||
|
||||
// Return how many batches have been added but not committed yet.
|
||||
size_t uncommited_batches() const {
|
||||
return uncommited_batches_;
|
||||
}
|
||||
|
||||
private:
|
||||
void thread_fn();
|
||||
static void signal(void*);
|
||||
|
||||
uint64_t batch_{0};
|
||||
size_t uncommited_batches_{0};
|
||||
void thread_fn();
|
||||
std::mutex mtx_;
|
||||
std::condition_variable cond_;
|
||||
|
||||
uint64_t committed_batch_{0};
|
||||
uint64_t signaled_batch_{0};
|
||||
|
||||
// Cuda stream and event for signaling kernel completion.
|
||||
CudaStream signal_stream_;
|
||||
CudaEvent signal_event_;
|
||||
|
||||
// Worker thread.
|
||||
SharedEvent worker_event_;
|
||||
std::thread worker_;
|
||||
std::mutex worker_mutex_;
|
||||
bool stop_{false};
|
||||
|
||||
// Tasks are put in |pending_tasks_| first, and then moved to
|
||||
@ -63,6 +50,7 @@ class Worker {
|
||||
using Tasks = std::vector<std::function<void()>>;
|
||||
Tasks pending_tasks_;
|
||||
std::map<uint64_t, Tasks> worker_tasks_;
|
||||
std::thread worker_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
|
@ -128,8 +128,7 @@ Buffer MetalAllocator::malloc(size_t size) {
|
||||
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
|
||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||
// try to reclaim memory from the cache
|
||||
// If we have a lot of memory pressure try to reclaim memory from the cache
|
||||
if (mem_required >= gc_limit_ || num_resources_ >= resource_limit_) {
|
||||
num_resources_ -=
|
||||
buffer_cache_.release_cached_buffers(mem_required - gc_limit_);
|
||||
|
Loading…
Reference in New Issue
Block a user