track resource limit and throw if exceeded (#1718)

This commit is contained in:
Awni Hannun 2024-12-18 18:45:58 -08:00 committed by GitHub
parent 8bae22b0fa
commit 7480059306
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 47 additions and 19 deletions

View File

@ -34,16 +34,20 @@ BufferCache::~BufferCache() {
clear();
}
void BufferCache::clear() {
int BufferCache::clear() {
int n_release = 0;
for (auto& [size, holder] : buffer_pool_) {
if (holder->buf)
if (holder->buf) {
holder->buf->release();
n_release++;
}
delete holder;
}
buffer_pool_.clear();
pool_size_ = 0;
head_ = nullptr;
tail_ = nullptr;
return n_release;
}
MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
@ -81,10 +85,11 @@ void BufferCache::recycle_to_cache(MTL::Buffer* buf) {
}
}
void BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
int BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
if (min_bytes_to_free >= 0.9 * pool_size_) {
clear();
return clear();
} else {
int n_release = 0;
size_t total_bytes_freed = 0;
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
@ -92,10 +97,12 @@ void BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
total_bytes_freed += tail_->buf->length();
tail_->buf->release();
tail_->buf = nullptr;
n_release++;
}
remove_from_list(tail_);
}
pool_size_ -= total_bytes_freed;
return n_release;
}
}
@ -144,11 +151,11 @@ MetalAllocator::MetalAllocator()
residency_set_(device_),
buffer_cache_(device_) {
auto memsize = std::get<size_t>(device_info()["memory_size"]);
block_limit_ =
std::min(1.5 * device_->recommendedMaxWorkingSetSize(), 0.95 * memsize);
gc_limit_ = std::min(
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()),
block_limit_);
auto max_rec_size =
std::get<size_t>(device_info()["max_recommended_working_set_size"]);
resource_limit_ = std::get<size_t>(device_info()["resource_limit"]);
block_limit_ = std::min(1.5 * max_rec_size, 0.95 * memsize);
gc_limit_ = std::min(static_cast<size_t>(0.95 * max_rec_size), block_limit_);
max_pool_size_ = block_limit_;
device(mlx::core::Device::gpu)
.set_residency_set(residency_set_.mtl_residency_set());
@ -186,7 +193,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// More helpful message if maximum buffer length is exceeded
if (size > device_->maxBufferLength()) {
std::ostringstream msg;
msg << "Attempting to allocate " << size << " bytes which is greater than"
msg << "[metal::malloc] Attempting to allocate " << size
<< " bytes which is greater than"
<< " the maximum allowed buffer size of " << device_->maxBufferLength()
<< " bytes.";
throw std::runtime_error(msg.str());
@ -212,16 +220,26 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// If we have a lot of memory pressure or are over the maximum cache size,
// try to reclaim memory from the cache
if (mem_required >= gc_limit_) {
buffer_cache_.release_cached_buffers(mem_required - gc_limit_);
if (mem_required >= gc_limit_ || num_resources_ >= resource_limit_) {
num_resources_ -=
buffer_cache_.release_cached_buffers(mem_required - gc_limit_);
}
// Allocate new buffer if needed
size_t res_opt = MTL::ResourceStorageModeShared;
res_opt |= MTL::ResourceHazardTrackingModeUntracked;
if (num_resources_ >= resource_limit_) {
std::ostringstream msg;
msg << "[metal::malloc] Resource limit (" << resource_limit_
<< ") exceeded.";
throw std::runtime_error(msg.str());
}
lk.unlock();
buf = device_->newBuffer(size, res_opt);
lk.lock();
if (buf) {
num_resources_++;
}
}
active_memory_ += buf->length();
@ -230,7 +248,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Maintain the cache below the requested limit
if (get_cache_memory() >= max_pool_size_) {
auto pool = metal::new_scoped_memory_pool();
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
num_resources_ -= buffer_cache_.release_cached_buffers(
get_cache_memory() - max_pool_size_);
}
residency_set_.insert(buf);
@ -241,7 +260,7 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
void MetalAllocator::clear_cache() {
std::unique_lock lk(mutex_);
auto pool = metal::new_scoped_memory_pool();
buffer_cache_.clear();
num_resources_ -= buffer_cache_.clear();
}
void MetalAllocator::free(Buffer buffer) {
@ -255,6 +274,7 @@ void MetalAllocator::free(Buffer buffer) {
if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf);
} else {
num_resources_--;
lk.unlock();
auto pool = metal::new_scoped_memory_pool();
buf->release();

View File

@ -23,11 +23,11 @@ class BufferCache {
MTL::Buffer* reuse_from_cache(size_t size);
void recycle_to_cache(MTL::Buffer* buf);
void release_cached_buffers(size_t min_bytes_to_free);
int release_cached_buffers(size_t min_bytes_to_free);
size_t cache_size() {
return pool_size_;
}
void clear();
int clear();
private:
struct BufferHolder {
@ -94,6 +94,8 @@ class MetalAllocator : public allocator::Allocator {
size_t max_pool_size_;
size_t wired_limit_{0};
bool relaxed_{true};
size_t num_resources_{0};
size_t resource_limit_{0};
std::mutex mutex_;
};

View File

@ -651,18 +651,23 @@ device_info() {
auto raw_device = device(default_device()).mtl_device();
auto arch = std::string(raw_device->architecture()->name()->utf8String());
int mib[] = {CTL_HW, HW_MEMSIZE};
size_t memsize = 0;
size_t length = sizeof(memsize);
sysctlbyname("hw.memsize", &memsize, &length, NULL, 0);
sysctl(mib, 2, &memsize, &length, NULL, 0);
size_t rsrc_limit = 0;
sysctlbyname("iogpu.rsrc_limit", &rsrc_limit, &length, NULL, 0);
if (rsrc_limit == 0) {
rsrc_limit = 499000;
}
return {
{"architecture", arch},
{"max_buffer_length", raw_device->maxBufferLength()},
{"max_recommended_working_set_size",
raw_device->recommendedMaxWorkingSetSize()},
{"memory_size", memsize}};
{"memory_size", memsize},
{"resource_limit", rsrc_limit}};
};
static auto device_info_ = init_device_info();
return device_info_;

View File

@ -168,6 +168,7 @@ void init_metal(nb::module_& m) {
* ``max_buffer_size``
* ``max_recommended_working_set_size``
* ``memory_size``
* ``resource_limit``
Returns:
dict: A dictionary with string keys and string or integer values.