From cbcf44a4caf3fb504ed29ef78091126134e197a3 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 5 Mar 2024 13:30:50 -0800 Subject: [PATCH] Some fixes in cache / thread safety (#777) * some fixes in cache / thread safety * speed up no cache case * fix opt test * optimizer docs * otpimizer docs * fix adafactor * fix adafactor --- mlx/backend/metal/allocator.cpp | 35 +++++++++---------- mlx/backend/metal/allocator.h | 9 ++--- python/mlx/optimizers/optimizers.py | 53 ++++++++++++++++++++--------- python/tests/test_optimizers.py | 4 +-- 4 files changed, 60 insertions(+), 41 deletions(-) diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index d8e4538ae..286388003 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -1,5 +1,4 @@ // Copyright © 2023-2024 Apple Inc. - #include "mlx/backend/metal/allocator.h" #include "mlx/backend/metal/metal.h" @@ -34,7 +33,6 @@ BufferCache::~BufferCache() { } void BufferCache::clear() { - std::lock_guard lk(cache_mutex_); for (auto& [size, holder] : buffer_pool_) { if (holder->buf) holder->buf->release(); @@ -47,12 +45,9 @@ void BufferCache::clear() { } MTL::Buffer* BufferCache::reuse_from_cache(size_t size) { - std::lock_guard lk(cache_mutex_); - // Find the closest buffer in pool MTL::Buffer* pbuf = nullptr; - // Make sure we use most of the available memory auto it = buffer_pool_.lower_bound(size); // Make sure we use most of the available memory @@ -75,8 +70,6 @@ MTL::Buffer* BufferCache::reuse_from_cache(size_t size) { } void BufferCache::recycle_to_cache(MTL::Buffer* buf) { - std::lock_guard lk(cache_mutex_); - // Add to cache if (buf) { BufferHolder* bh = new BufferHolder(buf); @@ -90,7 +83,6 @@ void BufferCache::release_cached_buffers(size_t min_bytes_to_free) { if (min_bytes_to_free >= 0.9 * pool_size_) { clear(); } else { - std::lock_guard lk(cache_mutex_); size_t total_bytes_freed = 0; while (tail_ && (total_bytes_freed < min_bytes_to_free)) { @@ -178,10 +170,10 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { } // Try the cache + std::unique_lock lk(mutex_); MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size); - size_t pool_size = get_cache_memory(); if (!buf) { - size_t mem_required = get_active_memory() + pool_size + size; + size_t mem_required = get_active_memory() + get_cache_memory() + size; // If there is too much memory pressure, fail (likely causes a wait). if (!(allow_swap && relaxed_) && mem_required >= block_limit_) { @@ -190,8 +182,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { auto thread_pool = metal::new_scoped_memory_pool(); - // If we have a lot of memory pressure, check if we can reclaim some memory - // from the cache + // If we have a lot of memory pressure or are over the maximum cache size, + // try to reclaim memory from the cache if (mem_required >= gc_limit_) { buffer_cache_.release_cached_buffers(mem_required - gc_limit_); } @@ -199,27 +191,32 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { // Allocate new buffer if needed size_t res_opt = MTL::ResourceStorageModeShared; res_opt |= MTL::ResourceHazardTrackingModeTracked; + lk.unlock(); buf = device_->newBuffer(size, res_opt); - } - - // Maintain the cache below the requested limit - if (pool_size >= max_pool_size_) { - auto thread_pool = metal::new_scoped_memory_pool(); - buffer_cache_.release_cached_buffers(pool_size - max_pool_size_); + lk.lock(); } active_memory_ += buf->length(); peak_memory_ = std::max(peak_memory_, active_memory_); + // Maintain the cache below the requested limit + if (get_cache_memory() >= max_pool_size_) { + auto thread_pool = metal::new_scoped_memory_pool(); + buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_); + } + return Buffer{static_cast(buf)}; } void MetalAllocator::free(Buffer buffer) { auto buf = static_cast(buffer.ptr()); + std::unique_lock lk(mutex_); active_memory_ -= buf->length(); - if (max_pool_size_ > 0) { + if (get_cache_memory() < max_pool_size_) { buffer_cache_.recycle_to_cache(buf); } else { + lk.unlock(); + auto thread_pool = metal::new_scoped_memory_pool(); buf->release(); } } diff --git a/mlx/backend/metal/allocator.h b/mlx/backend/metal/allocator.h index a31cb5fb4..9f6c0ec9b 100644 --- a/mlx/backend/metal/allocator.h +++ b/mlx/backend/metal/allocator.h @@ -19,12 +19,11 @@ class BufferCache { public: BufferCache(MTL::Device* device); ~BufferCache(); - void clear(); MTL::Buffer* reuse_from_cache(size_t size); void recycle_to_cache(MTL::Buffer* buf); void release_cached_buffers(size_t min_bytes_to_free); - size_t pool_size() { + size_t cache_size() { return pool_size_; } @@ -38,11 +37,11 @@ class BufferCache { MTL::Buffer* buf; }; + void clear(); void add_at_head(BufferHolder* to_add); void remove_from_list(BufferHolder* to_remove); MTL::Device* device_; - std::mutex cache_mutex_; std::multimap buffer_pool_; BufferHolder* head_; @@ -64,7 +63,7 @@ class MetalAllocator : public allocator::Allocator { return peak_memory_; }; size_t get_cache_memory() { - return buffer_cache_.pool_size(); + return buffer_cache_.cache_size(); }; size_t set_cache_limit(size_t limit); size_t set_memory_limit(size_t limit, bool relaxed); @@ -84,6 +83,8 @@ class MetalAllocator : public allocator::Allocator { size_t peak_memory_{0}; size_t max_pool_size_; bool relaxed_{true}; + + std::mutex mutex_; }; MetalAllocator& allocator(); diff --git a/python/mlx/optimizers/optimizers.py b/python/mlx/optimizers/optimizers.py index 16928625f..054466f90 100644 --- a/python/mlx/optimizers/optimizers.py +++ b/python/mlx/optimizers/optimizers.py @@ -210,14 +210,19 @@ class RMSprop(Optimizer): w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon} Args: - learning_rate (float): The learning rate :math:`\lambda`. + learning_rate (float or callable): The learning rate :math:`\lambda`. alpha (float, optional): The smoothing constant :math:`\alpha`. Default: ``0.99`` eps (float, optional): The term :math:`\epsilon` added to the denominator to improve numerical stability. Default: ``1e-8`` """ - def __init__(self, learning_rate: float, alpha: float = 0.99, eps: float = 1e-8): + def __init__( + self, + learning_rate: Union[float, Callable[[mx.array], mx.array]], + alpha: float = 0.99, + eps: float = 1e-8, + ): super().__init__() self._maybe_schedule("learning_rate", learning_rate) @@ -264,12 +269,16 @@ class Adagrad(Optimizer): w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon} Args: - learning_rate (float): The learning rate :math:`\lambda`. + learning_rate (float or callable): The learning rate :math:`\lambda`. eps (float, optional): The term :math:`\epsilon` added to the denominator to improve numerical stability. Default: ``1e-8`` """ - def __init__(self, learning_rate: float, eps: float = 1e-8): + def __init__( + self, + learning_rate: Union[float, Callable[[mx.array], mx.array]], + eps: float = 1e-8, + ): super().__init__() self._maybe_schedule("learning_rate", learning_rate) @@ -311,14 +320,19 @@ class AdaDelta(Optimizer): w_{t+1} &= w_t - \lambda \Delta w_{t+1} Args: - learning_rate (float): The learning rate :math:`\lambda`. + learning_rate (float or callable): The learning rate :math:`\lambda`. rho (float, optional): The coefficient :math:`\rho` used for computing a running average of squared gradients. Default: ``0.9`` eps (float, optional): The term :math:`\epsilon` added to the denominator to improve numerical stability. Default: `1e-8` """ - def __init__(self, learning_rate: float, rho: float = 0.9, eps: float = 1e-6): + def __init__( + self, + learning_rate: Union[float, Callable[[mx.array], mx.array]], + rho: float = 0.9, + eps: float = 1e-6, + ): super().__init__() self._maybe_schedule("learning_rate", learning_rate) @@ -374,7 +388,7 @@ class Adam(Optimizer): w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}} Args: - learning_rate (float): The learning rate :math:`\lambda`. + learning_rate (float or callable): The learning rate :math:`\lambda`. betas (Tuple[float, float], optional): The coefficients :math:`(\beta_1, \beta_2)` used for computing running averages of the gradient and its square. Default: ``(0.9, 0.999)`` @@ -383,7 +397,10 @@ class Adam(Optimizer): """ def __init__( - self, learning_rate: float, betas: List[float] = [0.9, 0.999], eps: float = 1e-8 + self, + learning_rate: Union[float, Callable[[mx.array], mx.array]], + betas: List[float] = [0.9, 0.999], + eps: float = 1e-8, ): super().__init__() @@ -430,7 +447,7 @@ class AdamW(Adam): w_{t+1} &= w_t - \alpha (\frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}} + \lambda w_t) Args: - learning_rate (float): The learning rate :math:`\alpha`. + learning_rate (float or callable): The learning rate :math:`\alpha`. betas (Tuple[float, float], optional): The coefficients :math:`(\beta_1, \beta_2)` used for computing running averages of the gradient and its square. Default: ``(0.9, 0.999)`` @@ -442,7 +459,7 @@ class AdamW(Adam): def __init__( self, - learning_rate: float, + learning_rate: Union[float, Callable[[mx.array], mx.array]], betas: List[float] = [0.9, 0.999], eps: float = 1e-8, weight_decay: float = 0.01, @@ -477,7 +494,7 @@ class Adamax(Adam): w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{v_{t+1} + \epsilon} Args: - learning_rate (float): The learning rate :math:`\lambda`. + learning_rate (float or callable): The learning rate :math:`\lambda`. betas (Tuple[float, float], optional): The coefficients :math:`(\beta_1, \beta_2)` used for computing running averages of the gradient and its square. Default: ``(0.9, 0.999)`` @@ -486,7 +503,10 @@ class Adamax(Adam): """ def __init__( - self, learning_rate: float, betas: List[float] = [0.9, 0.999], eps: float = 1e-8 + self, + learning_rate: Union[float, Callable[[mx.array], mx.array]], + betas: List[float] = [0.9, 0.999], + eps: float = 1e-8, ): super().__init__(learning_rate, betas, eps) if not 0.0 <= eps: @@ -537,7 +557,7 @@ class Lion(Optimizer): w_{t + 1} &= w_t - \eta (\text{sign}(c_t) + \lambda w_t) Args: - learning_rate (float): The learning rate :math:`\eta`. + learning_rate (float or callable): The learning rate :math:`\eta`. betas (Tuple[float, float], optional): The coefficients :math:`(\beta_1, \beta_2)` used for computing the gradient momentum and update direction. Default: ``(0.9, 0.99)`` @@ -546,7 +566,7 @@ class Lion(Optimizer): def __init__( self, - learning_rate: float, + learning_rate: Union[float, Callable[[mx.array], mx.array]], betas: List[float] = [0.9, 0.99], weight_decay: float = 0.0, ): @@ -583,7 +603,8 @@ class Adafactor(Optimizer): `_ Args: - learning_rate (float, optional): The learning rate. Default: ``None``. + learning_rate (float or callable, optional): The learning rate. + Default: ``None``. eps (tuple(float, float), optional): The first term :math:`\epsilon_1` added to the square of the gradients to improve numerical stability and the second term :math:`\epsilon_2` is used for @@ -610,7 +631,7 @@ class Adafactor(Optimizer): def __init__( self, - learning_rate: Optional[float] = None, + learning_rate: Union[float, Callable[[mx.array], mx.array], None] = None, eps: Tuple[float, float] = (1e-30, 1e-3), clip_threshold: float = 1.0, decay_rate: float = -0.8, diff --git a/python/tests/test_optimizers.py b/python/tests/test_optimizers.py index 5c28938dc..2a20f4b1a 100644 --- a/python/tests/test_optimizers.py +++ b/python/tests/test_optimizers.py @@ -299,16 +299,16 @@ class TestOptimizers(mlx_tests.MLXTestCase): class TestSchedulers(unittest.TestCase): def test_decay_lr(self): for optim_class in optimizers_dict.values(): - lr_schedule = opt.step_decay(1e-1, 0.9, 1000) + lr_schedule = opt.step_decay(1e-1, 0.9, 1) optimizer = optim_class(learning_rate=lr_schedule) params = {"w": mx.ones((5, 5))} grads = tree_map(lambda x: mx.ones_like(x), params) for it in range(10): + optimizer.apply_gradients(grads, params) expected_lr = 0.1 * (0.9**it) self.assertAlmostEqual(optimizer.learning_rate, expected_lr, delta=1e-7) - return optimizer.apply_gradients(grads, params) def test_step_decay(self): lr_schedule = opt.step_decay(1e-1, 0.9, 1000)