mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-06 03:26:45 +08:00
track resource limit and throw if exceeded (#1718)
This commit is contained in:
parent
8bae22b0fa
commit
7480059306
@ -34,16 +34,20 @@ BufferCache::~BufferCache() {
|
|||||||
clear();
|
clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
void BufferCache::clear() {
|
int BufferCache::clear() {
|
||||||
|
int n_release = 0;
|
||||||
for (auto& [size, holder] : buffer_pool_) {
|
for (auto& [size, holder] : buffer_pool_) {
|
||||||
if (holder->buf)
|
if (holder->buf) {
|
||||||
holder->buf->release();
|
holder->buf->release();
|
||||||
|
n_release++;
|
||||||
|
}
|
||||||
delete holder;
|
delete holder;
|
||||||
}
|
}
|
||||||
buffer_pool_.clear();
|
buffer_pool_.clear();
|
||||||
pool_size_ = 0;
|
pool_size_ = 0;
|
||||||
head_ = nullptr;
|
head_ = nullptr;
|
||||||
tail_ = nullptr;
|
tail_ = nullptr;
|
||||||
|
return n_release;
|
||||||
}
|
}
|
||||||
|
|
||||||
MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
|
MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
|
||||||
@ -81,10 +85,11 @@ void BufferCache::recycle_to_cache(MTL::Buffer* buf) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
|
int BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
|
||||||
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
||||||
clear();
|
return clear();
|
||||||
} else {
|
} else {
|
||||||
|
int n_release = 0;
|
||||||
size_t total_bytes_freed = 0;
|
size_t total_bytes_freed = 0;
|
||||||
|
|
||||||
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
||||||
@ -92,10 +97,12 @@ void BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
|
|||||||
total_bytes_freed += tail_->buf->length();
|
total_bytes_freed += tail_->buf->length();
|
||||||
tail_->buf->release();
|
tail_->buf->release();
|
||||||
tail_->buf = nullptr;
|
tail_->buf = nullptr;
|
||||||
|
n_release++;
|
||||||
}
|
}
|
||||||
remove_from_list(tail_);
|
remove_from_list(tail_);
|
||||||
}
|
}
|
||||||
pool_size_ -= total_bytes_freed;
|
pool_size_ -= total_bytes_freed;
|
||||||
|
return n_release;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -144,11 +151,11 @@ MetalAllocator::MetalAllocator()
|
|||||||
residency_set_(device_),
|
residency_set_(device_),
|
||||||
buffer_cache_(device_) {
|
buffer_cache_(device_) {
|
||||||
auto memsize = std::get<size_t>(device_info()["memory_size"]);
|
auto memsize = std::get<size_t>(device_info()["memory_size"]);
|
||||||
block_limit_ =
|
auto max_rec_size =
|
||||||
std::min(1.5 * device_->recommendedMaxWorkingSetSize(), 0.95 * memsize);
|
std::get<size_t>(device_info()["max_recommended_working_set_size"]);
|
||||||
gc_limit_ = std::min(
|
resource_limit_ = std::get<size_t>(device_info()["resource_limit"]);
|
||||||
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()),
|
block_limit_ = std::min(1.5 * max_rec_size, 0.95 * memsize);
|
||||||
block_limit_);
|
gc_limit_ = std::min(static_cast<size_t>(0.95 * max_rec_size), block_limit_);
|
||||||
max_pool_size_ = block_limit_;
|
max_pool_size_ = block_limit_;
|
||||||
device(mlx::core::Device::gpu)
|
device(mlx::core::Device::gpu)
|
||||||
.set_residency_set(residency_set_.mtl_residency_set());
|
.set_residency_set(residency_set_.mtl_residency_set());
|
||||||
@ -186,7 +193,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
|||||||
// More helpful message if maximum buffer length is exceeded
|
// More helpful message if maximum buffer length is exceeded
|
||||||
if (size > device_->maxBufferLength()) {
|
if (size > device_->maxBufferLength()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "Attempting to allocate " << size << " bytes which is greater than"
|
msg << "[metal::malloc] Attempting to allocate " << size
|
||||||
|
<< " bytes which is greater than"
|
||||||
<< " the maximum allowed buffer size of " << device_->maxBufferLength()
|
<< " the maximum allowed buffer size of " << device_->maxBufferLength()
|
||||||
<< " bytes.";
|
<< " bytes.";
|
||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
@ -212,16 +220,26 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
|||||||
|
|
||||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||||
// try to reclaim memory from the cache
|
// try to reclaim memory from the cache
|
||||||
if (mem_required >= gc_limit_) {
|
if (mem_required >= gc_limit_ || num_resources_ >= resource_limit_) {
|
||||||
buffer_cache_.release_cached_buffers(mem_required - gc_limit_);
|
num_resources_ -=
|
||||||
|
buffer_cache_.release_cached_buffers(mem_required - gc_limit_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allocate new buffer if needed
|
// Allocate new buffer if needed
|
||||||
size_t res_opt = MTL::ResourceStorageModeShared;
|
size_t res_opt = MTL::ResourceStorageModeShared;
|
||||||
res_opt |= MTL::ResourceHazardTrackingModeUntracked;
|
res_opt |= MTL::ResourceHazardTrackingModeUntracked;
|
||||||
|
if (num_resources_ >= resource_limit_) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[metal::malloc] Resource limit (" << resource_limit_
|
||||||
|
<< ") exceeded.";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
lk.unlock();
|
lk.unlock();
|
||||||
buf = device_->newBuffer(size, res_opt);
|
buf = device_->newBuffer(size, res_opt);
|
||||||
lk.lock();
|
lk.lock();
|
||||||
|
if (buf) {
|
||||||
|
num_resources_++;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
active_memory_ += buf->length();
|
active_memory_ += buf->length();
|
||||||
@ -230,7 +248,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
|||||||
// Maintain the cache below the requested limit
|
// Maintain the cache below the requested limit
|
||||||
if (get_cache_memory() >= max_pool_size_) {
|
if (get_cache_memory() >= max_pool_size_) {
|
||||||
auto pool = metal::new_scoped_memory_pool();
|
auto pool = metal::new_scoped_memory_pool();
|
||||||
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
num_resources_ -= buffer_cache_.release_cached_buffers(
|
||||||
|
get_cache_memory() - max_pool_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
residency_set_.insert(buf);
|
residency_set_.insert(buf);
|
||||||
@ -241,7 +260,7 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
|||||||
void MetalAllocator::clear_cache() {
|
void MetalAllocator::clear_cache() {
|
||||||
std::unique_lock lk(mutex_);
|
std::unique_lock lk(mutex_);
|
||||||
auto pool = metal::new_scoped_memory_pool();
|
auto pool = metal::new_scoped_memory_pool();
|
||||||
buffer_cache_.clear();
|
num_resources_ -= buffer_cache_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
void MetalAllocator::free(Buffer buffer) {
|
void MetalAllocator::free(Buffer buffer) {
|
||||||
@ -255,6 +274,7 @@ void MetalAllocator::free(Buffer buffer) {
|
|||||||
if (get_cache_memory() < max_pool_size_) {
|
if (get_cache_memory() < max_pool_size_) {
|
||||||
buffer_cache_.recycle_to_cache(buf);
|
buffer_cache_.recycle_to_cache(buf);
|
||||||
} else {
|
} else {
|
||||||
|
num_resources_--;
|
||||||
lk.unlock();
|
lk.unlock();
|
||||||
auto pool = metal::new_scoped_memory_pool();
|
auto pool = metal::new_scoped_memory_pool();
|
||||||
buf->release();
|
buf->release();
|
||||||
|
@ -23,11 +23,11 @@ class BufferCache {
|
|||||||
|
|
||||||
MTL::Buffer* reuse_from_cache(size_t size);
|
MTL::Buffer* reuse_from_cache(size_t size);
|
||||||
void recycle_to_cache(MTL::Buffer* buf);
|
void recycle_to_cache(MTL::Buffer* buf);
|
||||||
void release_cached_buffers(size_t min_bytes_to_free);
|
int release_cached_buffers(size_t min_bytes_to_free);
|
||||||
size_t cache_size() {
|
size_t cache_size() {
|
||||||
return pool_size_;
|
return pool_size_;
|
||||||
}
|
}
|
||||||
void clear();
|
int clear();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
struct BufferHolder {
|
struct BufferHolder {
|
||||||
@ -94,6 +94,8 @@ class MetalAllocator : public allocator::Allocator {
|
|||||||
size_t max_pool_size_;
|
size_t max_pool_size_;
|
||||||
size_t wired_limit_{0};
|
size_t wired_limit_{0};
|
||||||
bool relaxed_{true};
|
bool relaxed_{true};
|
||||||
|
size_t num_resources_{0};
|
||||||
|
size_t resource_limit_{0};
|
||||||
|
|
||||||
std::mutex mutex_;
|
std::mutex mutex_;
|
||||||
};
|
};
|
||||||
|
@ -651,18 +651,23 @@ device_info() {
|
|||||||
auto raw_device = device(default_device()).mtl_device();
|
auto raw_device = device(default_device()).mtl_device();
|
||||||
auto arch = std::string(raw_device->architecture()->name()->utf8String());
|
auto arch = std::string(raw_device->architecture()->name()->utf8String());
|
||||||
|
|
||||||
int mib[] = {CTL_HW, HW_MEMSIZE};
|
|
||||||
size_t memsize = 0;
|
size_t memsize = 0;
|
||||||
size_t length = sizeof(memsize);
|
size_t length = sizeof(memsize);
|
||||||
|
sysctlbyname("hw.memsize", &memsize, &length, NULL, 0);
|
||||||
|
|
||||||
sysctl(mib, 2, &memsize, &length, NULL, 0);
|
size_t rsrc_limit = 0;
|
||||||
|
sysctlbyname("iogpu.rsrc_limit", &rsrc_limit, &length, NULL, 0);
|
||||||
|
if (rsrc_limit == 0) {
|
||||||
|
rsrc_limit = 499000;
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
{"architecture", arch},
|
{"architecture", arch},
|
||||||
{"max_buffer_length", raw_device->maxBufferLength()},
|
{"max_buffer_length", raw_device->maxBufferLength()},
|
||||||
{"max_recommended_working_set_size",
|
{"max_recommended_working_set_size",
|
||||||
raw_device->recommendedMaxWorkingSetSize()},
|
raw_device->recommendedMaxWorkingSetSize()},
|
||||||
{"memory_size", memsize}};
|
{"memory_size", memsize},
|
||||||
|
{"resource_limit", rsrc_limit}};
|
||||||
};
|
};
|
||||||
static auto device_info_ = init_device_info();
|
static auto device_info_ = init_device_info();
|
||||||
return device_info_;
|
return device_info_;
|
||||||
|
@ -168,6 +168,7 @@ void init_metal(nb::module_& m) {
|
|||||||
* ``max_buffer_size``
|
* ``max_buffer_size``
|
||||||
* ``max_recommended_working_set_size``
|
* ``max_recommended_working_set_size``
|
||||||
* ``memory_size``
|
* ``memory_size``
|
||||||
|
* ``resource_limit``
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: A dictionary with string keys and string or integer values.
|
dict: A dictionary with string keys and string or integer values.
|
||||||
|
Loading…
Reference in New Issue
Block a user