track resource limit and throw if exceeded (#1718)

This commit is contained in:
Awni Hannun 2024-12-18 18:45:58 -08:00 committed by GitHub
parent 8bae22b0fa
commit 7480059306
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 47 additions and 19 deletions

View File

@ -34,16 +34,20 @@ BufferCache::~BufferCache() {
clear(); clear();
} }
void BufferCache::clear() { int BufferCache::clear() {
int n_release = 0;
for (auto& [size, holder] : buffer_pool_) { for (auto& [size, holder] : buffer_pool_) {
if (holder->buf) if (holder->buf) {
holder->buf->release(); holder->buf->release();
n_release++;
}
delete holder; delete holder;
} }
buffer_pool_.clear(); buffer_pool_.clear();
pool_size_ = 0; pool_size_ = 0;
head_ = nullptr; head_ = nullptr;
tail_ = nullptr; tail_ = nullptr;
return n_release;
} }
MTL::Buffer* BufferCache::reuse_from_cache(size_t size) { MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
@ -81,10 +85,11 @@ void BufferCache::recycle_to_cache(MTL::Buffer* buf) {
} }
} }
void BufferCache::release_cached_buffers(size_t min_bytes_to_free) { int BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
if (min_bytes_to_free >= 0.9 * pool_size_) { if (min_bytes_to_free >= 0.9 * pool_size_) {
clear(); return clear();
} else { } else {
int n_release = 0;
size_t total_bytes_freed = 0; size_t total_bytes_freed = 0;
while (tail_ && (total_bytes_freed < min_bytes_to_free)) { while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
@ -92,10 +97,12 @@ void BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
total_bytes_freed += tail_->buf->length(); total_bytes_freed += tail_->buf->length();
tail_->buf->release(); tail_->buf->release();
tail_->buf = nullptr; tail_->buf = nullptr;
n_release++;
} }
remove_from_list(tail_); remove_from_list(tail_);
} }
pool_size_ -= total_bytes_freed; pool_size_ -= total_bytes_freed;
return n_release;
} }
} }
@ -144,11 +151,11 @@ MetalAllocator::MetalAllocator()
residency_set_(device_), residency_set_(device_),
buffer_cache_(device_) { buffer_cache_(device_) {
auto memsize = std::get<size_t>(device_info()["memory_size"]); auto memsize = std::get<size_t>(device_info()["memory_size"]);
block_limit_ = auto max_rec_size =
std::min(1.5 * device_->recommendedMaxWorkingSetSize(), 0.95 * memsize); std::get<size_t>(device_info()["max_recommended_working_set_size"]);
gc_limit_ = std::min( resource_limit_ = std::get<size_t>(device_info()["resource_limit"]);
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()), block_limit_ = std::min(1.5 * max_rec_size, 0.95 * memsize);
block_limit_); gc_limit_ = std::min(static_cast<size_t>(0.95 * max_rec_size), block_limit_);
max_pool_size_ = block_limit_; max_pool_size_ = block_limit_;
device(mlx::core::Device::gpu) device(mlx::core::Device::gpu)
.set_residency_set(residency_set_.mtl_residency_set()); .set_residency_set(residency_set_.mtl_residency_set());
@ -186,7 +193,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// More helpful message if maximum buffer length is exceeded // More helpful message if maximum buffer length is exceeded
if (size > device_->maxBufferLength()) { if (size > device_->maxBufferLength()) {
std::ostringstream msg; std::ostringstream msg;
msg << "Attempting to allocate " << size << " bytes which is greater than" msg << "[metal::malloc] Attempting to allocate " << size
<< " bytes which is greater than"
<< " the maximum allowed buffer size of " << device_->maxBufferLength() << " the maximum allowed buffer size of " << device_->maxBufferLength()
<< " bytes."; << " bytes.";
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
@ -212,16 +220,26 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// If we have a lot of memory pressure or are over the maximum cache size, // If we have a lot of memory pressure or are over the maximum cache size,
// try to reclaim memory from the cache // try to reclaim memory from the cache
if (mem_required >= gc_limit_) { if (mem_required >= gc_limit_ || num_resources_ >= resource_limit_) {
buffer_cache_.release_cached_buffers(mem_required - gc_limit_); num_resources_ -=
buffer_cache_.release_cached_buffers(mem_required - gc_limit_);
} }
// Allocate new buffer if needed // Allocate new buffer if needed
size_t res_opt = MTL::ResourceStorageModeShared; size_t res_opt = MTL::ResourceStorageModeShared;
res_opt |= MTL::ResourceHazardTrackingModeUntracked; res_opt |= MTL::ResourceHazardTrackingModeUntracked;
if (num_resources_ >= resource_limit_) {
std::ostringstream msg;
msg << "[metal::malloc] Resource limit (" << resource_limit_
<< ") exceeded.";
throw std::runtime_error(msg.str());
}
lk.unlock(); lk.unlock();
buf = device_->newBuffer(size, res_opt); buf = device_->newBuffer(size, res_opt);
lk.lock(); lk.lock();
if (buf) {
num_resources_++;
}
} }
active_memory_ += buf->length(); active_memory_ += buf->length();
@ -230,7 +248,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Maintain the cache below the requested limit // Maintain the cache below the requested limit
if (get_cache_memory() >= max_pool_size_) { if (get_cache_memory() >= max_pool_size_) {
auto pool = metal::new_scoped_memory_pool(); auto pool = metal::new_scoped_memory_pool();
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_); num_resources_ -= buffer_cache_.release_cached_buffers(
get_cache_memory() - max_pool_size_);
} }
residency_set_.insert(buf); residency_set_.insert(buf);
@ -241,7 +260,7 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
void MetalAllocator::clear_cache() { void MetalAllocator::clear_cache() {
std::unique_lock lk(mutex_); std::unique_lock lk(mutex_);
auto pool = metal::new_scoped_memory_pool(); auto pool = metal::new_scoped_memory_pool();
buffer_cache_.clear(); num_resources_ -= buffer_cache_.clear();
} }
void MetalAllocator::free(Buffer buffer) { void MetalAllocator::free(Buffer buffer) {
@ -255,6 +274,7 @@ void MetalAllocator::free(Buffer buffer) {
if (get_cache_memory() < max_pool_size_) { if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf); buffer_cache_.recycle_to_cache(buf);
} else { } else {
num_resources_--;
lk.unlock(); lk.unlock();
auto pool = metal::new_scoped_memory_pool(); auto pool = metal::new_scoped_memory_pool();
buf->release(); buf->release();

View File

@ -23,11 +23,11 @@ class BufferCache {
MTL::Buffer* reuse_from_cache(size_t size); MTL::Buffer* reuse_from_cache(size_t size);
void recycle_to_cache(MTL::Buffer* buf); void recycle_to_cache(MTL::Buffer* buf);
void release_cached_buffers(size_t min_bytes_to_free); int release_cached_buffers(size_t min_bytes_to_free);
size_t cache_size() { size_t cache_size() {
return pool_size_; return pool_size_;
} }
void clear(); int clear();
private: private:
struct BufferHolder { struct BufferHolder {
@ -94,6 +94,8 @@ class MetalAllocator : public allocator::Allocator {
size_t max_pool_size_; size_t max_pool_size_;
size_t wired_limit_{0}; size_t wired_limit_{0};
bool relaxed_{true}; bool relaxed_{true};
size_t num_resources_{0};
size_t resource_limit_{0};
std::mutex mutex_; std::mutex mutex_;
}; };

View File

@ -651,18 +651,23 @@ device_info() {
auto raw_device = device(default_device()).mtl_device(); auto raw_device = device(default_device()).mtl_device();
auto arch = std::string(raw_device->architecture()->name()->utf8String()); auto arch = std::string(raw_device->architecture()->name()->utf8String());
int mib[] = {CTL_HW, HW_MEMSIZE};
size_t memsize = 0; size_t memsize = 0;
size_t length = sizeof(memsize); size_t length = sizeof(memsize);
sysctlbyname("hw.memsize", &memsize, &length, NULL, 0);
sysctl(mib, 2, &memsize, &length, NULL, 0); size_t rsrc_limit = 0;
sysctlbyname("iogpu.rsrc_limit", &rsrc_limit, &length, NULL, 0);
if (rsrc_limit == 0) {
rsrc_limit = 499000;
}
return { return {
{"architecture", arch}, {"architecture", arch},
{"max_buffer_length", raw_device->maxBufferLength()}, {"max_buffer_length", raw_device->maxBufferLength()},
{"max_recommended_working_set_size", {"max_recommended_working_set_size",
raw_device->recommendedMaxWorkingSetSize()}, raw_device->recommendedMaxWorkingSetSize()},
{"memory_size", memsize}}; {"memory_size", memsize},
{"resource_limit", rsrc_limit}};
}; };
static auto device_info_ = init_device_info(); static auto device_info_ = init_device_info();
return device_info_; return device_info_;

View File

@ -168,6 +168,7 @@ void init_metal(nb::module_& m) {
* ``max_buffer_size`` * ``max_buffer_size``
* ``max_recommended_working_set_size`` * ``max_recommended_working_set_size``
* ``memory_size`` * ``memory_size``
* ``resource_limit``
Returns: Returns:
dict: A dictionary with string keys and string or integer values. dict: A dictionary with string keys and string or integer values.