mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Memory allocation (#292)
* try alternative gc * try no cache * add forced swap * remove cache for now * add cache back * change fit crtieria * remove unused function * nit in comment * tune / fix allocation * increase block limit to original
This commit is contained in:
parent
295ce9db09
commit
99c80a2c8b
@ -9,7 +9,7 @@
|
|||||||
namespace mlx::core::allocator {
|
namespace mlx::core::allocator {
|
||||||
|
|
||||||
Buffer malloc(size_t size) {
|
Buffer malloc(size_t size) {
|
||||||
auto buffer = allocator().malloc(size);
|
auto buffer = allocator().malloc(size, /* allow_swap */ true);
|
||||||
if (size && !buffer.ptr()) {
|
if (size && !buffer.ptr()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
||||||
@ -22,7 +22,7 @@ void free(Buffer buffer) {
|
|||||||
return allocator().free(buffer);
|
return allocator().free(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
Buffer CommonAllocator::malloc(size_t size) {
|
Buffer CommonAllocator::malloc(size_t size, bool) {
|
||||||
return Buffer{std::malloc(size)};
|
return Buffer{std::malloc(size)};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -38,6 +38,11 @@ Buffer malloc_or_wait(size_t size) {
|
|||||||
buffer = allocator().malloc(size);
|
buffer = allocator().malloc(size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Try swapping if needed
|
||||||
|
if (size && !buffer.ptr()) {
|
||||||
|
buffer = allocator().malloc(size, /* allow_swap = */ true);
|
||||||
|
}
|
||||||
|
|
||||||
if (size && !buffer.ptr()) {
|
if (size && !buffer.ptr()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";
|
msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";
|
||||||
|
@ -39,7 +39,7 @@ Buffer malloc_or_wait(size_t size);
|
|||||||
class Allocator {
|
class Allocator {
|
||||||
/** Abstract base class for a memory allocator. */
|
/** Abstract base class for a memory allocator. */
|
||||||
public:
|
public:
|
||||||
virtual Buffer malloc(size_t size) = 0;
|
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
|
||||||
virtual void free(Buffer buffer) = 0;
|
virtual void free(Buffer buffer) = 0;
|
||||||
|
|
||||||
Allocator() = default;
|
Allocator() = default;
|
||||||
@ -55,7 +55,7 @@ Allocator& allocator();
|
|||||||
class CommonAllocator : public Allocator {
|
class CommonAllocator : public Allocator {
|
||||||
/** A general CPU allocator. */
|
/** A general CPU allocator. */
|
||||||
public:
|
public:
|
||||||
virtual Buffer malloc(size_t size) override;
|
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
||||||
virtual void free(Buffer buffer) override;
|
virtual void free(Buffer buffer) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -26,11 +26,7 @@ namespace metal {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
BufferCache::BufferCache(MTL::Device* device)
|
BufferCache::BufferCache(MTL::Device* device)
|
||||||
: device_(device),
|
: device_(device), head_(nullptr), tail_(nullptr), pool_size_(0) {}
|
||||||
head_(nullptr),
|
|
||||||
tail_(nullptr),
|
|
||||||
pool_size_(0),
|
|
||||||
gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()) {}
|
|
||||||
|
|
||||||
BufferCache::~BufferCache() {
|
BufferCache::~BufferCache() {
|
||||||
clear();
|
clear();
|
||||||
@ -54,12 +50,16 @@ MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
|
|||||||
|
|
||||||
// Find the closest buffer in pool
|
// Find the closest buffer in pool
|
||||||
MTL::Buffer* pbuf = nullptr;
|
MTL::Buffer* pbuf = nullptr;
|
||||||
|
|
||||||
|
// Make sure we use most of the available memory
|
||||||
auto it = buffer_pool_.lower_bound(size);
|
auto it = buffer_pool_.lower_bound(size);
|
||||||
|
|
||||||
// Make sure we use > 50% of the available memory
|
// Make sure we use most of the available memory
|
||||||
while (!pbuf && it != buffer_pool_.end() && it->first < 2 * size) {
|
while (!pbuf && it != buffer_pool_.end() &&
|
||||||
|
it->first < std::min(2 * size, size + 2 * vm_page_size)) {
|
||||||
// Collect from the cache
|
// Collect from the cache
|
||||||
pbuf = it->second->buf;
|
pbuf = it->second->buf;
|
||||||
|
|
||||||
// Remove from cache
|
// Remove from cache
|
||||||
remove_from_list(it->second);
|
remove_from_list(it->second);
|
||||||
delete it->second;
|
delete it->second;
|
||||||
@ -85,13 +85,9 @@ void BufferCache::recycle_to_cache(MTL::Buffer* buf) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
|
void BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
|
||||||
min_bytes_to_free += device_->currentAllocatedSize() - gc_limit_;
|
|
||||||
|
|
||||||
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
||||||
size_t old_pool_size = pool_size_;
|
|
||||||
clear();
|
clear();
|
||||||
return old_pool_size;
|
|
||||||
} else {
|
} else {
|
||||||
std::lock_guard<std::mutex> lk(cache_mutex_);
|
std::lock_guard<std::mutex> lk(cache_mutex_);
|
||||||
size_t total_bytes_freed = 0;
|
size_t total_bytes_freed = 0;
|
||||||
@ -104,9 +100,7 @@ size_t BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
|
|||||||
}
|
}
|
||||||
remove_from_list(tail_);
|
remove_from_list(tail_);
|
||||||
}
|
}
|
||||||
|
|
||||||
pool_size_ -= total_bytes_freed;
|
pool_size_ -= total_bytes_freed;
|
||||||
return total_bytes_freed;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -125,8 +119,9 @@ void BufferCache::add_at_head(BufferCache::BufferHolder* to_add) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
|
void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
|
||||||
if (!to_remove)
|
if (!to_remove) {
|
||||||
return;
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// If in the middle
|
// If in the middle
|
||||||
if (to_remove->prev && to_remove->next) {
|
if (to_remove->prev && to_remove->next) {
|
||||||
@ -153,26 +148,30 @@ MetalAllocator::MetalAllocator()
|
|||||||
: device_(device(mlx::core::Device::gpu).mtl_device()),
|
: device_(device(mlx::core::Device::gpu).mtl_device()),
|
||||||
buffer_cache_(device_),
|
buffer_cache_(device_),
|
||||||
peak_allocated_size_(0),
|
peak_allocated_size_(0),
|
||||||
block_limit_(1.5 * device_->recommendedMaxWorkingSetSize()) {}
|
block_limit_(1.5 * device_->recommendedMaxWorkingSetSize()),
|
||||||
|
gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()) {}
|
||||||
|
|
||||||
Buffer MetalAllocator::malloc(size_t size) {
|
Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||||
// Align up memory
|
// Align up memory
|
||||||
if (size > vm_page_size) {
|
if (size > vm_page_size) {
|
||||||
size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size);
|
size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Try the cache
|
||||||
MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size);
|
MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||||
|
|
||||||
// Prepare to allocate new memory as needed
|
|
||||||
if (!buf) {
|
if (!buf) {
|
||||||
// If we are under very high memory pressure, we don't allocate further
|
// If there is too much memory pressure, fail (likely causes a wait).
|
||||||
if (device_->currentAllocatedSize() >= block_limit_) {
|
if (!allow_swap && device_->currentAllocatedSize() + size >= block_limit_) {
|
||||||
return Buffer{nullptr};
|
return Buffer{nullptr};
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we are still under memory pressure, try cleaning cache
|
// If we have a lot of memory pressure, check if we can reclaim some memory
|
||||||
if (buffer_cache_.can_garbage_collect()) {
|
// from the cache
|
||||||
buffer_cache_.release_cached_buffers(size);
|
if (device_->currentAllocatedSize() + size >= gc_limit_) {
|
||||||
|
size_t min_bytes_to_free =
|
||||||
|
size + device_->currentAllocatedSize() - gc_limit_;
|
||||||
|
buffer_cache_.release_cached_buffers(min_bytes_to_free);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allocate new buffer if needed
|
// Allocate new buffer if needed
|
||||||
|
@ -23,11 +23,7 @@ class BufferCache {
|
|||||||
|
|
||||||
MTL::Buffer* reuse_from_cache(size_t size);
|
MTL::Buffer* reuse_from_cache(size_t size);
|
||||||
void recycle_to_cache(MTL::Buffer* buf);
|
void recycle_to_cache(MTL::Buffer* buf);
|
||||||
size_t release_cached_buffers(size_t min_bytes_to_free);
|
void release_cached_buffers(size_t min_bytes_to_free);
|
||||||
|
|
||||||
bool can_garbage_collect() {
|
|
||||||
return pool_size_ > 0 && device_->currentAllocatedSize() > gc_limit_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
struct BufferHolder {
|
struct BufferHolder {
|
||||||
@ -49,7 +45,6 @@ class BufferCache {
|
|||||||
BufferHolder* head_;
|
BufferHolder* head_;
|
||||||
BufferHolder* tail_;
|
BufferHolder* tail_;
|
||||||
size_t pool_size_;
|
size_t pool_size_;
|
||||||
size_t gc_limit_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -57,7 +52,7 @@ class BufferCache {
|
|||||||
class MetalAllocator : public allocator::Allocator {
|
class MetalAllocator : public allocator::Allocator {
|
||||||
/** Allocator for Metal GPUs. */
|
/** Allocator for Metal GPUs. */
|
||||||
public:
|
public:
|
||||||
virtual Buffer malloc(size_t size) override;
|
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
||||||
virtual void free(Buffer buffer) override;
|
virtual void free(Buffer buffer) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -71,6 +66,7 @@ class MetalAllocator : public allocator::Allocator {
|
|||||||
// Allocation stats
|
// Allocation stats
|
||||||
size_t peak_allocated_size_;
|
size_t peak_allocated_size_;
|
||||||
size_t block_limit_;
|
size_t block_limit_;
|
||||||
|
size_t gc_limit_;
|
||||||
};
|
};
|
||||||
|
|
||||||
MetalAllocator& allocator();
|
MetalAllocator& allocator();
|
||||||
|
Loading…
Reference in New Issue
Block a user