mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Add memory cache to CUDA backend (#2221)
* Move BufferCache out of allocator * Add memory cache to cuda backend allocator * Simplify BufferCache assuming buf can not be null
This commit is contained in:
parent
6ef2f67e7f
commit
db5a7c6192
157
mlx/backend/common/buffer_cache.h
Normal file
157
mlx/backend/common/buffer_cache.h
Normal file
@ -0,0 +1,157 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
class BufferCache {
|
||||
public:
|
||||
BufferCache(
|
||||
size_t page_size,
|
||||
std::function<size_t(T*)> get_size,
|
||||
std::function<void(T*)> free)
|
||||
: page_size_(page_size),
|
||||
get_size_(std::move(get_size)),
|
||||
free_(std::move(free)) {}
|
||||
|
||||
~BufferCache() {
|
||||
clear();
|
||||
}
|
||||
|
||||
BufferCache(const BufferCache&) = delete;
|
||||
BufferCache& operator=(const BufferCache&) = delete;
|
||||
|
||||
T* reuse_from_cache(size_t size) {
|
||||
// Find the closest buffer in pool.
|
||||
auto it = buffer_pool_.lower_bound(size);
|
||||
if (it == buffer_pool_.end() ||
|
||||
it->first >= std::min(2 * size, size + 2 * page_size_)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Collect from the cache.
|
||||
T* buf = it->second->buf;
|
||||
pool_size_ -= it->first;
|
||||
|
||||
// Remove from record.
|
||||
remove_from_list(it->second);
|
||||
buffer_pool_.erase(it);
|
||||
return buf;
|
||||
}
|
||||
|
||||
void recycle_to_cache(T* buf) {
|
||||
assert(buf);
|
||||
// Add to cache.
|
||||
BufferHolder* bh = new BufferHolder(buf);
|
||||
add_at_head(bh);
|
||||
size_t size = get_size_(buf);
|
||||
pool_size_ += size;
|
||||
buffer_pool_.emplace(size, bh);
|
||||
}
|
||||
|
||||
int release_cached_buffers(size_t min_bytes_to_free) {
|
||||
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
||||
return clear();
|
||||
} else {
|
||||
int n_release = 0;
|
||||
size_t total_bytes_freed = 0;
|
||||
|
||||
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
||||
// Release buffer.
|
||||
size_t size = get_size_(tail_->buf);
|
||||
total_bytes_freed += size;
|
||||
free_(tail_->buf);
|
||||
n_release++;
|
||||
|
||||
// Remove from record.
|
||||
auto its = buffer_pool_.equal_range(size);
|
||||
auto it = std::find_if(its.first, its.second, [this](const auto& el) {
|
||||
return el.second == tail_;
|
||||
});
|
||||
assert(it != buffer_pool_.end());
|
||||
buffer_pool_.erase(it);
|
||||
remove_from_list(tail_);
|
||||
}
|
||||
|
||||
pool_size_ -= total_bytes_freed;
|
||||
return n_release;
|
||||
}
|
||||
}
|
||||
|
||||
int clear() {
|
||||
int n_release = 0;
|
||||
for (auto& [size, holder] : buffer_pool_) {
|
||||
free_(holder->buf);
|
||||
n_release++;
|
||||
delete holder;
|
||||
}
|
||||
buffer_pool_.clear();
|
||||
pool_size_ = 0;
|
||||
head_ = nullptr;
|
||||
tail_ = nullptr;
|
||||
return n_release;
|
||||
}
|
||||
|
||||
size_t cache_size() const {
|
||||
return pool_size_;
|
||||
}
|
||||
|
||||
size_t page_size() const {
|
||||
return page_size_;
|
||||
}
|
||||
|
||||
private:
|
||||
struct BufferHolder {
|
||||
public:
|
||||
explicit BufferHolder(T* buf_) : buf(buf_) {}
|
||||
|
||||
BufferHolder* prev{nullptr};
|
||||
BufferHolder* next{nullptr};
|
||||
T* buf;
|
||||
};
|
||||
|
||||
void add_at_head(BufferHolder* to_add) {
|
||||
if (!head_) {
|
||||
head_ = to_add;
|
||||
tail_ = to_add;
|
||||
} else {
|
||||
head_->prev = to_add;
|
||||
to_add->next = head_;
|
||||
head_ = to_add;
|
||||
}
|
||||
}
|
||||
|
||||
void remove_from_list(BufferHolder* to_remove) {
|
||||
if (to_remove->prev && to_remove->next) { // if middle
|
||||
to_remove->prev->next = to_remove->next;
|
||||
to_remove->next->prev = to_remove->prev;
|
||||
} else if (to_remove->prev && to_remove == tail_) { // if tail
|
||||
tail_ = to_remove->prev;
|
||||
tail_->next = nullptr;
|
||||
} else if (to_remove == head_ && to_remove->next) { // if head
|
||||
head_ = to_remove->next;
|
||||
head_->prev = nullptr;
|
||||
} else if (to_remove == head_ && to_remove == tail_) { // if only element
|
||||
head_ = nullptr;
|
||||
tail_ = nullptr;
|
||||
}
|
||||
|
||||
delete to_remove;
|
||||
}
|
||||
|
||||
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
||||
BufferHolder* head_{nullptr};
|
||||
BufferHolder* tail_{nullptr};
|
||||
size_t pool_size_{0};
|
||||
|
||||
const size_t page_size_;
|
||||
std::function<size_t(T*)> get_size_;
|
||||
std::function<void(T*)> free_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <fmt/format.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
@ -13,24 +14,47 @@ namespace mlx::core {
|
||||
|
||||
namespace cu {
|
||||
|
||||
CudaAllocator::CudaAllocator() {
|
||||
CudaAllocator::CudaAllocator()
|
||||
: buffer_cache_(
|
||||
getpagesize(),
|
||||
[](CudaBuffer* buf) { return buf->size; },
|
||||
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
||||
// TODO: Set memory limit for multi-device.
|
||||
size_t free, total;
|
||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||
memory_limit_ = total * 0.8;
|
||||
max_pool_size_ = memory_limit_;
|
||||
}
|
||||
|
||||
Buffer CudaAllocator::malloc(size_t size) {
|
||||
// TODO: Check memory limit.
|
||||
auto* buf = new CudaBuffer{nullptr, size};
|
||||
// Find available buffer from cache.
|
||||
std::unique_lock lock(mutex_);
|
||||
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||
if (!buf) {
|
||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||
// try to reclaim memory from the cache.
|
||||
size_t mem_required = get_active_memory() + get_cache_memory() + size;
|
||||
if (mem_required >= memory_limit_) {
|
||||
buffer_cache_.release_cached_buffers(mem_required - memory_limit_);
|
||||
}
|
||||
|
||||
lock.unlock();
|
||||
buf = new CudaBuffer{nullptr, size};
|
||||
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||
throw std::runtime_error(fmt::format(
|
||||
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||
}
|
||||
lock.lock();
|
||||
}
|
||||
std::lock_guard lock(mutex_);
|
||||
active_memory_ += size;
|
||||
peak_memory_ = std::max(active_memory_, peak_memory_);
|
||||
|
||||
// Maintain the cache below the requested limit.
|
||||
if (get_cache_memory() > max_pool_size_) {
|
||||
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
||||
}
|
||||
|
||||
return Buffer{buf};
|
||||
}
|
||||
|
||||
@ -40,26 +64,14 @@ void CudaAllocator::free(Buffer buffer) {
|
||||
return;
|
||||
}
|
||||
|
||||
// If free() is called from a unregistered thread, reschedule the call to
|
||||
// worker.
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
|
||||
if (!worker_) {
|
||||
worker_.reset(new Worker);
|
||||
std::unique_lock lock(mutex_);
|
||||
active_memory_ -= buf->size;
|
||||
if (get_cache_memory() < max_pool_size_) {
|
||||
buffer_cache_.recycle_to_cache(buf);
|
||||
} else {
|
||||
lock.unlock();
|
||||
cuda_free(buf);
|
||||
}
|
||||
worker_->add_task([buffer]() { allocator().free(buffer); });
|
||||
worker_->end_batch();
|
||||
worker_->commit();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
size_t size = buf->size;
|
||||
cudaFree(buf->data);
|
||||
delete buf;
|
||||
std::lock_guard lock(mutex_);
|
||||
active_memory_ -= size;
|
||||
}
|
||||
|
||||
size_t CudaAllocator::size(Buffer buffer) const {
|
||||
@ -98,6 +110,41 @@ size_t CudaAllocator::set_memory_limit(size_t limit) {
|
||||
return limit;
|
||||
}
|
||||
|
||||
size_t CudaAllocator::get_cache_memory() const {
|
||||
return buffer_cache_.cache_size();
|
||||
}
|
||||
|
||||
size_t CudaAllocator::set_cache_limit(size_t limit) {
|
||||
std::lock_guard lk(mutex_);
|
||||
std::swap(limit, max_pool_size_);
|
||||
return limit;
|
||||
}
|
||||
|
||||
void CudaAllocator::clear_cache() {
|
||||
std::lock_guard lk(mutex_);
|
||||
buffer_cache_.clear();
|
||||
}
|
||||
|
||||
void CudaAllocator::cuda_free(CudaBuffer* buf) {
|
||||
// If cuda_free() is called from a unregistered thread, reschedule the call to
|
||||
// worker.
|
||||
{
|
||||
std::lock_guard lock(worker_mutex_);
|
||||
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
|
||||
if (!worker_) {
|
||||
worker_.reset(new Worker);
|
||||
}
|
||||
worker_->add_task([this, buf]() { this->cuda_free(buf); });
|
||||
worker_->end_batch();
|
||||
worker_->commit();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
cudaFree(buf->data);
|
||||
delete buf;
|
||||
}
|
||||
|
||||
CudaAllocator& allocator() {
|
||||
// By creating the |allocator_| on heap, the destructor of CudaAllocator
|
||||
// will not be called on exit and buffers in the cache will be leaked. This
|
||||
@ -138,17 +185,19 @@ size_t set_memory_limit(size_t limit) {
|
||||
size_t get_memory_limit() {
|
||||
return cu::allocator().get_memory_limit();
|
||||
}
|
||||
|
||||
// TODO: Implement buffer cache.
|
||||
size_t get_cache_memory() {
|
||||
return 0;
|
||||
return cu::allocator().get_cache_memory();
|
||||
}
|
||||
size_t set_cache_limit(size_t) {
|
||||
return 0;
|
||||
size_t set_cache_limit(size_t limit) {
|
||||
return cu::allocator().set_cache_limit(limit);
|
||||
}
|
||||
void clear_cache() {
|
||||
cu::allocator().clear_cache();
|
||||
}
|
||||
|
||||
// Not supported in CUDA.
|
||||
size_t set_wired_limit(size_t) {
|
||||
return 0;
|
||||
}
|
||||
void clear_cache() {}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -3,6 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/buffer_cache.h"
|
||||
|
||||
#include <mutex>
|
||||
#include <set>
|
||||
@ -38,17 +39,24 @@ class CudaAllocator : public allocator::Allocator {
|
||||
void reset_peak_memory();
|
||||
size_t get_memory_limit();
|
||||
size_t set_memory_limit(size_t limit);
|
||||
size_t get_cache_memory() const;
|
||||
size_t set_cache_limit(size_t limit);
|
||||
void clear_cache();
|
||||
|
||||
private:
|
||||
CudaAllocator();
|
||||
friend CudaAllocator& allocator();
|
||||
|
||||
void cuda_free(CudaBuffer* buf);
|
||||
|
||||
std::mutex worker_mutex_;
|
||||
std::unique_ptr<Worker> worker_;
|
||||
std::set<std::thread::id> allowed_threads_;
|
||||
|
||||
std::mutex mutex_;
|
||||
size_t memory_limit_;
|
||||
size_t max_pool_size_;
|
||||
BufferCache<CudaBuffer> buffer_cache_;
|
||||
size_t active_memory_{0};
|
||||
size_t peak_memory_{0};
|
||||
};
|
||||
|
@ -30,141 +30,18 @@ void* Buffer::raw_ptr() {
|
||||
|
||||
namespace metal {
|
||||
|
||||
namespace {
|
||||
|
||||
BufferCache::BufferCache(ResidencySet& residency_set)
|
||||
: head_(nullptr),
|
||||
tail_(nullptr),
|
||||
pool_size_(0),
|
||||
residency_set_(residency_set) {}
|
||||
|
||||
BufferCache::~BufferCache() {
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
clear();
|
||||
}
|
||||
|
||||
int BufferCache::clear() {
|
||||
int n_release = 0;
|
||||
for (auto& [size, holder] : buffer_pool_) {
|
||||
if (holder->buf) {
|
||||
if (!holder->buf->heap()) {
|
||||
residency_set_.erase(holder->buf);
|
||||
}
|
||||
holder->buf->release();
|
||||
n_release++;
|
||||
}
|
||||
delete holder;
|
||||
}
|
||||
buffer_pool_.clear();
|
||||
pool_size_ = 0;
|
||||
head_ = nullptr;
|
||||
tail_ = nullptr;
|
||||
return n_release;
|
||||
}
|
||||
|
||||
MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
|
||||
// Find the closest buffer in pool
|
||||
MTL::Buffer* pbuf = nullptr;
|
||||
|
||||
auto it = buffer_pool_.lower_bound(size);
|
||||
|
||||
// Make sure we use most of the available memory
|
||||
while (!pbuf && it != buffer_pool_.end() &&
|
||||
it->first < std::min(2 * size, size + 2 * vm_page_size)) {
|
||||
// Collect from the cache
|
||||
pbuf = it->second->buf;
|
||||
|
||||
// Remove from cache
|
||||
remove_from_list(it->second);
|
||||
delete it->second;
|
||||
it = buffer_pool_.erase(it);
|
||||
}
|
||||
|
||||
if (pbuf) {
|
||||
pool_size_ -= pbuf->length();
|
||||
}
|
||||
|
||||
return pbuf;
|
||||
}
|
||||
|
||||
void BufferCache::recycle_to_cache(MTL::Buffer* buf) {
|
||||
// Add to cache
|
||||
if (buf) {
|
||||
BufferHolder* bh = new BufferHolder(buf);
|
||||
add_at_head(bh);
|
||||
pool_size_ += buf->length();
|
||||
buffer_pool_.insert({buf->length(), bh});
|
||||
}
|
||||
}
|
||||
|
||||
int BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
|
||||
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
||||
return clear();
|
||||
} else {
|
||||
int n_release = 0;
|
||||
size_t total_bytes_freed = 0;
|
||||
|
||||
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
||||
if (tail_->buf) {
|
||||
total_bytes_freed += tail_->buf->length();
|
||||
if (!tail_->buf->heap()) {
|
||||
residency_set_.erase(tail_->buf);
|
||||
}
|
||||
tail_->buf->release();
|
||||
tail_->buf = nullptr;
|
||||
n_release++;
|
||||
}
|
||||
remove_from_list(tail_);
|
||||
}
|
||||
pool_size_ -= total_bytes_freed;
|
||||
return n_release;
|
||||
}
|
||||
}
|
||||
|
||||
void BufferCache::add_at_head(BufferCache::BufferHolder* to_add) {
|
||||
if (!to_add)
|
||||
return;
|
||||
|
||||
if (!head_) {
|
||||
head_ = to_add;
|
||||
tail_ = to_add;
|
||||
} else {
|
||||
head_->prev = to_add;
|
||||
to_add->next = head_;
|
||||
head_ = to_add;
|
||||
}
|
||||
}
|
||||
|
||||
void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
|
||||
if (!to_remove) {
|
||||
return;
|
||||
}
|
||||
|
||||
// If in the middle
|
||||
if (to_remove->prev && to_remove->next) {
|
||||
to_remove->prev->next = to_remove->next;
|
||||
to_remove->next->prev = to_remove->prev;
|
||||
} else if (to_remove->prev && to_remove == tail_) { // If tail
|
||||
tail_ = to_remove->prev;
|
||||
tail_->next = nullptr;
|
||||
} else if (to_remove == head_ && to_remove->next) { // If head
|
||||
head_ = to_remove->next;
|
||||
head_->prev = nullptr;
|
||||
} else if (to_remove == head_ && to_remove == tail_) { // If only element
|
||||
head_ = nullptr;
|
||||
tail_ = nullptr;
|
||||
}
|
||||
|
||||
to_remove->prev = nullptr;
|
||||
to_remove->next = nullptr;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
MetalAllocator::MetalAllocator()
|
||||
: device_(device(mlx::core::Device::gpu).mtl_device()),
|
||||
residency_set_(device_),
|
||||
buffer_cache_(residency_set_) {
|
||||
buffer_cache_(
|
||||
vm_page_size,
|
||||
[](MTL::Buffer* buf) { return buf->length(); },
|
||||
[this](MTL::Buffer* buf) {
|
||||
if (!buf->heap()) {
|
||||
residency_set_.erase(buf);
|
||||
}
|
||||
buf->release();
|
||||
}) {
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
auto memsize = std::get<size_t>(device_info().at("memory_size"));
|
||||
auto max_rec_size =
|
||||
@ -193,6 +70,7 @@ MetalAllocator::~MetalAllocator() {
|
||||
if (heap_) {
|
||||
heap_->release();
|
||||
}
|
||||
buffer_cache_.clear();
|
||||
}
|
||||
|
||||
size_t MetalAllocator::set_cache_limit(size_t limit) {
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/buffer_cache.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/resident.h"
|
||||
|
||||
@ -14,43 +15,6 @@ namespace mlx::core::metal {
|
||||
|
||||
using allocator::Buffer;
|
||||
|
||||
namespace {
|
||||
|
||||
class BufferCache {
|
||||
public:
|
||||
BufferCache(ResidencySet& residency_set);
|
||||
~BufferCache();
|
||||
|
||||
MTL::Buffer* reuse_from_cache(size_t size);
|
||||
void recycle_to_cache(MTL::Buffer* buf);
|
||||
int release_cached_buffers(size_t min_bytes_to_free);
|
||||
size_t cache_size() {
|
||||
return pool_size_;
|
||||
}
|
||||
int clear();
|
||||
|
||||
private:
|
||||
struct BufferHolder {
|
||||
public:
|
||||
BufferHolder(MTL::Buffer* buf_) : buf(buf_), prev(nullptr), next(nullptr) {}
|
||||
|
||||
BufferHolder* prev;
|
||||
BufferHolder* next;
|
||||
MTL::Buffer* buf;
|
||||
};
|
||||
|
||||
void add_at_head(BufferHolder* to_add);
|
||||
void remove_from_list(BufferHolder* to_remove);
|
||||
|
||||
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
||||
BufferHolder* head_;
|
||||
BufferHolder* tail_;
|
||||
size_t pool_size_;
|
||||
ResidencySet& residency_set_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
class MetalAllocator : public allocator::Allocator {
|
||||
/** Allocator for Metal GPUs. */
|
||||
public:
|
||||
@ -90,7 +54,7 @@ class MetalAllocator : public allocator::Allocator {
|
||||
friend MetalAllocator& allocator();
|
||||
|
||||
// Caching allocator
|
||||
BufferCache buffer_cache_;
|
||||
BufferCache<MTL::Buffer> buffer_cache_;
|
||||
|
||||
ResidencySet residency_set_;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user