mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Use a heap for small sizes (#1911)
* use a heap for small sizes * check if VM
This commit is contained in:
@@ -10,6 +10,9 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr size_t resource_options =
|
||||
MTL::ResourceStorageModeShared | MTL::ResourceHazardTrackingModeUntracked;
|
||||
|
||||
namespace allocator {
|
||||
|
||||
Allocator& allocator() {
|
||||
@@ -150,15 +153,34 @@ MetalAllocator::MetalAllocator()
|
||||
: device_(device(mlx::core::Device::gpu).mtl_device()),
|
||||
residency_set_(device_),
|
||||
buffer_cache_(device_) {
|
||||
auto memsize = std::get<size_t>(device_info()["memory_size"]);
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
auto memsize = std::get<size_t>(device_info().at("memory_size"));
|
||||
auto max_rec_size =
|
||||
std::get<size_t>(device_info()["max_recommended_working_set_size"]);
|
||||
resource_limit_ = std::get<size_t>(device_info()["resource_limit"]);
|
||||
std::get<size_t>(device_info().at("max_recommended_working_set_size"));
|
||||
resource_limit_ = std::get<size_t>(device_info().at("resource_limit"));
|
||||
block_limit_ = std::min(1.5 * max_rec_size, 0.95 * memsize);
|
||||
gc_limit_ = std::min(static_cast<size_t>(0.95 * max_rec_size), block_limit_);
|
||||
max_pool_size_ = block_limit_;
|
||||
device(mlx::core::Device::gpu)
|
||||
.set_residency_set(residency_set_.mtl_residency_set());
|
||||
bool is_vm = std::get<std::string>(device_info().at("device_name")) ==
|
||||
"Apple Paravirtual device";
|
||||
if (is_vm) {
|
||||
return;
|
||||
}
|
||||
auto heap_desc = MTL::HeapDescriptor::alloc()->init();
|
||||
heap_desc->setResourceOptions(resource_options);
|
||||
heap_desc->setSize(heap_size_);
|
||||
heap_ = device_->newHeap(heap_desc);
|
||||
heap_desc->release();
|
||||
residency_set_.insert(heap_);
|
||||
}
|
||||
|
||||
MetalAllocator::~MetalAllocator() {
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
if (heap_) {
|
||||
heap_->release();
|
||||
}
|
||||
}
|
||||
|
||||
size_t MetalAllocator::set_cache_limit(size_t limit) {
|
||||
@@ -226,8 +248,6 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
}
|
||||
|
||||
// Allocate new buffer if needed
|
||||
size_t res_opt = MTL::ResourceStorageModeShared;
|
||||
res_opt |= MTL::ResourceHazardTrackingModeUntracked;
|
||||
if (num_resources_ >= resource_limit_) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal::malloc] Resource limit (" << resource_limit_
|
||||
@@ -235,7 +255,12 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
lk.unlock();
|
||||
buf = device_->newBuffer(size, res_opt);
|
||||
if (size < small_size_ && heap_) {
|
||||
buf = heap_->newBuffer(size, resource_options);
|
||||
}
|
||||
if (!buf) {
|
||||
buf = device_->newBuffer(size, resource_options);
|
||||
}
|
||||
lk.lock();
|
||||
if (buf) {
|
||||
num_resources_++;
|
||||
@@ -246,13 +271,15 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
peak_memory_ = std::max(peak_memory_, active_memory_);
|
||||
|
||||
// Maintain the cache below the requested limit
|
||||
if (get_cache_memory() >= max_pool_size_) {
|
||||
if (get_cache_memory() > max_pool_size_) {
|
||||
auto pool = metal::new_scoped_memory_pool();
|
||||
num_resources_ -= buffer_cache_.release_cached_buffers(
|
||||
get_cache_memory() - max_pool_size_);
|
||||
}
|
||||
|
||||
residency_set_.insert(buf);
|
||||
if (!buf->heap()) {
|
||||
residency_set_.insert(buf);
|
||||
}
|
||||
|
||||
return Buffer{static_cast<void*>(buf)};
|
||||
}
|
||||
@@ -269,7 +296,9 @@ void MetalAllocator::free(Buffer buffer) {
|
||||
return;
|
||||
}
|
||||
std::unique_lock lk(mutex_);
|
||||
residency_set_.erase(buf);
|
||||
if (!buf->heap()) {
|
||||
residency_set_.erase(buf);
|
||||
}
|
||||
active_memory_ -= buf->length();
|
||||
if (get_cache_memory() < max_pool_size_) {
|
||||
buffer_cache_.recycle_to_cache(buf);
|
||||
@@ -301,7 +330,7 @@ size_t set_memory_limit(size_t limit, bool relaxed /* = true */) {
|
||||
}
|
||||
size_t set_wired_limit(size_t limit) {
|
||||
if (limit >
|
||||
std::get<size_t>(device_info()["max_recommended_working_set_size"])) {
|
||||
std::get<size_t>(device_info().at("max_recommended_working_set_size"))) {
|
||||
throw std::invalid_argument(
|
||||
"[metal::set_wired_limit] Setting a wired limit larger than "
|
||||
"the maximum working set size is not allowed.");
|
||||
|
||||
Reference in New Issue
Block a user