mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 11:14:32 +08:00
Some fixes in cache / thread safety (#777)
* some fixes in cache / thread safety * speed up no cache case * fix opt test * optimizer docs * otpimizer docs * fix adafactor * fix adafactor
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/allocator.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
|
||||
@@ -34,7 +33,6 @@ BufferCache::~BufferCache() {
|
||||
}
|
||||
|
||||
void BufferCache::clear() {
|
||||
std::lock_guard<std::mutex> lk(cache_mutex_);
|
||||
for (auto& [size, holder] : buffer_pool_) {
|
||||
if (holder->buf)
|
||||
holder->buf->release();
|
||||
@@ -47,12 +45,9 @@ void BufferCache::clear() {
|
||||
}
|
||||
|
||||
MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
|
||||
std::lock_guard<std::mutex> lk(cache_mutex_);
|
||||
|
||||
// Find the closest buffer in pool
|
||||
MTL::Buffer* pbuf = nullptr;
|
||||
|
||||
// Make sure we use most of the available memory
|
||||
auto it = buffer_pool_.lower_bound(size);
|
||||
|
||||
// Make sure we use most of the available memory
|
||||
@@ -75,8 +70,6 @@ MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
|
||||
}
|
||||
|
||||
void BufferCache::recycle_to_cache(MTL::Buffer* buf) {
|
||||
std::lock_guard<std::mutex> lk(cache_mutex_);
|
||||
|
||||
// Add to cache
|
||||
if (buf) {
|
||||
BufferHolder* bh = new BufferHolder(buf);
|
||||
@@ -90,7 +83,6 @@ void BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
|
||||
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
||||
clear();
|
||||
} else {
|
||||
std::lock_guard<std::mutex> lk(cache_mutex_);
|
||||
size_t total_bytes_freed = 0;
|
||||
|
||||
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
||||
@@ -178,10 +170,10 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
}
|
||||
|
||||
// Try the cache
|
||||
std::unique_lock lk(mutex_);
|
||||
MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||
size_t pool_size = get_cache_memory();
|
||||
if (!buf) {
|
||||
size_t mem_required = get_active_memory() + pool_size + size;
|
||||
size_t mem_required = get_active_memory() + get_cache_memory() + size;
|
||||
|
||||
// If there is too much memory pressure, fail (likely causes a wait).
|
||||
if (!(allow_swap && relaxed_) && mem_required >= block_limit_) {
|
||||
@@ -190,8 +182,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
|
||||
auto thread_pool = metal::new_scoped_memory_pool();
|
||||
|
||||
// If we have a lot of memory pressure, check if we can reclaim some memory
|
||||
// from the cache
|
||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||
// try to reclaim memory from the cache
|
||||
if (mem_required >= gc_limit_) {
|
||||
buffer_cache_.release_cached_buffers(mem_required - gc_limit_);
|
||||
}
|
||||
@@ -199,27 +191,32 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
||||
// Allocate new buffer if needed
|
||||
size_t res_opt = MTL::ResourceStorageModeShared;
|
||||
res_opt |= MTL::ResourceHazardTrackingModeTracked;
|
||||
lk.unlock();
|
||||
buf = device_->newBuffer(size, res_opt);
|
||||
}
|
||||
|
||||
// Maintain the cache below the requested limit
|
||||
if (pool_size >= max_pool_size_) {
|
||||
auto thread_pool = metal::new_scoped_memory_pool();
|
||||
buffer_cache_.release_cached_buffers(pool_size - max_pool_size_);
|
||||
lk.lock();
|
||||
}
|
||||
|
||||
active_memory_ += buf->length();
|
||||
peak_memory_ = std::max(peak_memory_, active_memory_);
|
||||
|
||||
// Maintain the cache below the requested limit
|
||||
if (get_cache_memory() >= max_pool_size_) {
|
||||
auto thread_pool = metal::new_scoped_memory_pool();
|
||||
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
||||
}
|
||||
|
||||
return Buffer{static_cast<void*>(buf)};
|
||||
}
|
||||
|
||||
void MetalAllocator::free(Buffer buffer) {
|
||||
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
|
||||
std::unique_lock lk(mutex_);
|
||||
active_memory_ -= buf->length();
|
||||
if (max_pool_size_ > 0) {
|
||||
if (get_cache_memory() < max_pool_size_) {
|
||||
buffer_cache_.recycle_to_cache(buf);
|
||||
} else {
|
||||
lk.unlock();
|
||||
auto thread_pool = metal::new_scoped_memory_pool();
|
||||
buf->release();
|
||||
}
|
||||
}
|
||||
|
@@ -19,12 +19,11 @@ class BufferCache {
|
||||
public:
|
||||
BufferCache(MTL::Device* device);
|
||||
~BufferCache();
|
||||
void clear();
|
||||
|
||||
MTL::Buffer* reuse_from_cache(size_t size);
|
||||
void recycle_to_cache(MTL::Buffer* buf);
|
||||
void release_cached_buffers(size_t min_bytes_to_free);
|
||||
size_t pool_size() {
|
||||
size_t cache_size() {
|
||||
return pool_size_;
|
||||
}
|
||||
|
||||
@@ -38,11 +37,11 @@ class BufferCache {
|
||||
MTL::Buffer* buf;
|
||||
};
|
||||
|
||||
void clear();
|
||||
void add_at_head(BufferHolder* to_add);
|
||||
void remove_from_list(BufferHolder* to_remove);
|
||||
|
||||
MTL::Device* device_;
|
||||
std::mutex cache_mutex_;
|
||||
|
||||
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
||||
BufferHolder* head_;
|
||||
@@ -64,7 +63,7 @@ class MetalAllocator : public allocator::Allocator {
|
||||
return peak_memory_;
|
||||
};
|
||||
size_t get_cache_memory() {
|
||||
return buffer_cache_.pool_size();
|
||||
return buffer_cache_.cache_size();
|
||||
};
|
||||
size_t set_cache_limit(size_t limit);
|
||||
size_t set_memory_limit(size_t limit, bool relaxed);
|
||||
@@ -84,6 +83,8 @@ class MetalAllocator : public allocator::Allocator {
|
||||
size_t peak_memory_{0};
|
||||
size_t max_pool_size_;
|
||||
bool relaxed_{true};
|
||||
|
||||
std::mutex mutex_;
|
||||
};
|
||||
|
||||
MetalAllocator& allocator();
|
||||
|
Reference in New Issue
Block a user