Memory allocation (#292)

* try alternative gc

* try no cache

* add forced swap

* remove cache for now

* add cache back

* change fit crtieria

* remove unused function

* nit in comment

* tune / fix allocation

* increase block limit to original
This commit is contained in:
Awni Hannun 2024-01-02 11:59:19 -08:00 committed by GitHub
parent 295ce9db09
commit 99c80a2c8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 34 additions and 34 deletions

View File

@ -9,7 +9,7 @@
namespace mlx::core::allocator { namespace mlx::core::allocator {
Buffer malloc(size_t size) { Buffer malloc(size_t size) {
auto buffer = allocator().malloc(size); auto buffer = allocator().malloc(size, /* allow_swap */ true);
if (size && !buffer.ptr()) { if (size && !buffer.ptr()) {
std::ostringstream msg; std::ostringstream msg;
msg << "[malloc] Unable to allocate " << size << " bytes."; msg << "[malloc] Unable to allocate " << size << " bytes.";
@ -22,7 +22,7 @@ void free(Buffer buffer) {
return allocator().free(buffer); return allocator().free(buffer);
} }
Buffer CommonAllocator::malloc(size_t size) { Buffer CommonAllocator::malloc(size_t size, bool) {
return Buffer{std::malloc(size)}; return Buffer{std::malloc(size)};
} }
@ -38,6 +38,11 @@ Buffer malloc_or_wait(size_t size) {
buffer = allocator().malloc(size); buffer = allocator().malloc(size);
} }
// Try swapping if needed
if (size && !buffer.ptr()) {
buffer = allocator().malloc(size, /* allow_swap = */ true);
}
if (size && !buffer.ptr()) { if (size && !buffer.ptr()) {
std::ostringstream msg; std::ostringstream msg;
msg << "[malloc_or_wait] Unable to allocate " << size << " bytes."; msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";

View File

@ -39,7 +39,7 @@ Buffer malloc_or_wait(size_t size);
class Allocator { class Allocator {
/** Abstract base class for a memory allocator. */ /** Abstract base class for a memory allocator. */
public: public:
virtual Buffer malloc(size_t size) = 0; virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
virtual void free(Buffer buffer) = 0; virtual void free(Buffer buffer) = 0;
Allocator() = default; Allocator() = default;
@ -55,7 +55,7 @@ Allocator& allocator();
class CommonAllocator : public Allocator { class CommonAllocator : public Allocator {
/** A general CPU allocator. */ /** A general CPU allocator. */
public: public:
virtual Buffer malloc(size_t size) override; virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual void free(Buffer buffer) override; virtual void free(Buffer buffer) override;
private: private:

View File

@ -26,11 +26,7 @@ namespace metal {
namespace { namespace {
BufferCache::BufferCache(MTL::Device* device) BufferCache::BufferCache(MTL::Device* device)
: device_(device), : device_(device), head_(nullptr), tail_(nullptr), pool_size_(0) {}
head_(nullptr),
tail_(nullptr),
pool_size_(0),
gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()) {}
BufferCache::~BufferCache() { BufferCache::~BufferCache() {
clear(); clear();
@ -54,12 +50,16 @@ MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
// Find the closest buffer in pool // Find the closest buffer in pool
MTL::Buffer* pbuf = nullptr; MTL::Buffer* pbuf = nullptr;
// Make sure we use most of the available memory
auto it = buffer_pool_.lower_bound(size); auto it = buffer_pool_.lower_bound(size);
// Make sure we use > 50% of the available memory // Make sure we use most of the available memory
while (!pbuf && it != buffer_pool_.end() && it->first < 2 * size) { while (!pbuf && it != buffer_pool_.end() &&
it->first < std::min(2 * size, size + 2 * vm_page_size)) {
// Collect from the cache // Collect from the cache
pbuf = it->second->buf; pbuf = it->second->buf;
// Remove from cache // Remove from cache
remove_from_list(it->second); remove_from_list(it->second);
delete it->second; delete it->second;
@ -85,13 +85,9 @@ void BufferCache::recycle_to_cache(MTL::Buffer* buf) {
} }
} }
size_t BufferCache::release_cached_buffers(size_t min_bytes_to_free) { void BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
min_bytes_to_free += device_->currentAllocatedSize() - gc_limit_;
if (min_bytes_to_free >= 0.9 * pool_size_) { if (min_bytes_to_free >= 0.9 * pool_size_) {
size_t old_pool_size = pool_size_;
clear(); clear();
return old_pool_size;
} else { } else {
std::lock_guard<std::mutex> lk(cache_mutex_); std::lock_guard<std::mutex> lk(cache_mutex_);
size_t total_bytes_freed = 0; size_t total_bytes_freed = 0;
@ -104,9 +100,7 @@ size_t BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
} }
remove_from_list(tail_); remove_from_list(tail_);
} }
pool_size_ -= total_bytes_freed; pool_size_ -= total_bytes_freed;
return total_bytes_freed;
} }
} }
@ -125,8 +119,9 @@ void BufferCache::add_at_head(BufferCache::BufferHolder* to_add) {
} }
void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) { void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
if (!to_remove) if (!to_remove) {
return; return;
}
// If in the middle // If in the middle
if (to_remove->prev && to_remove->next) { if (to_remove->prev && to_remove->next) {
@ -153,26 +148,30 @@ MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()), : device_(device(mlx::core::Device::gpu).mtl_device()),
buffer_cache_(device_), buffer_cache_(device_),
peak_allocated_size_(0), peak_allocated_size_(0),
block_limit_(1.5 * device_->recommendedMaxWorkingSetSize()) {} block_limit_(1.5 * device_->recommendedMaxWorkingSetSize()),
gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()) {}
Buffer MetalAllocator::malloc(size_t size) { Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Align up memory // Align up memory
if (size > vm_page_size) { if (size > vm_page_size) {
size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size); size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size);
} }
// Try the cache
MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size); MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size);
// Prepare to allocate new memory as needed
if (!buf) { if (!buf) {
// If we are under very high memory pressure, we don't allocate further // If there is too much memory pressure, fail (likely causes a wait).
if (device_->currentAllocatedSize() >= block_limit_) { if (!allow_swap && device_->currentAllocatedSize() + size >= block_limit_) {
return Buffer{nullptr}; return Buffer{nullptr};
} }
// If we are still under memory pressure, try cleaning cache // If we have a lot of memory pressure, check if we can reclaim some memory
if (buffer_cache_.can_garbage_collect()) { // from the cache
buffer_cache_.release_cached_buffers(size); if (device_->currentAllocatedSize() + size >= gc_limit_) {
size_t min_bytes_to_free =
size + device_->currentAllocatedSize() - gc_limit_;
buffer_cache_.release_cached_buffers(min_bytes_to_free);
} }
// Allocate new buffer if needed // Allocate new buffer if needed

View File

@ -23,11 +23,7 @@ class BufferCache {
MTL::Buffer* reuse_from_cache(size_t size); MTL::Buffer* reuse_from_cache(size_t size);
void recycle_to_cache(MTL::Buffer* buf); void recycle_to_cache(MTL::Buffer* buf);
size_t release_cached_buffers(size_t min_bytes_to_free); void release_cached_buffers(size_t min_bytes_to_free);
bool can_garbage_collect() {
return pool_size_ > 0 && device_->currentAllocatedSize() > gc_limit_;
}
private: private:
struct BufferHolder { struct BufferHolder {
@ -49,7 +45,6 @@ class BufferCache {
BufferHolder* head_; BufferHolder* head_;
BufferHolder* tail_; BufferHolder* tail_;
size_t pool_size_; size_t pool_size_;
size_t gc_limit_;
}; };
} // namespace } // namespace
@ -57,7 +52,7 @@ class BufferCache {
class MetalAllocator : public allocator::Allocator { class MetalAllocator : public allocator::Allocator {
/** Allocator for Metal GPUs. */ /** Allocator for Metal GPUs. */
public: public:
virtual Buffer malloc(size_t size) override; virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual void free(Buffer buffer) override; virtual void free(Buffer buffer) override;
private: private:
@ -71,6 +66,7 @@ class MetalAllocator : public allocator::Allocator {
// Allocation stats // Allocation stats
size_t peak_allocated_size_; size_t peak_allocated_size_;
size_t block_limit_; size_t block_limit_;
size_t gc_limit_;
}; };
MetalAllocator& allocator(); MetalAllocator& allocator();