Use a heap for small sizes (#1911)

* use a heap for small sizes

* check if VM
This commit is contained in:
Awni Hannun 2025-03-03 06:50:57 -08:00 committed by GitHub
parent 4e7cd31d12
commit ba12e4999a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 53 additions and 13 deletions

View File

@ -10,6 +10,9 @@
namespace mlx::core {
constexpr size_t resource_options =
MTL::ResourceStorageModeShared | MTL::ResourceHazardTrackingModeUntracked;
namespace allocator {
Allocator& allocator() {
@ -150,15 +153,34 @@ MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()),
residency_set_(device_),
buffer_cache_(device_) {
auto memsize = std::get<size_t>(device_info()["memory_size"]);
auto pool = metal::new_scoped_memory_pool();
auto memsize = std::get<size_t>(device_info().at("memory_size"));
auto max_rec_size =
std::get<size_t>(device_info()["max_recommended_working_set_size"]);
resource_limit_ = std::get<size_t>(device_info()["resource_limit"]);
std::get<size_t>(device_info().at("max_recommended_working_set_size"));
resource_limit_ = std::get<size_t>(device_info().at("resource_limit"));
block_limit_ = std::min(1.5 * max_rec_size, 0.95 * memsize);
gc_limit_ = std::min(static_cast<size_t>(0.95 * max_rec_size), block_limit_);
max_pool_size_ = block_limit_;
device(mlx::core::Device::gpu)
.set_residency_set(residency_set_.mtl_residency_set());
bool is_vm = std::get<std::string>(device_info().at("device_name")) ==
"Apple Paravirtual device";
if (is_vm) {
return;
}
auto heap_desc = MTL::HeapDescriptor::alloc()->init();
heap_desc->setResourceOptions(resource_options);
heap_desc->setSize(heap_size_);
heap_ = device_->newHeap(heap_desc);
heap_desc->release();
residency_set_.insert(heap_);
}
MetalAllocator::~MetalAllocator() {
auto pool = metal::new_scoped_memory_pool();
if (heap_) {
heap_->release();
}
}
size_t MetalAllocator::set_cache_limit(size_t limit) {
@ -226,8 +248,6 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
}
// Allocate new buffer if needed
size_t res_opt = MTL::ResourceStorageModeShared;
res_opt |= MTL::ResourceHazardTrackingModeUntracked;
if (num_resources_ >= resource_limit_) {
std::ostringstream msg;
msg << "[metal::malloc] Resource limit (" << resource_limit_
@ -235,7 +255,12 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
throw std::runtime_error(msg.str());
}
lk.unlock();
buf = device_->newBuffer(size, res_opt);
if (size < small_size_ && heap_) {
buf = heap_->newBuffer(size, resource_options);
}
if (!buf) {
buf = device_->newBuffer(size, resource_options);
}
lk.lock();
if (buf) {
num_resources_++;
@ -246,13 +271,15 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
peak_memory_ = std::max(peak_memory_, active_memory_);
// Maintain the cache below the requested limit
if (get_cache_memory() >= max_pool_size_) {
if (get_cache_memory() > max_pool_size_) {
auto pool = metal::new_scoped_memory_pool();
num_resources_ -= buffer_cache_.release_cached_buffers(
get_cache_memory() - max_pool_size_);
}
residency_set_.insert(buf);
if (!buf->heap()) {
residency_set_.insert(buf);
}
return Buffer{static_cast<void*>(buf)};
}
@ -269,7 +296,9 @@ void MetalAllocator::free(Buffer buffer) {
return;
}
std::unique_lock lk(mutex_);
residency_set_.erase(buf);
if (!buf->heap()) {
residency_set_.erase(buf);
}
active_memory_ -= buf->length();
if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf);
@ -301,7 +330,7 @@ size_t set_memory_limit(size_t limit, bool relaxed /* = true */) {
}
size_t set_wired_limit(size_t limit) {
if (limit >
std::get<size_t>(device_info()["max_recommended_working_set_size"])) {
std::get<size_t>(device_info().at("max_recommended_working_set_size"))) {
throw std::invalid_argument(
"[metal::set_wired_limit] Setting a wired limit larger than "
"the maximum working set size is not allowed.");

View File

@ -43,6 +43,7 @@ class BufferCache {
void remove_from_list(BufferHolder* to_remove);
MTL::Device* device_;
MTL::Heap* heap_{nullptr};
std::multimap<size_t, BufferHolder*> buffer_pool_;
BufferHolder* head_;
@ -78,7 +79,15 @@ class MetalAllocator : public allocator::Allocator {
private:
MTL::Device* device_;
// The size of allocations which go on the heap until it is full. This size
// is chosen because it is the actual minimum size of a buffer allocated from
// the heap, a heap can have at most heap.size() / 256 buffers.
static constexpr int small_size_ = 256;
static constexpr int heap_size_ = 1 << 20;
MTL::Heap* heap_;
MetalAllocator();
~MetalAllocator();
friend MetalAllocator& allocator();
// Caching allocator

View File

@ -692,12 +692,13 @@ void new_stream(Stream stream) {
}
}
std::unordered_map<std::string, std::variant<std::string, size_t>>
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
device_info() {
auto init_device_info = []()
-> std::unordered_map<std::string, std::variant<std::string, size_t>> {
auto pool = new_scoped_memory_pool();
auto raw_device = device(default_device()).mtl_device();
auto name = std::string(raw_device->name()->utf8String());
auto arch = std::string(raw_device->architecture()->name()->utf8String());
size_t memsize = 0;
@ -711,6 +712,7 @@ device_info() {
}
return {
{"device_name", name},
{"architecture", arch},
{"max_buffer_length", raw_device->maxBufferLength()},
{"max_recommended_working_set_size",

View File

@ -82,7 +82,7 @@ void start_capture(std::string path = "");
void stop_capture();
/** Get information about the GPU and system settings. */
std::unordered_map<std::string, std::variant<std::string, size_t>>
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
device_info();
} // namespace mlx::core::metal

View File

@ -54,7 +54,7 @@ void start_capture(std::string) {}
void stop_capture() {}
void clear_cache() {}
std::unordered_map<std::string, std::variant<std::string, size_t>>
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
device_info() {
throw std::runtime_error(
"[metal::device_info] Cannot get device info without metal backend");