mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Move BufferCache out of allocator
This commit is contained in:
163
mlx/backend/common/buffer_cache.h
Normal file
163
mlx/backend/common/buffer_cache.h
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class BufferCache {
|
||||||
|
public:
|
||||||
|
BufferCache(
|
||||||
|
size_t page_size,
|
||||||
|
std::function<size_t(T*)> get_size,
|
||||||
|
std::function<void(T*)> free)
|
||||||
|
: page_size_(page_size),
|
||||||
|
get_size_(std::move(get_size)),
|
||||||
|
free_(std::move(free)) {}
|
||||||
|
|
||||||
|
~BufferCache() {
|
||||||
|
clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
BufferCache(const BufferCache&) = delete;
|
||||||
|
BufferCache& operator=(const BufferCache&) = delete;
|
||||||
|
|
||||||
|
T* reuse_from_cache(size_t size) {
|
||||||
|
// Find the closest buffer in pool.
|
||||||
|
T* pbuf = nullptr;
|
||||||
|
|
||||||
|
auto it = buffer_pool_.lower_bound(size);
|
||||||
|
|
||||||
|
// Make sure we use most of the available memory.
|
||||||
|
while (!pbuf && it != buffer_pool_.end() &&
|
||||||
|
it->first < std::min(2 * size, size + 2 * page_size_)) {
|
||||||
|
// Collect from the cache.
|
||||||
|
pbuf = it->second->buf;
|
||||||
|
|
||||||
|
// Remove from cache.
|
||||||
|
remove_from_list(it->second);
|
||||||
|
it = buffer_pool_.erase(it);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pbuf) {
|
||||||
|
pool_size_ -= get_size_(pbuf);
|
||||||
|
}
|
||||||
|
|
||||||
|
return pbuf;
|
||||||
|
}
|
||||||
|
|
||||||
|
void recycle_to_cache(T* buf) {
|
||||||
|
// Add to cache.
|
||||||
|
if (buf) {
|
||||||
|
BufferHolder* bh = new BufferHolder(buf);
|
||||||
|
add_at_head(bh);
|
||||||
|
size_t size = get_size_(buf);
|
||||||
|
pool_size_ += size;
|
||||||
|
buffer_pool_.insert({size, bh});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int release_cached_buffers(size_t min_bytes_to_free) {
|
||||||
|
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
||||||
|
return clear();
|
||||||
|
} else {
|
||||||
|
int n_release = 0;
|
||||||
|
size_t total_bytes_freed = 0;
|
||||||
|
|
||||||
|
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
||||||
|
if (tail_->buf) {
|
||||||
|
total_bytes_freed += get_size_(tail_->buf);
|
||||||
|
free_(tail_->buf);
|
||||||
|
tail_->buf = nullptr;
|
||||||
|
n_release++;
|
||||||
|
}
|
||||||
|
remove_from_list(tail_);
|
||||||
|
for (auto it = buffer_pool_.begin(); it != buffer_pool_.end(); ++it) {
|
||||||
|
if (it->second == tail_) {
|
||||||
|
buffer_pool_.erase(it);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pool_size_ -= total_bytes_freed;
|
||||||
|
return n_release;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int clear() {
|
||||||
|
int n_release = 0;
|
||||||
|
for (auto& [size, holder] : buffer_pool_) {
|
||||||
|
if (holder->buf) {
|
||||||
|
free_(holder->buf);
|
||||||
|
n_release++;
|
||||||
|
}
|
||||||
|
delete holder;
|
||||||
|
}
|
||||||
|
buffer_pool_.clear();
|
||||||
|
pool_size_ = 0;
|
||||||
|
head_ = nullptr;
|
||||||
|
tail_ = nullptr;
|
||||||
|
return n_release;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t cache_size() const {
|
||||||
|
return pool_size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t page_size() const {
|
||||||
|
return page_size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct BufferHolder {
|
||||||
|
public:
|
||||||
|
explicit BufferHolder(T* buf_) : buf(buf_) {}
|
||||||
|
|
||||||
|
BufferHolder* prev{nullptr};
|
||||||
|
BufferHolder* next{nullptr};
|
||||||
|
T* buf;
|
||||||
|
};
|
||||||
|
|
||||||
|
void add_at_head(BufferHolder* to_add) {
|
||||||
|
if (!head_) {
|
||||||
|
head_ = to_add;
|
||||||
|
tail_ = to_add;
|
||||||
|
} else {
|
||||||
|
head_->prev = to_add;
|
||||||
|
to_add->next = head_;
|
||||||
|
head_ = to_add;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void remove_from_list(BufferHolder* to_remove) {
|
||||||
|
if (to_remove->prev && to_remove->next) { // if middle
|
||||||
|
to_remove->prev->next = to_remove->next;
|
||||||
|
to_remove->next->prev = to_remove->prev;
|
||||||
|
} else if (to_remove->prev && to_remove == tail_) { // if tail
|
||||||
|
tail_ = to_remove->prev;
|
||||||
|
tail_->next = nullptr;
|
||||||
|
} else if (to_remove == head_ && to_remove->next) { // if head
|
||||||
|
head_ = to_remove->next;
|
||||||
|
head_->prev = nullptr;
|
||||||
|
} else if (to_remove == head_ && to_remove == tail_) { // if only element
|
||||||
|
head_ = nullptr;
|
||||||
|
tail_ = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
delete to_remove;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
||||||
|
BufferHolder* head_{nullptr};
|
||||||
|
BufferHolder* tail_{nullptr};
|
||||||
|
size_t pool_size_{0};
|
||||||
|
|
||||||
|
const size_t page_size_;
|
||||||
|
std::function<size_t(T*)> get_size_;
|
||||||
|
std::function<void(T*)> free_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -30,141 +30,18 @@ void* Buffer::raw_ptr() {
|
|||||||
|
|
||||||
namespace metal {
|
namespace metal {
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
BufferCache::BufferCache(ResidencySet& residency_set)
|
|
||||||
: head_(nullptr),
|
|
||||||
tail_(nullptr),
|
|
||||||
pool_size_(0),
|
|
||||||
residency_set_(residency_set) {}
|
|
||||||
|
|
||||||
BufferCache::~BufferCache() {
|
|
||||||
auto pool = metal::new_scoped_memory_pool();
|
|
||||||
clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
int BufferCache::clear() {
|
|
||||||
int n_release = 0;
|
|
||||||
for (auto& [size, holder] : buffer_pool_) {
|
|
||||||
if (holder->buf) {
|
|
||||||
if (!holder->buf->heap()) {
|
|
||||||
residency_set_.erase(holder->buf);
|
|
||||||
}
|
|
||||||
holder->buf->release();
|
|
||||||
n_release++;
|
|
||||||
}
|
|
||||||
delete holder;
|
|
||||||
}
|
|
||||||
buffer_pool_.clear();
|
|
||||||
pool_size_ = 0;
|
|
||||||
head_ = nullptr;
|
|
||||||
tail_ = nullptr;
|
|
||||||
return n_release;
|
|
||||||
}
|
|
||||||
|
|
||||||
MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
|
|
||||||
// Find the closest buffer in pool
|
|
||||||
MTL::Buffer* pbuf = nullptr;
|
|
||||||
|
|
||||||
auto it = buffer_pool_.lower_bound(size);
|
|
||||||
|
|
||||||
// Make sure we use most of the available memory
|
|
||||||
while (!pbuf && it != buffer_pool_.end() &&
|
|
||||||
it->first < std::min(2 * size, size + 2 * vm_page_size)) {
|
|
||||||
// Collect from the cache
|
|
||||||
pbuf = it->second->buf;
|
|
||||||
|
|
||||||
// Remove from cache
|
|
||||||
remove_from_list(it->second);
|
|
||||||
delete it->second;
|
|
||||||
it = buffer_pool_.erase(it);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (pbuf) {
|
|
||||||
pool_size_ -= pbuf->length();
|
|
||||||
}
|
|
||||||
|
|
||||||
return pbuf;
|
|
||||||
}
|
|
||||||
|
|
||||||
void BufferCache::recycle_to_cache(MTL::Buffer* buf) {
|
|
||||||
// Add to cache
|
|
||||||
if (buf) {
|
|
||||||
BufferHolder* bh = new BufferHolder(buf);
|
|
||||||
add_at_head(bh);
|
|
||||||
pool_size_ += buf->length();
|
|
||||||
buffer_pool_.insert({buf->length(), bh});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
|
|
||||||
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
|
||||||
return clear();
|
|
||||||
} else {
|
|
||||||
int n_release = 0;
|
|
||||||
size_t total_bytes_freed = 0;
|
|
||||||
|
|
||||||
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
|
||||||
if (tail_->buf) {
|
|
||||||
total_bytes_freed += tail_->buf->length();
|
|
||||||
if (!tail_->buf->heap()) {
|
|
||||||
residency_set_.erase(tail_->buf);
|
|
||||||
}
|
|
||||||
tail_->buf->release();
|
|
||||||
tail_->buf = nullptr;
|
|
||||||
n_release++;
|
|
||||||
}
|
|
||||||
remove_from_list(tail_);
|
|
||||||
}
|
|
||||||
pool_size_ -= total_bytes_freed;
|
|
||||||
return n_release;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void BufferCache::add_at_head(BufferCache::BufferHolder* to_add) {
|
|
||||||
if (!to_add)
|
|
||||||
return;
|
|
||||||
|
|
||||||
if (!head_) {
|
|
||||||
head_ = to_add;
|
|
||||||
tail_ = to_add;
|
|
||||||
} else {
|
|
||||||
head_->prev = to_add;
|
|
||||||
to_add->next = head_;
|
|
||||||
head_ = to_add;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
|
|
||||||
if (!to_remove) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// If in the middle
|
|
||||||
if (to_remove->prev && to_remove->next) {
|
|
||||||
to_remove->prev->next = to_remove->next;
|
|
||||||
to_remove->next->prev = to_remove->prev;
|
|
||||||
} else if (to_remove->prev && to_remove == tail_) { // If tail
|
|
||||||
tail_ = to_remove->prev;
|
|
||||||
tail_->next = nullptr;
|
|
||||||
} else if (to_remove == head_ && to_remove->next) { // If head
|
|
||||||
head_ = to_remove->next;
|
|
||||||
head_->prev = nullptr;
|
|
||||||
} else if (to_remove == head_ && to_remove == tail_) { // If only element
|
|
||||||
head_ = nullptr;
|
|
||||||
tail_ = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
to_remove->prev = nullptr;
|
|
||||||
to_remove->next = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
MetalAllocator::MetalAllocator()
|
MetalAllocator::MetalAllocator()
|
||||||
: device_(device(mlx::core::Device::gpu).mtl_device()),
|
: device_(device(mlx::core::Device::gpu).mtl_device()),
|
||||||
residency_set_(device_),
|
residency_set_(device_),
|
||||||
buffer_cache_(residency_set_) {
|
buffer_cache_(
|
||||||
|
vm_page_size,
|
||||||
|
[](MTL::Buffer* buf) { return buf->length(); },
|
||||||
|
[this](MTL::Buffer* buf) {
|
||||||
|
if (!buf->heap()) {
|
||||||
|
residency_set_.erase(buf);
|
||||||
|
}
|
||||||
|
buf->release();
|
||||||
|
}) {
|
||||||
auto pool = metal::new_scoped_memory_pool();
|
auto pool = metal::new_scoped_memory_pool();
|
||||||
auto memsize = std::get<size_t>(device_info().at("memory_size"));
|
auto memsize = std::get<size_t>(device_info().at("memory_size"));
|
||||||
auto max_rec_size =
|
auto max_rec_size =
|
||||||
@@ -193,6 +70,7 @@ MetalAllocator::~MetalAllocator() {
|
|||||||
if (heap_) {
|
if (heap_) {
|
||||||
heap_->release();
|
heap_->release();
|
||||||
}
|
}
|
||||||
|
buffer_cache_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t MetalAllocator::set_cache_limit(size_t limit) {
|
size_t MetalAllocator::set_cache_limit(size_t limit) {
|
||||||
|
|||||||
@@ -7,6 +7,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/backend/common/buffer_cache.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/resident.h"
|
#include "mlx/backend/metal/resident.h"
|
||||||
|
|
||||||
@@ -14,43 +15,6 @@ namespace mlx::core::metal {
|
|||||||
|
|
||||||
using allocator::Buffer;
|
using allocator::Buffer;
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
class BufferCache {
|
|
||||||
public:
|
|
||||||
BufferCache(ResidencySet& residency_set);
|
|
||||||
~BufferCache();
|
|
||||||
|
|
||||||
MTL::Buffer* reuse_from_cache(size_t size);
|
|
||||||
void recycle_to_cache(MTL::Buffer* buf);
|
|
||||||
int release_cached_buffers(size_t min_bytes_to_free);
|
|
||||||
size_t cache_size() {
|
|
||||||
return pool_size_;
|
|
||||||
}
|
|
||||||
int clear();
|
|
||||||
|
|
||||||
private:
|
|
||||||
struct BufferHolder {
|
|
||||||
public:
|
|
||||||
BufferHolder(MTL::Buffer* buf_) : buf(buf_), prev(nullptr), next(nullptr) {}
|
|
||||||
|
|
||||||
BufferHolder* prev;
|
|
||||||
BufferHolder* next;
|
|
||||||
MTL::Buffer* buf;
|
|
||||||
};
|
|
||||||
|
|
||||||
void add_at_head(BufferHolder* to_add);
|
|
||||||
void remove_from_list(BufferHolder* to_remove);
|
|
||||||
|
|
||||||
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
|
||||||
BufferHolder* head_;
|
|
||||||
BufferHolder* tail_;
|
|
||||||
size_t pool_size_;
|
|
||||||
ResidencySet& residency_set_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
class MetalAllocator : public allocator::Allocator {
|
class MetalAllocator : public allocator::Allocator {
|
||||||
/** Allocator for Metal GPUs. */
|
/** Allocator for Metal GPUs. */
|
||||||
public:
|
public:
|
||||||
@@ -90,7 +54,7 @@ class MetalAllocator : public allocator::Allocator {
|
|||||||
friend MetalAllocator& allocator();
|
friend MetalAllocator& allocator();
|
||||||
|
|
||||||
// Caching allocator
|
// Caching allocator
|
||||||
BufferCache buffer_cache_;
|
BufferCache<MTL::Buffer> buffer_cache_;
|
||||||
|
|
||||||
ResidencySet residency_set_;
|
ResidencySet residency_set_;
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user