Some fixes in cache / thread safety (#777)

* some fixes in cache / thread safety

* speed up no cache case

* fix opt test

* optimizer docs

* otpimizer docs

* fix adafactor

* fix adafactor
This commit is contained in:
Awni Hannun
2024-03-05 13:30:50 -08:00
committed by GitHub
parent 859ae15a54
commit cbcf44a4ca
4 changed files with 60 additions and 41 deletions

View File

@@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/allocator.h"
#include "mlx/backend/metal/metal.h"
@@ -34,7 +33,6 @@ BufferCache::~BufferCache() {
}
void BufferCache::clear() {
std::lock_guard<std::mutex> lk(cache_mutex_);
for (auto& [size, holder] : buffer_pool_) {
if (holder->buf)
holder->buf->release();
@@ -47,12 +45,9 @@ void BufferCache::clear() {
}
MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
std::lock_guard<std::mutex> lk(cache_mutex_);
// Find the closest buffer in pool
MTL::Buffer* pbuf = nullptr;
// Make sure we use most of the available memory
auto it = buffer_pool_.lower_bound(size);
// Make sure we use most of the available memory
@@ -75,8 +70,6 @@ MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
}
void BufferCache::recycle_to_cache(MTL::Buffer* buf) {
std::lock_guard<std::mutex> lk(cache_mutex_);
// Add to cache
if (buf) {
BufferHolder* bh = new BufferHolder(buf);
@@ -90,7 +83,6 @@ void BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
if (min_bytes_to_free >= 0.9 * pool_size_) {
clear();
} else {
std::lock_guard<std::mutex> lk(cache_mutex_);
size_t total_bytes_freed = 0;
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
@@ -178,10 +170,10 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
}
// Try the cache
std::unique_lock lk(mutex_);
MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size);
size_t pool_size = get_cache_memory();
if (!buf) {
size_t mem_required = get_active_memory() + pool_size + size;
size_t mem_required = get_active_memory() + get_cache_memory() + size;
// If there is too much memory pressure, fail (likely causes a wait).
if (!(allow_swap && relaxed_) && mem_required >= block_limit_) {
@@ -190,8 +182,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
auto thread_pool = metal::new_scoped_memory_pool();
// If we have a lot of memory pressure, check if we can reclaim some memory
// from the cache
// If we have a lot of memory pressure or are over the maximum cache size,
// try to reclaim memory from the cache
if (mem_required >= gc_limit_) {
buffer_cache_.release_cached_buffers(mem_required - gc_limit_);
}
@@ -199,27 +191,32 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Allocate new buffer if needed
size_t res_opt = MTL::ResourceStorageModeShared;
res_opt |= MTL::ResourceHazardTrackingModeTracked;
lk.unlock();
buf = device_->newBuffer(size, res_opt);
}
// Maintain the cache below the requested limit
if (pool_size >= max_pool_size_) {
auto thread_pool = metal::new_scoped_memory_pool();
buffer_cache_.release_cached_buffers(pool_size - max_pool_size_);
lk.lock();
}
active_memory_ += buf->length();
peak_memory_ = std::max(peak_memory_, active_memory_);
// Maintain the cache below the requested limit
if (get_cache_memory() >= max_pool_size_) {
auto thread_pool = metal::new_scoped_memory_pool();
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
}
return Buffer{static_cast<void*>(buf)};
}
void MetalAllocator::free(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
std::unique_lock lk(mutex_);
active_memory_ -= buf->length();
if (max_pool_size_ > 0) {
if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf);
} else {
lk.unlock();
auto thread_pool = metal::new_scoped_memory_pool();
buf->release();
}
}

View File

@@ -19,12 +19,11 @@ class BufferCache {
public:
BufferCache(MTL::Device* device);
~BufferCache();
void clear();
MTL::Buffer* reuse_from_cache(size_t size);
void recycle_to_cache(MTL::Buffer* buf);
void release_cached_buffers(size_t min_bytes_to_free);
size_t pool_size() {
size_t cache_size() {
return pool_size_;
}
@@ -38,11 +37,11 @@ class BufferCache {
MTL::Buffer* buf;
};
void clear();
void add_at_head(BufferHolder* to_add);
void remove_from_list(BufferHolder* to_remove);
MTL::Device* device_;
std::mutex cache_mutex_;
std::multimap<size_t, BufferHolder*> buffer_pool_;
BufferHolder* head_;
@@ -64,7 +63,7 @@ class MetalAllocator : public allocator::Allocator {
return peak_memory_;
};
size_t get_cache_memory() {
return buffer_cache_.pool_size();
return buffer_cache_.cache_size();
};
size_t set_cache_limit(size_t limit);
size_t set_memory_limit(size_t limit, bool relaxed);
@@ -84,6 +83,8 @@ class MetalAllocator : public allocator::Allocator {
size_t peak_memory_{0};
size_t max_pool_size_;
bool relaxed_{true};
std::mutex mutex_;
};
MetalAllocator& allocator();