mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
track resource limit and throw if exceeded (#1718)
This commit is contained in:
@@ -34,16 +34,20 @@ BufferCache::~BufferCache() {
|
||||
clear();
|
||||
}
|
||||
|
||||
void BufferCache::clear() {
|
||||
int BufferCache::clear() {
|
||||
int n_release = 0;
|
||||
for (auto& [size, holder] : buffer_pool_) {
|
||||
if (holder->buf)
|
||||
if (holder->buf) {
|
||||
holder->buf->release();
|
||||
n_release++;
|
||||
}
|
||||
delete holder;
|
||||
}
|
||||
buffer_pool_.clear();
|
||||
pool_size_ = 0;
|
||||
head_ = nullptr;
|
||||
tail_ = nullptr;
|
||||
return n_release;
|
||||
}
|
||||
|
||||
MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
|
||||
@@ -81,10 +85,11 @@ void BufferCache::recycle_to_cache(MTL::Buffer* buf) {
|
||||
}
|
||||
}
|
||||
|
||||
void BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
|
||||
int BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
|
||||
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
||||
clear();
|
||||
return clear();
|
||||
} else {
|
||||
int n_release = 0;
|
||||
size_t total_bytes_freed = 0;
|
||||
|
||||
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
||||
@@ -92,10 +97,12 @@ void BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
|
||||
total_bytes_freed += tail_->buf->length();
|
||||
tail_->buf->release();
|
||||
tail_->buf = nullptr;
|
||||
n_release++;
|
||||
}
|
||||
remove_from_list(tail_);
|
||||
}
|
||||
pool_size_ -= total_bytes_freed;
|
||||
return n_release;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,11 +151,11 @@ MetalAllocator::MetalAllocator()
|
||||
residency_set_(device_),
|
||||
buffer_cache_(device_) {
|
||||
auto memsize = std::get<size_t>(device_info()["memory_size"]);
|
||||
block_limit_ =
|
||||
std::min(1.5 * device_->recommendedMaxWorkingSetSize(), 0.95 * memsize);
|
||||
gc_limit_ = std::min(
|
||||
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()),
|
||||
block_limit_);
|
||||
auto max_rec_size =
|
||||
std::get<size_t>(device_info()["max_recommended_working_set_size"]);
|
||||
resource_limit_ = std::get<size_t>(device_info()["resource_limit"]);
|
||||
block_limit_ = std::min(1.5 * max_rec_size, 0.95 * memsize);
|
||||
gc_limit_ = std::min(static_cast<size_t>(0.95 * max_rec_size), block_limit_);
|
||||
max_pool_size_ = block_limit_;
|
||||
device(mlx::core::Device::gpu)
|
||||
.set_residency_set(residency_set_.mtl_residency_set());
|
||||
@@ -186,7 +193,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
// More helpful message if maximum buffer length is exceeded
|
||||
if (size > device_->maxBufferLength()) {
|
||||
std::ostringstream msg;
|
||||
msg << "Attempting to allocate " << size << " bytes which is greater than"
|
||||
msg << "[metal::malloc] Attempting to allocate " << size
|
||||
<< " bytes which is greater than"
|
||||
<< " the maximum allowed buffer size of " << device_->maxBufferLength()
|
||||
<< " bytes.";
|
||||
throw std::runtime_error(msg.str());
|
||||
@@ -212,16 +220,26 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
|
||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||
// try to reclaim memory from the cache
|
||||
if (mem_required >= gc_limit_) {
|
||||
buffer_cache_.release_cached_buffers(mem_required - gc_limit_);
|
||||
if (mem_required >= gc_limit_ || num_resources_ >= resource_limit_) {
|
||||
num_resources_ -=
|
||||
buffer_cache_.release_cached_buffers(mem_required - gc_limit_);
|
||||
}
|
||||
|
||||
// Allocate new buffer if needed
|
||||
size_t res_opt = MTL::ResourceStorageModeShared;
|
||||
res_opt |= MTL::ResourceHazardTrackingModeUntracked;
|
||||
if (num_resources_ >= resource_limit_) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal::malloc] Resource limit (" << resource_limit_
|
||||
<< ") exceeded.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
lk.unlock();
|
||||
buf = device_->newBuffer(size, res_opt);
|
||||
lk.lock();
|
||||
if (buf) {
|
||||
num_resources_++;
|
||||
}
|
||||
}
|
||||
|
||||
active_memory_ += buf->length();
|
||||
@@ -230,7 +248,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
// Maintain the cache below the requested limit
|
||||
if (get_cache_memory() >= max_pool_size_) {
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
||||
num_resources_ -= buffer_cache_.release_cached_buffers(
|
||||
get_cache_memory() - max_pool_size_);
|
||||
}
|
||||
|
||||
residency_set_.insert(buf);
|
||||
@@ -241,7 +260,7 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
void MetalAllocator::clear_cache() {
|
||||
std::unique_lock lk(mutex_);
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
buffer_cache_.clear();
|
||||
num_resources_ -= buffer_cache_.clear();
|
||||
}
|
||||
|
||||
void MetalAllocator::free(Buffer buffer) {
|
||||
@@ -255,6 +274,7 @@ void MetalAllocator::free(Buffer buffer) {
|
||||
if (get_cache_memory() < max_pool_size_) {
|
||||
buffer_cache_.recycle_to_cache(buf);
|
||||
} else {
|
||||
num_resources_--;
|
||||
lk.unlock();
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
buf->release();
|
||||
|
||||
Reference in New Issue
Block a user