Add memory cache to cuda backend allocator

This commit is contained in:
Cheng
2025-05-24 02:34:39 +00:00
parent 82e4fbc1fd
commit fe6aba20f7
2 changed files with 90 additions and 33 deletions

View File

@@ -6,6 +6,7 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <fmt/format.h> #include <fmt/format.h>
#include <unistd.h>
#include <cassert> #include <cassert>
@@ -13,24 +14,47 @@ namespace mlx::core {
namespace cu { namespace cu {
CudaAllocator::CudaAllocator() { CudaAllocator::CudaAllocator()
: buffer_cache_(
getpagesize(),
[](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { cuda_free(buf); }) {
// TODO: Set memory limit for multi-device. // TODO: Set memory limit for multi-device.
size_t free, total; size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.8; memory_limit_ = total * 0.8;
max_pool_size_ = memory_limit_;
} }
Buffer CudaAllocator::malloc(size_t size) { Buffer CudaAllocator::malloc(size_t size) {
// TODO: Check memory limit. // Find available buffer from cache.
auto* buf = new CudaBuffer{nullptr, size}; std::unique_lock lock(mutex_);
cudaError_t err = cudaMallocManaged(&buf->data, size); CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { if (!buf) {
throw std::runtime_error( // If we have a lot of memory pressure or are over the maximum cache size,
fmt::format("cudaMallocManaged failed: {}.", cudaGetErrorString(err))); // try to reclaim memory from the cache.
size_t mem_required = get_active_memory() + get_cache_memory() + size;
if (mem_required >= memory_limit_) {
buffer_cache_.release_cached_buffers(mem_required - memory_limit_);
}
lock.unlock();
buf = new CudaBuffer{nullptr, size};
cudaError_t err = cudaMallocManaged(&buf->data, size);
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
}
lock.lock();
} }
std::lock_guard lock(mutex_);
active_memory_ += size; active_memory_ += size;
peak_memory_ = std::max(active_memory_, peak_memory_); peak_memory_ = std::max(active_memory_, peak_memory_);
// Maintain the cache below the requested limit.
if (get_cache_memory() > max_pool_size_) {
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
}
return Buffer{buf}; return Buffer{buf};
} }
@@ -40,26 +64,14 @@ void CudaAllocator::free(Buffer buffer) {
return; return;
} }
// If free() is called from a unregistered thread, reschedule the call to std::unique_lock lock(mutex_);
// worker. active_memory_ -= buf->size;
{ if (get_cache_memory() < max_pool_size_) {
std::lock_guard lock(worker_mutex_); buffer_cache_.recycle_to_cache(buf);
if (allowed_threads_.count(std::this_thread::get_id()) == 0) { } else {
if (!worker_) { lock.unlock();
worker_.reset(new Worker); cuda_free(buf);
}
worker_->add_task([buffer]() { allocator().free(buffer); });
worker_->end_batch();
worker_->commit();
return;
}
} }
size_t size = buf->size;
cudaFree(buf->data);
delete buf;
std::lock_guard lock(mutex_);
active_memory_ -= size;
} }
size_t CudaAllocator::size(Buffer buffer) const { size_t CudaAllocator::size(Buffer buffer) const {
@@ -98,6 +110,41 @@ size_t CudaAllocator::set_memory_limit(size_t limit) {
return limit; return limit;
} }
size_t CudaAllocator::get_cache_memory() const {
return buffer_cache_.cache_size();
}
size_t CudaAllocator::set_cache_limit(size_t limit) {
std::lock_guard lk(mutex_);
std::swap(limit, max_pool_size_);
return limit;
}
void CudaAllocator::clear_cache() {
std::lock_guard lk(mutex_);
buffer_cache_.clear();
}
void CudaAllocator::cuda_free(CudaBuffer* buf) {
// If cuda_free() is called from a unregistered thread, reschedule the call to
// worker.
{
std::lock_guard lock(worker_mutex_);
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
if (!worker_) {
worker_.reset(new Worker);
}
worker_->add_task([this, buf]() { this->cuda_free(buf); });
worker_->end_batch();
worker_->commit();
return;
}
}
cudaFree(buf->data);
delete buf;
}
CudaAllocator& allocator() { CudaAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of CudaAllocator // By creating the |allocator_| on heap, the destructor of CudaAllocator
// will not be called on exit and buffers in the cache will be leaked. This // will not be called on exit and buffers in the cache will be leaked. This
@@ -138,17 +185,19 @@ size_t set_memory_limit(size_t limit) {
size_t get_memory_limit() { size_t get_memory_limit() {
return cu::allocator().get_memory_limit(); return cu::allocator().get_memory_limit();
} }
// TODO: Implement buffer cache.
size_t get_cache_memory() { size_t get_cache_memory() {
return 0; return cu::allocator().get_cache_memory();
} }
size_t set_cache_limit(size_t) { size_t set_cache_limit(size_t limit) {
return 0; return cu::allocator().set_cache_limit(limit);
} }
void clear_cache() {
cu::allocator().clear_cache();
}
// Not supported in CUDA.
size_t set_wired_limit(size_t) { size_t set_wired_limit(size_t) {
return 0; return 0;
} }
void clear_cache() {}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -3,6 +3,7 @@
#pragma once #pragma once
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/buffer_cache.h"
#include <mutex> #include <mutex>
#include <set> #include <set>
@@ -38,17 +39,24 @@ class CudaAllocator : public allocator::Allocator {
void reset_peak_memory(); void reset_peak_memory();
size_t get_memory_limit(); size_t get_memory_limit();
size_t set_memory_limit(size_t limit); size_t set_memory_limit(size_t limit);
size_t get_cache_memory() const;
size_t set_cache_limit(size_t limit);
void clear_cache();
private: private:
CudaAllocator(); CudaAllocator();
friend CudaAllocator& allocator(); friend CudaAllocator& allocator();
void cuda_free(CudaBuffer* buf);
std::mutex worker_mutex_; std::mutex worker_mutex_;
std::unique_ptr<Worker> worker_; std::unique_ptr<Worker> worker_;
std::set<std::thread::id> allowed_threads_; std::set<std::thread::id> allowed_threads_;
std::mutex mutex_; std::mutex mutex_;
size_t memory_limit_; size_t memory_limit_;
size_t max_pool_size_;
BufferCache<CudaBuffer> buffer_cache_;
size_t active_memory_{0}; size_t active_memory_{0};
size_t peak_memory_{0}; size_t peak_memory_{0};
}; };