try no cache

This commit is contained in:
Awni Hannun 2023-12-25 12:50:38 -08:00
parent b900e60972
commit ad5036072c
2 changed files with 21 additions and 180 deletions

View File

@ -1,5 +1,6 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <iostream>
#include "mlx/backend/metal/allocator.h" #include "mlx/backend/metal/allocator.h"
#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal.h"
@ -23,155 +24,40 @@ void* Buffer::raw_ptr() {
namespace metal { namespace metal {
namespace {
BufferCache::BufferCache(MTL::Device* device)
: device_(device), head_(nullptr), tail_(nullptr), pool_size_(0) {}
BufferCache::~BufferCache() {
clear();
}
void BufferCache::clear() {
std::lock_guard<std::mutex> lk(cache_mutex_);
for (auto& [size, holder] : buffer_pool_) {
if (holder->buf)
holder->buf->release();
delete holder;
}
buffer_pool_.clear();
pool_size_ = 0;
head_ = nullptr;
tail_ = nullptr;
}
MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
std::lock_guard<std::mutex> lk(cache_mutex_);
// Find the closest buffer in pool
MTL::Buffer* pbuf = nullptr;
// Make sure we use > 50% of the available memory
if (auto it = buffer_pool_.lower_bound(size);
it != buffer_pool_.end() && it->first < 2 * size) {
// Collect from the cache
pbuf = it->second->buf;
// Remove from cache
remove_from_list(it->second);
delete it->second;
it = buffer_pool_.erase(it);
}
if (pbuf) {
pool_size_ -= pbuf->length();
}
return pbuf;
}
void BufferCache::recycle_to_cache(MTL::Buffer* buf) {
std::lock_guard<std::mutex> lk(cache_mutex_);
// Add to cache
if (buf) {
BufferHolder* bh = new BufferHolder(buf);
add_at_head(bh);
pool_size_ += buf->length();
buffer_pool_.insert({buf->length(), bh});
}
}
void BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
if (min_bytes_to_free >= 0.9 * pool_size_) {
clear();
} else {
std::lock_guard<std::mutex> lk(cache_mutex_);
size_t total_bytes_freed = 0;
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
if (tail_->buf) {
total_bytes_freed += tail_->buf->length();
tail_->buf->release();
tail_->buf = nullptr;
}
remove_from_list(tail_);
}
pool_size_ -= total_bytes_freed;
}
}
void BufferCache::add_at_head(BufferCache::BufferHolder* to_add) {
if (!to_add)
return;
if (!head_) {
head_ = to_add;
tail_ = to_add;
} else {
head_->prev = to_add;
to_add->next = head_;
head_ = to_add;
}
}
void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
if (!to_remove)
return;
// If in the middle
if (to_remove->prev && to_remove->next) {
to_remove->prev->next = to_remove->next;
to_remove->next->prev = to_remove->prev;
} else if (to_remove->prev && to_remove == tail_) { // If tail
tail_ = to_remove->prev;
tail_->next = nullptr;
} else if (to_remove == head_ && to_remove->next) { // If head
head_ = to_remove->next;
head_->prev = nullptr;
} else if (to_remove == head_ && to_remove == tail_) { // If only element
head_ = nullptr;
tail_ = nullptr;
}
to_remove->prev = nullptr;
to_remove->next = nullptr;
}
} // namespace
MetalAllocator::MetalAllocator() MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()), : device_(device(mlx::core::Device::gpu).mtl_device()),
buffer_cache_(device_),
peak_allocated_size_(0), peak_allocated_size_(0),
block_limit_(1.5 * device_->recommendedMaxWorkingSetSize()) {} block_limit_(device_->recommendedMaxWorkingSetSize()) {}
Buffer MetalAllocator::malloc(size_t size) { Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Align up memory // Align up memory
if (size > vm_page_size) { ///if (size > vm_page_size) {
size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size); // size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size);
} //}
MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size); // MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size);
// Prepare to allocate new memory as needed // Prepare to allocate new memory as needed
if (!buf) { // if (!buf) {
// First check if the cache is big but nothing fits, garbage collect // If we have memory pressure, first check if we can reclaim some memory
// if so // from the cache
// TODO maybe block limit and gc limit should be different // if (auto new_size = device_->currentAllocatedSize() + size; new_size >= block_limit_) {
if (buffer_cache_.size() >= block_limit_) { // buffer_cache_.clear();
buffer_cache_.release_cached_buffers( // buffer_cache_.release_cached_buffers(
std::max(buffer_cache_.size() - block_limit_, size)); // std::max(new_size - block_limit_, size));
} // }
// If there is still too much memory pressure, fail (likely causes a wait). // If there is still too much memory pressure, fail (likely causes a wait).
if (device_->currentAllocatedSize() >= block_limit_) { // size + allocated (to avoid going over the limit)
if (!allow_swap && device_->currentAllocatedSize() + size >= block_limit_) {
return Buffer{nullptr}; return Buffer{nullptr};
} }
// }
// Allocate new buffer if needed // Allocate new buffer if needed
size_t res_opt = MTL::ResourceStorageModeShared; size_t res_opt = MTL::ResourceStorageModeShared;
res_opt |= MTL::ResourceHazardTrackingModeTracked; res_opt |= MTL::ResourceHazardTrackingModeTracked;
buf = device_->newBuffer(size, res_opt); auto buf = device_->newBuffer(size, res_opt);
}
peak_allocated_size_ = peak_allocated_size_ =
std::max(peak_allocated_size_, device_->currentAllocatedSize()); std::max(peak_allocated_size_, device_->currentAllocatedSize());
@ -180,8 +66,7 @@ Buffer MetalAllocator::malloc(size_t size) {
} }
void MetalAllocator::free(Buffer buffer) { void MetalAllocator::free(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr()); static_cast<MTL::Buffer*>(buffer.ptr())->release();
buffer_cache_.recycle_to_cache(buf);
} }
MetalAllocator& allocator() { MetalAllocator& allocator() {

View File

@ -13,51 +13,10 @@ namespace mlx::core::metal {
using allocator::Buffer; using allocator::Buffer;
namespace {
class BufferCache {
public:
BufferCache(MTL::Device* device);
~BufferCache();
void clear();
MTL::Buffer* reuse_from_cache(size_t size);
void recycle_to_cache(MTL::Buffer* buf);
void release_cached_buffers(size_t min_bytes_to_free);
// Returnt he size in bytes of cached memory
size_t size() {
return pool_size_;
}
private:
struct BufferHolder {
public:
BufferHolder(MTL::Buffer* buf_) : buf(buf_), prev(nullptr), next(nullptr) {}
BufferHolder* prev;
BufferHolder* next;
MTL::Buffer* buf;
};
void add_at_head(BufferHolder* to_add);
void remove_from_list(BufferHolder* to_remove);
MTL::Device* device_;
std::mutex cache_mutex_;
std::multimap<size_t, BufferHolder*> buffer_pool_;
BufferHolder* head_;
BufferHolder* tail_;
size_t pool_size_;
};
} // namespace
class MetalAllocator : public allocator::Allocator { class MetalAllocator : public allocator::Allocator {
/** Allocator for Metal GPUs. */ /** Allocator for Metal GPUs. */
public: public:
virtual Buffer malloc(size_t size) override; virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual void free(Buffer buffer) override; virtual void free(Buffer buffer) override;
private: private:
@ -65,9 +24,6 @@ class MetalAllocator : public allocator::Allocator {
MetalAllocator(); MetalAllocator();
friend MetalAllocator& allocator(); friend MetalAllocator& allocator();
// Caching allocator
BufferCache buffer_cache_;
// Allocation stats // Allocation stats
size_t peak_allocated_size_; size_t peak_allocated_size_;
size_t block_limit_; size_t block_limit_;