mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Simplify BufferCache assuming buf can not be null
This commit is contained in:
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
|
||||||
@@ -27,37 +28,30 @@ class BufferCache {
|
|||||||
|
|
||||||
T* reuse_from_cache(size_t size) {
|
T* reuse_from_cache(size_t size) {
|
||||||
// Find the closest buffer in pool.
|
// Find the closest buffer in pool.
|
||||||
T* pbuf = nullptr;
|
|
||||||
|
|
||||||
auto it = buffer_pool_.lower_bound(size);
|
auto it = buffer_pool_.lower_bound(size);
|
||||||
|
if (it == buffer_pool_.end() ||
|
||||||
|
it->first >= std::min(2 * size, size + 2 * page_size_)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
// Make sure we use most of the available memory.
|
|
||||||
while (!pbuf && it != buffer_pool_.end() &&
|
|
||||||
it->first < std::min(2 * size, size + 2 * page_size_)) {
|
|
||||||
// Collect from the cache.
|
// Collect from the cache.
|
||||||
pbuf = it->second->buf;
|
T* buf = it->second->buf;
|
||||||
|
pool_size_ -= it->first;
|
||||||
|
|
||||||
// Remove from cache.
|
// Remove from record.
|
||||||
remove_from_list(it->second);
|
remove_from_list(it->second);
|
||||||
it = buffer_pool_.erase(it);
|
buffer_pool_.erase(it);
|
||||||
}
|
return buf;
|
||||||
|
|
||||||
if (pbuf) {
|
|
||||||
pool_size_ -= get_size_(pbuf);
|
|
||||||
}
|
|
||||||
|
|
||||||
return pbuf;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void recycle_to_cache(T* buf) {
|
void recycle_to_cache(T* buf) {
|
||||||
|
assert(buf);
|
||||||
// Add to cache.
|
// Add to cache.
|
||||||
if (buf) {
|
|
||||||
BufferHolder* bh = new BufferHolder(buf);
|
BufferHolder* bh = new BufferHolder(buf);
|
||||||
add_at_head(bh);
|
add_at_head(bh);
|
||||||
size_t size = get_size_(buf);
|
size_t size = get_size_(buf);
|
||||||
pool_size_ += size;
|
pool_size_ += size;
|
||||||
buffer_pool_.insert({size, bh});
|
buffer_pool_.emplace(size, bh);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int release_cached_buffers(size_t min_bytes_to_free) {
|
int release_cached_buffers(size_t min_bytes_to_free) {
|
||||||
@@ -68,20 +62,22 @@ class BufferCache {
|
|||||||
size_t total_bytes_freed = 0;
|
size_t total_bytes_freed = 0;
|
||||||
|
|
||||||
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
||||||
if (tail_->buf) {
|
// Release buffer.
|
||||||
total_bytes_freed += get_size_(tail_->buf);
|
size_t size = get_size_(tail_->buf);
|
||||||
|
total_bytes_freed += size;
|
||||||
free_(tail_->buf);
|
free_(tail_->buf);
|
||||||
tail_->buf = nullptr;
|
|
||||||
n_release++;
|
n_release++;
|
||||||
}
|
|
||||||
remove_from_list(tail_);
|
// Remove from record.
|
||||||
for (auto it = buffer_pool_.begin(); it != buffer_pool_.end(); ++it) {
|
auto its = buffer_pool_.equal_range(size);
|
||||||
if (it->second == tail_) {
|
auto it = std::find_if(its.first, its.second, [this](const auto& el) {
|
||||||
|
return el.second == tail_;
|
||||||
|
});
|
||||||
|
assert(it != buffer_pool_.end());
|
||||||
buffer_pool_.erase(it);
|
buffer_pool_.erase(it);
|
||||||
break;
|
remove_from_list(tail_);
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pool_size_ -= total_bytes_freed;
|
pool_size_ -= total_bytes_freed;
|
||||||
return n_release;
|
return n_release;
|
||||||
}
|
}
|
||||||
@@ -90,10 +86,8 @@ class BufferCache {
|
|||||||
int clear() {
|
int clear() {
|
||||||
int n_release = 0;
|
int n_release = 0;
|
||||||
for (auto& [size, holder] : buffer_pool_) {
|
for (auto& [size, holder] : buffer_pool_) {
|
||||||
if (holder->buf) {
|
|
||||||
free_(holder->buf);
|
free_(holder->buf);
|
||||||
n_release++;
|
n_release++;
|
||||||
}
|
|
||||||
delete holder;
|
delete holder;
|
||||||
}
|
}
|
||||||
buffer_pool_.clear();
|
buffer_pool_.clear();
|
||||||
|
|||||||
Reference in New Issue
Block a user