mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
wire cache (#2006)
This commit is contained in:
parent
0da8506552
commit
916fd273ea
@ -33,8 +33,11 @@ namespace metal {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
BufferCache::BufferCache(MTL::Device* device)
|
BufferCache::BufferCache(ResidencySet& residency_set)
|
||||||
: device_(device), head_(nullptr), tail_(nullptr), pool_size_(0) {}
|
: head_(nullptr),
|
||||||
|
tail_(nullptr),
|
||||||
|
pool_size_(0),
|
||||||
|
residency_set_(residency_set) {}
|
||||||
|
|
||||||
BufferCache::~BufferCache() {
|
BufferCache::~BufferCache() {
|
||||||
auto pool = metal::new_scoped_memory_pool();
|
auto pool = metal::new_scoped_memory_pool();
|
||||||
@ -102,6 +105,9 @@ int BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
|
|||||||
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
||||||
if (tail_->buf) {
|
if (tail_->buf) {
|
||||||
total_bytes_freed += tail_->buf->length();
|
total_bytes_freed += tail_->buf->length();
|
||||||
|
if (!tail_->buf->heap()) {
|
||||||
|
residency_set_.erase(tail_->buf);
|
||||||
|
}
|
||||||
tail_->buf->release();
|
tail_->buf->release();
|
||||||
tail_->buf = nullptr;
|
tail_->buf = nullptr;
|
||||||
n_release++;
|
n_release++;
|
||||||
@ -156,7 +162,7 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
|
|||||||
MetalAllocator::MetalAllocator()
|
MetalAllocator::MetalAllocator()
|
||||||
: device_(device(mlx::core::Device::gpu).mtl_device()),
|
: device_(device(mlx::core::Device::gpu).mtl_device()),
|
||||||
residency_set_(device_),
|
residency_set_(device_),
|
||||||
buffer_cache_(device_) {
|
buffer_cache_(residency_set_) {
|
||||||
auto pool = metal::new_scoped_memory_pool();
|
auto pool = metal::new_scoped_memory_pool();
|
||||||
auto memsize = std::get<size_t>(device_info().at("memory_size"));
|
auto memsize = std::get<size_t>(device_info().at("memory_size"));
|
||||||
auto max_rec_size =
|
auto max_rec_size =
|
||||||
@ -298,14 +304,14 @@ void MetalAllocator::free(Buffer buffer) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
std::unique_lock lk(mutex_);
|
std::unique_lock lk(mutex_);
|
||||||
if (!buf->heap()) {
|
|
||||||
residency_set_.erase(buf);
|
|
||||||
}
|
|
||||||
active_memory_ -= buf->length();
|
active_memory_ -= buf->length();
|
||||||
if (get_cache_memory() < max_pool_size_) {
|
if (get_cache_memory() < max_pool_size_) {
|
||||||
buffer_cache_.recycle_to_cache(buf);
|
buffer_cache_.recycle_to_cache(buf);
|
||||||
} else {
|
} else {
|
||||||
num_resources_--;
|
num_resources_--;
|
||||||
|
if (!buf->heap()) {
|
||||||
|
residency_set_.erase(buf);
|
||||||
|
}
|
||||||
lk.unlock();
|
lk.unlock();
|
||||||
auto pool = metal::new_scoped_memory_pool();
|
auto pool = metal::new_scoped_memory_pool();
|
||||||
buf->release();
|
buf->release();
|
||||||
|
@ -18,7 +18,7 @@ namespace {
|
|||||||
|
|
||||||
class BufferCache {
|
class BufferCache {
|
||||||
public:
|
public:
|
||||||
BufferCache(MTL::Device* device);
|
BufferCache(ResidencySet& residency_set);
|
||||||
~BufferCache();
|
~BufferCache();
|
||||||
|
|
||||||
MTL::Buffer* reuse_from_cache(size_t size);
|
MTL::Buffer* reuse_from_cache(size_t size);
|
||||||
@ -42,13 +42,11 @@ class BufferCache {
|
|||||||
void add_at_head(BufferHolder* to_add);
|
void add_at_head(BufferHolder* to_add);
|
||||||
void remove_from_list(BufferHolder* to_remove);
|
void remove_from_list(BufferHolder* to_remove);
|
||||||
|
|
||||||
MTL::Device* device_;
|
|
||||||
MTL::Heap* heap_{nullptr};
|
|
||||||
|
|
||||||
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
||||||
BufferHolder* head_;
|
BufferHolder* head_;
|
||||||
BufferHolder* tail_;
|
BufferHolder* tail_;
|
||||||
size_t pool_size_;
|
size_t pool_size_;
|
||||||
|
ResidencySet& residency_set_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
Loading…
Reference in New Issue
Block a user