mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Use a heap for small sizes (#1911)
* use a heap for small sizes * check if VM
This commit is contained in:
parent
4e7cd31d12
commit
ba12e4999a
@ -10,6 +10,9 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
constexpr size_t resource_options =
|
||||||
|
MTL::ResourceStorageModeShared | MTL::ResourceHazardTrackingModeUntracked;
|
||||||
|
|
||||||
namespace allocator {
|
namespace allocator {
|
||||||
|
|
||||||
Allocator& allocator() {
|
Allocator& allocator() {
|
||||||
@ -150,15 +153,34 @@ MetalAllocator::MetalAllocator()
|
|||||||
: device_(device(mlx::core::Device::gpu).mtl_device()),
|
: device_(device(mlx::core::Device::gpu).mtl_device()),
|
||||||
residency_set_(device_),
|
residency_set_(device_),
|
||||||
buffer_cache_(device_) {
|
buffer_cache_(device_) {
|
||||||
auto memsize = std::get<size_t>(device_info()["memory_size"]);
|
auto pool = metal::new_scoped_memory_pool();
|
||||||
|
auto memsize = std::get<size_t>(device_info().at("memory_size"));
|
||||||
auto max_rec_size =
|
auto max_rec_size =
|
||||||
std::get<size_t>(device_info()["max_recommended_working_set_size"]);
|
std::get<size_t>(device_info().at("max_recommended_working_set_size"));
|
||||||
resource_limit_ = std::get<size_t>(device_info()["resource_limit"]);
|
resource_limit_ = std::get<size_t>(device_info().at("resource_limit"));
|
||||||
block_limit_ = std::min(1.5 * max_rec_size, 0.95 * memsize);
|
block_limit_ = std::min(1.5 * max_rec_size, 0.95 * memsize);
|
||||||
gc_limit_ = std::min(static_cast<size_t>(0.95 * max_rec_size), block_limit_);
|
gc_limit_ = std::min(static_cast<size_t>(0.95 * max_rec_size), block_limit_);
|
||||||
max_pool_size_ = block_limit_;
|
max_pool_size_ = block_limit_;
|
||||||
device(mlx::core::Device::gpu)
|
device(mlx::core::Device::gpu)
|
||||||
.set_residency_set(residency_set_.mtl_residency_set());
|
.set_residency_set(residency_set_.mtl_residency_set());
|
||||||
|
bool is_vm = std::get<std::string>(device_info().at("device_name")) ==
|
||||||
|
"Apple Paravirtual device";
|
||||||
|
if (is_vm) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto heap_desc = MTL::HeapDescriptor::alloc()->init();
|
||||||
|
heap_desc->setResourceOptions(resource_options);
|
||||||
|
heap_desc->setSize(heap_size_);
|
||||||
|
heap_ = device_->newHeap(heap_desc);
|
||||||
|
heap_desc->release();
|
||||||
|
residency_set_.insert(heap_);
|
||||||
|
}
|
||||||
|
|
||||||
|
MetalAllocator::~MetalAllocator() {
|
||||||
|
auto pool = metal::new_scoped_memory_pool();
|
||||||
|
if (heap_) {
|
||||||
|
heap_->release();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t MetalAllocator::set_cache_limit(size_t limit) {
|
size_t MetalAllocator::set_cache_limit(size_t limit) {
|
||||||
@ -226,8 +248,6 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Allocate new buffer if needed
|
// Allocate new buffer if needed
|
||||||
size_t res_opt = MTL::ResourceStorageModeShared;
|
|
||||||
res_opt |= MTL::ResourceHazardTrackingModeUntracked;
|
|
||||||
if (num_resources_ >= resource_limit_) {
|
if (num_resources_ >= resource_limit_) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[metal::malloc] Resource limit (" << resource_limit_
|
msg << "[metal::malloc] Resource limit (" << resource_limit_
|
||||||
@ -235,7 +255,12 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
|||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
lk.unlock();
|
lk.unlock();
|
||||||
buf = device_->newBuffer(size, res_opt);
|
if (size < small_size_ && heap_) {
|
||||||
|
buf = heap_->newBuffer(size, resource_options);
|
||||||
|
}
|
||||||
|
if (!buf) {
|
||||||
|
buf = device_->newBuffer(size, resource_options);
|
||||||
|
}
|
||||||
lk.lock();
|
lk.lock();
|
||||||
if (buf) {
|
if (buf) {
|
||||||
num_resources_++;
|
num_resources_++;
|
||||||
@ -246,13 +271,15 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
|||||||
peak_memory_ = std::max(peak_memory_, active_memory_);
|
peak_memory_ = std::max(peak_memory_, active_memory_);
|
||||||
|
|
||||||
// Maintain the cache below the requested limit
|
// Maintain the cache below the requested limit
|
||||||
if (get_cache_memory() >= max_pool_size_) {
|
if (get_cache_memory() > max_pool_size_) {
|
||||||
auto pool = metal::new_scoped_memory_pool();
|
auto pool = metal::new_scoped_memory_pool();
|
||||||
num_resources_ -= buffer_cache_.release_cached_buffers(
|
num_resources_ -= buffer_cache_.release_cached_buffers(
|
||||||
get_cache_memory() - max_pool_size_);
|
get_cache_memory() - max_pool_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!buf->heap()) {
|
||||||
residency_set_.insert(buf);
|
residency_set_.insert(buf);
|
||||||
|
}
|
||||||
|
|
||||||
return Buffer{static_cast<void*>(buf)};
|
return Buffer{static_cast<void*>(buf)};
|
||||||
}
|
}
|
||||||
@ -269,7 +296,9 @@ void MetalAllocator::free(Buffer buffer) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
std::unique_lock lk(mutex_);
|
std::unique_lock lk(mutex_);
|
||||||
|
if (!buf->heap()) {
|
||||||
residency_set_.erase(buf);
|
residency_set_.erase(buf);
|
||||||
|
}
|
||||||
active_memory_ -= buf->length();
|
active_memory_ -= buf->length();
|
||||||
if (get_cache_memory() < max_pool_size_) {
|
if (get_cache_memory() < max_pool_size_) {
|
||||||
buffer_cache_.recycle_to_cache(buf);
|
buffer_cache_.recycle_to_cache(buf);
|
||||||
@ -301,7 +330,7 @@ size_t set_memory_limit(size_t limit, bool relaxed /* = true */) {
|
|||||||
}
|
}
|
||||||
size_t set_wired_limit(size_t limit) {
|
size_t set_wired_limit(size_t limit) {
|
||||||
if (limit >
|
if (limit >
|
||||||
std::get<size_t>(device_info()["max_recommended_working_set_size"])) {
|
std::get<size_t>(device_info().at("max_recommended_working_set_size"))) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[metal::set_wired_limit] Setting a wired limit larger than "
|
"[metal::set_wired_limit] Setting a wired limit larger than "
|
||||||
"the maximum working set size is not allowed.");
|
"the maximum working set size is not allowed.");
|
||||||
|
@ -43,6 +43,7 @@ class BufferCache {
|
|||||||
void remove_from_list(BufferHolder* to_remove);
|
void remove_from_list(BufferHolder* to_remove);
|
||||||
|
|
||||||
MTL::Device* device_;
|
MTL::Device* device_;
|
||||||
|
MTL::Heap* heap_{nullptr};
|
||||||
|
|
||||||
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
||||||
BufferHolder* head_;
|
BufferHolder* head_;
|
||||||
@ -78,7 +79,15 @@ class MetalAllocator : public allocator::Allocator {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
MTL::Device* device_;
|
MTL::Device* device_;
|
||||||
|
|
||||||
|
// The size of allocations which go on the heap until it is full. This size
|
||||||
|
// is chosen because it is the actual minimum size of a buffer allocated from
|
||||||
|
// the heap, a heap can have at most heap.size() / 256 buffers.
|
||||||
|
static constexpr int small_size_ = 256;
|
||||||
|
static constexpr int heap_size_ = 1 << 20;
|
||||||
|
MTL::Heap* heap_;
|
||||||
MetalAllocator();
|
MetalAllocator();
|
||||||
|
~MetalAllocator();
|
||||||
friend MetalAllocator& allocator();
|
friend MetalAllocator& allocator();
|
||||||
|
|
||||||
// Caching allocator
|
// Caching allocator
|
||||||
|
@ -692,12 +692,13 @@ void new_stream(Stream stream) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unordered_map<std::string, std::variant<std::string, size_t>>
|
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
|
||||||
device_info() {
|
device_info() {
|
||||||
auto init_device_info = []()
|
auto init_device_info = []()
|
||||||
-> std::unordered_map<std::string, std::variant<std::string, size_t>> {
|
-> std::unordered_map<std::string, std::variant<std::string, size_t>> {
|
||||||
auto pool = new_scoped_memory_pool();
|
auto pool = new_scoped_memory_pool();
|
||||||
auto raw_device = device(default_device()).mtl_device();
|
auto raw_device = device(default_device()).mtl_device();
|
||||||
|
auto name = std::string(raw_device->name()->utf8String());
|
||||||
auto arch = std::string(raw_device->architecture()->name()->utf8String());
|
auto arch = std::string(raw_device->architecture()->name()->utf8String());
|
||||||
|
|
||||||
size_t memsize = 0;
|
size_t memsize = 0;
|
||||||
@ -711,6 +712,7 @@ device_info() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
{"device_name", name},
|
||||||
{"architecture", arch},
|
{"architecture", arch},
|
||||||
{"max_buffer_length", raw_device->maxBufferLength()},
|
{"max_buffer_length", raw_device->maxBufferLength()},
|
||||||
{"max_recommended_working_set_size",
|
{"max_recommended_working_set_size",
|
||||||
|
@ -82,7 +82,7 @@ void start_capture(std::string path = "");
|
|||||||
void stop_capture();
|
void stop_capture();
|
||||||
|
|
||||||
/** Get information about the GPU and system settings. */
|
/** Get information about the GPU and system settings. */
|
||||||
std::unordered_map<std::string, std::variant<std::string, size_t>>
|
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
|
||||||
device_info();
|
device_info();
|
||||||
|
|
||||||
} // namespace mlx::core::metal
|
} // namespace mlx::core::metal
|
||||||
|
@ -54,7 +54,7 @@ void start_capture(std::string) {}
|
|||||||
void stop_capture() {}
|
void stop_capture() {}
|
||||||
void clear_cache() {}
|
void clear_cache() {}
|
||||||
|
|
||||||
std::unordered_map<std::string, std::variant<std::string, size_t>>
|
const std::unordered_map<std::string, std::variant<std::string, size_t>>&
|
||||||
device_info() {
|
device_info() {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[metal::device_info] Cannot get device info without metal backend");
|
"[metal::device_info] Cannot get device info without metal backend");
|
||||||
|
Loading…
Reference in New Issue
Block a user