Add memory cache to CUDA backend (#2221)

* Move BufferCache out of allocator

* Add memory cache to cuda backend allocator

* Simplify BufferCache assuming buf can not be null
This commit is contained in:
Cheng 2025-05-31 04:12:54 +09:00 committed by GitHub
parent 6ef2f67e7f
commit db5a7c6192
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 259 additions and 203 deletions

View File

@ -0,0 +1,157 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include <cassert>
#include <functional>
#include <map>
namespace mlx::core {
template <typename T>
class BufferCache {
public:
BufferCache(
size_t page_size,
std::function<size_t(T*)> get_size,
std::function<void(T*)> free)
: page_size_(page_size),
get_size_(std::move(get_size)),
free_(std::move(free)) {}
~BufferCache() {
clear();
}
BufferCache(const BufferCache&) = delete;
BufferCache& operator=(const BufferCache&) = delete;
T* reuse_from_cache(size_t size) {
// Find the closest buffer in pool.
auto it = buffer_pool_.lower_bound(size);
if (it == buffer_pool_.end() ||
it->first >= std::min(2 * size, size + 2 * page_size_)) {
return nullptr;
}
// Collect from the cache.
T* buf = it->second->buf;
pool_size_ -= it->first;
// Remove from record.
remove_from_list(it->second);
buffer_pool_.erase(it);
return buf;
}
void recycle_to_cache(T* buf) {
assert(buf);
// Add to cache.
BufferHolder* bh = new BufferHolder(buf);
add_at_head(bh);
size_t size = get_size_(buf);
pool_size_ += size;
buffer_pool_.emplace(size, bh);
}
int release_cached_buffers(size_t min_bytes_to_free) {
if (min_bytes_to_free >= 0.9 * pool_size_) {
return clear();
} else {
int n_release = 0;
size_t total_bytes_freed = 0;
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
// Release buffer.
size_t size = get_size_(tail_->buf);
total_bytes_freed += size;
free_(tail_->buf);
n_release++;
// Remove from record.
auto its = buffer_pool_.equal_range(size);
auto it = std::find_if(its.first, its.second, [this](const auto& el) {
return el.second == tail_;
});
assert(it != buffer_pool_.end());
buffer_pool_.erase(it);
remove_from_list(tail_);
}
pool_size_ -= total_bytes_freed;
return n_release;
}
}
int clear() {
int n_release = 0;
for (auto& [size, holder] : buffer_pool_) {
free_(holder->buf);
n_release++;
delete holder;
}
buffer_pool_.clear();
pool_size_ = 0;
head_ = nullptr;
tail_ = nullptr;
return n_release;
}
size_t cache_size() const {
return pool_size_;
}
size_t page_size() const {
return page_size_;
}
private:
struct BufferHolder {
public:
explicit BufferHolder(T* buf_) : buf(buf_) {}
BufferHolder* prev{nullptr};
BufferHolder* next{nullptr};
T* buf;
};
void add_at_head(BufferHolder* to_add) {
if (!head_) {
head_ = to_add;
tail_ = to_add;
} else {
head_->prev = to_add;
to_add->next = head_;
head_ = to_add;
}
}
void remove_from_list(BufferHolder* to_remove) {
if (to_remove->prev && to_remove->next) { // if middle
to_remove->prev->next = to_remove->next;
to_remove->next->prev = to_remove->prev;
} else if (to_remove->prev && to_remove == tail_) { // if tail
tail_ = to_remove->prev;
tail_->next = nullptr;
} else if (to_remove == head_ && to_remove->next) { // if head
head_ = to_remove->next;
head_->prev = nullptr;
} else if (to_remove == head_ && to_remove == tail_) { // if only element
head_ = nullptr;
tail_ = nullptr;
}
delete to_remove;
}
std::multimap<size_t, BufferHolder*> buffer_pool_;
BufferHolder* head_{nullptr};
BufferHolder* tail_{nullptr};
size_t pool_size_{0};
const size_t page_size_;
std::function<size_t(T*)> get_size_;
std::function<void(T*)> free_;
};
} // namespace mlx::core

View File

@ -6,6 +6,7 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <fmt/format.h> #include <fmt/format.h>
#include <unistd.h>
#include <cassert> #include <cassert>
@ -13,24 +14,47 @@ namespace mlx::core {
namespace cu { namespace cu {
CudaAllocator::CudaAllocator() { CudaAllocator::CudaAllocator()
: buffer_cache_(
getpagesize(),
[](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { cuda_free(buf); }) {
// TODO: Set memory limit for multi-device. // TODO: Set memory limit for multi-device.
size_t free, total; size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
memory_limit_ = total * 0.8; memory_limit_ = total * 0.8;
max_pool_size_ = memory_limit_;
} }
Buffer CudaAllocator::malloc(size_t size) { Buffer CudaAllocator::malloc(size_t size) {
// TODO: Check memory limit. // Find available buffer from cache.
auto* buf = new CudaBuffer{nullptr, size}; std::unique_lock lock(mutex_);
cudaError_t err = cudaMallocManaged(&buf->data, size); CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { if (!buf) {
throw std::runtime_error( // If we have a lot of memory pressure or are over the maximum cache size,
fmt::format("cudaMallocManaged failed: {}.", cudaGetErrorString(err))); // try to reclaim memory from the cache.
size_t mem_required = get_active_memory() + get_cache_memory() + size;
if (mem_required >= memory_limit_) {
buffer_cache_.release_cached_buffers(mem_required - memory_limit_);
}
lock.unlock();
buf = new CudaBuffer{nullptr, size};
cudaError_t err = cudaMallocManaged(&buf->data, size);
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
}
lock.lock();
} }
std::lock_guard lock(mutex_);
active_memory_ += size; active_memory_ += size;
peak_memory_ = std::max(active_memory_, peak_memory_); peak_memory_ = std::max(active_memory_, peak_memory_);
// Maintain the cache below the requested limit.
if (get_cache_memory() > max_pool_size_) {
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
}
return Buffer{buf}; return Buffer{buf};
} }
@ -40,26 +64,14 @@ void CudaAllocator::free(Buffer buffer) {
return; return;
} }
// If free() is called from a unregistered thread, reschedule the call to std::unique_lock lock(mutex_);
// worker. active_memory_ -= buf->size;
{ if (get_cache_memory() < max_pool_size_) {
std::lock_guard lock(worker_mutex_); buffer_cache_.recycle_to_cache(buf);
if (allowed_threads_.count(std::this_thread::get_id()) == 0) { } else {
if (!worker_) { lock.unlock();
worker_.reset(new Worker); cuda_free(buf);
}
worker_->add_task([buffer]() { allocator().free(buffer); });
worker_->end_batch();
worker_->commit();
return;
}
} }
size_t size = buf->size;
cudaFree(buf->data);
delete buf;
std::lock_guard lock(mutex_);
active_memory_ -= size;
} }
size_t CudaAllocator::size(Buffer buffer) const { size_t CudaAllocator::size(Buffer buffer) const {
@ -98,6 +110,41 @@ size_t CudaAllocator::set_memory_limit(size_t limit) {
return limit; return limit;
} }
size_t CudaAllocator::get_cache_memory() const {
return buffer_cache_.cache_size();
}
size_t CudaAllocator::set_cache_limit(size_t limit) {
std::lock_guard lk(mutex_);
std::swap(limit, max_pool_size_);
return limit;
}
void CudaAllocator::clear_cache() {
std::lock_guard lk(mutex_);
buffer_cache_.clear();
}
void CudaAllocator::cuda_free(CudaBuffer* buf) {
// If cuda_free() is called from a unregistered thread, reschedule the call to
// worker.
{
std::lock_guard lock(worker_mutex_);
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
if (!worker_) {
worker_.reset(new Worker);
}
worker_->add_task([this, buf]() { this->cuda_free(buf); });
worker_->end_batch();
worker_->commit();
return;
}
}
cudaFree(buf->data);
delete buf;
}
CudaAllocator& allocator() { CudaAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of CudaAllocator // By creating the |allocator_| on heap, the destructor of CudaAllocator
// will not be called on exit and buffers in the cache will be leaked. This // will not be called on exit and buffers in the cache will be leaked. This
@ -138,17 +185,19 @@ size_t set_memory_limit(size_t limit) {
size_t get_memory_limit() { size_t get_memory_limit() {
return cu::allocator().get_memory_limit(); return cu::allocator().get_memory_limit();
} }
// TODO: Implement buffer cache.
size_t get_cache_memory() { size_t get_cache_memory() {
return 0; return cu::allocator().get_cache_memory();
} }
size_t set_cache_limit(size_t) { size_t set_cache_limit(size_t limit) {
return 0; return cu::allocator().set_cache_limit(limit);
} }
void clear_cache() {
cu::allocator().clear_cache();
}
// Not supported in CUDA.
size_t set_wired_limit(size_t) { size_t set_wired_limit(size_t) {
return 0; return 0;
} }
void clear_cache() {}
} // namespace mlx::core } // namespace mlx::core

View File

@ -3,6 +3,7 @@
#pragma once #pragma once
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/buffer_cache.h"
#include <mutex> #include <mutex>
#include <set> #include <set>
@ -38,17 +39,24 @@ class CudaAllocator : public allocator::Allocator {
void reset_peak_memory(); void reset_peak_memory();
size_t get_memory_limit(); size_t get_memory_limit();
size_t set_memory_limit(size_t limit); size_t set_memory_limit(size_t limit);
size_t get_cache_memory() const;
size_t set_cache_limit(size_t limit);
void clear_cache();
private: private:
CudaAllocator(); CudaAllocator();
friend CudaAllocator& allocator(); friend CudaAllocator& allocator();
void cuda_free(CudaBuffer* buf);
std::mutex worker_mutex_; std::mutex worker_mutex_;
std::unique_ptr<Worker> worker_; std::unique_ptr<Worker> worker_;
std::set<std::thread::id> allowed_threads_; std::set<std::thread::id> allowed_threads_;
std::mutex mutex_; std::mutex mutex_;
size_t memory_limit_; size_t memory_limit_;
size_t max_pool_size_;
BufferCache<CudaBuffer> buffer_cache_;
size_t active_memory_{0}; size_t active_memory_{0};
size_t peak_memory_{0}; size_t peak_memory_{0};
}; };

View File

@ -30,141 +30,18 @@ void* Buffer::raw_ptr() {
namespace metal { namespace metal {
namespace {
BufferCache::BufferCache(ResidencySet& residency_set)
: head_(nullptr),
tail_(nullptr),
pool_size_(0),
residency_set_(residency_set) {}
BufferCache::~BufferCache() {
auto pool = metal::new_scoped_memory_pool();
clear();
}
int BufferCache::clear() {
int n_release = 0;
for (auto& [size, holder] : buffer_pool_) {
if (holder->buf) {
if (!holder->buf->heap()) {
residency_set_.erase(holder->buf);
}
holder->buf->release();
n_release++;
}
delete holder;
}
buffer_pool_.clear();
pool_size_ = 0;
head_ = nullptr;
tail_ = nullptr;
return n_release;
}
MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
// Find the closest buffer in pool
MTL::Buffer* pbuf = nullptr;
auto it = buffer_pool_.lower_bound(size);
// Make sure we use most of the available memory
while (!pbuf && it != buffer_pool_.end() &&
it->first < std::min(2 * size, size + 2 * vm_page_size)) {
// Collect from the cache
pbuf = it->second->buf;
// Remove from cache
remove_from_list(it->second);
delete it->second;
it = buffer_pool_.erase(it);
}
if (pbuf) {
pool_size_ -= pbuf->length();
}
return pbuf;
}
void BufferCache::recycle_to_cache(MTL::Buffer* buf) {
// Add to cache
if (buf) {
BufferHolder* bh = new BufferHolder(buf);
add_at_head(bh);
pool_size_ += buf->length();
buffer_pool_.insert({buf->length(), bh});
}
}
int BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
if (min_bytes_to_free >= 0.9 * pool_size_) {
return clear();
} else {
int n_release = 0;
size_t total_bytes_freed = 0;
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
if (tail_->buf) {
total_bytes_freed += tail_->buf->length();
if (!tail_->buf->heap()) {
residency_set_.erase(tail_->buf);
}
tail_->buf->release();
tail_->buf = nullptr;
n_release++;
}
remove_from_list(tail_);
}
pool_size_ -= total_bytes_freed;
return n_release;
}
}
void BufferCache::add_at_head(BufferCache::BufferHolder* to_add) {
if (!to_add)
return;
if (!head_) {
head_ = to_add;
tail_ = to_add;
} else {
head_->prev = to_add;
to_add->next = head_;
head_ = to_add;
}
}
void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
if (!to_remove) {
return;
}
// If in the middle
if (to_remove->prev && to_remove->next) {
to_remove->prev->next = to_remove->next;
to_remove->next->prev = to_remove->prev;
} else if (to_remove->prev && to_remove == tail_) { // If tail
tail_ = to_remove->prev;
tail_->next = nullptr;
} else if (to_remove == head_ && to_remove->next) { // If head
head_ = to_remove->next;
head_->prev = nullptr;
} else if (to_remove == head_ && to_remove == tail_) { // If only element
head_ = nullptr;
tail_ = nullptr;
}
to_remove->prev = nullptr;
to_remove->next = nullptr;
}
} // namespace
MetalAllocator::MetalAllocator() MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()), : device_(device(mlx::core::Device::gpu).mtl_device()),
residency_set_(device_), residency_set_(device_),
buffer_cache_(residency_set_) { buffer_cache_(
vm_page_size,
[](MTL::Buffer* buf) { return buf->length(); },
[this](MTL::Buffer* buf) {
if (!buf->heap()) {
residency_set_.erase(buf);
}
buf->release();
}) {
auto pool = metal::new_scoped_memory_pool(); auto pool = metal::new_scoped_memory_pool();
auto memsize = std::get<size_t>(device_info().at("memory_size")); auto memsize = std::get<size_t>(device_info().at("memory_size"));
auto max_rec_size = auto max_rec_size =
@ -193,6 +70,7 @@ MetalAllocator::~MetalAllocator() {
if (heap_) { if (heap_) {
heap_->release(); heap_->release();
} }
buffer_cache_.clear();
} }
size_t MetalAllocator::set_cache_limit(size_t limit) { size_t MetalAllocator::set_cache_limit(size_t limit) {

View File

@ -7,6 +7,7 @@
#include <vector> #include <vector>
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/buffer_cache.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/resident.h" #include "mlx/backend/metal/resident.h"
@ -14,43 +15,6 @@ namespace mlx::core::metal {
using allocator::Buffer; using allocator::Buffer;
namespace {
class BufferCache {
public:
BufferCache(ResidencySet& residency_set);
~BufferCache();
MTL::Buffer* reuse_from_cache(size_t size);
void recycle_to_cache(MTL::Buffer* buf);
int release_cached_buffers(size_t min_bytes_to_free);
size_t cache_size() {
return pool_size_;
}
int clear();
private:
struct BufferHolder {
public:
BufferHolder(MTL::Buffer* buf_) : buf(buf_), prev(nullptr), next(nullptr) {}
BufferHolder* prev;
BufferHolder* next;
MTL::Buffer* buf;
};
void add_at_head(BufferHolder* to_add);
void remove_from_list(BufferHolder* to_remove);
std::multimap<size_t, BufferHolder*> buffer_pool_;
BufferHolder* head_;
BufferHolder* tail_;
size_t pool_size_;
ResidencySet& residency_set_;
};
} // namespace
class MetalAllocator : public allocator::Allocator { class MetalAllocator : public allocator::Allocator {
/** Allocator for Metal GPUs. */ /** Allocator for Metal GPUs. */
public: public:
@ -90,7 +54,7 @@ class MetalAllocator : public allocator::Allocator {
friend MetalAllocator& allocator(); friend MetalAllocator& allocator();
// Caching allocator // Caching allocator
BufferCache buffer_cache_; BufferCache<MTL::Buffer> buffer_cache_;
ResidencySet residency_set_; ResidencySet residency_set_;