Some fixes in cache / thread safety (#777)

* some fixes in cache / thread safety

* speed up no cache case

* fix opt test

* optimizer docs

* otpimizer docs

* fix adafactor

* fix adafactor
This commit is contained in:
Awni Hannun 2024-03-05 13:30:50 -08:00 committed by GitHub
parent 859ae15a54
commit cbcf44a4ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 60 additions and 41 deletions

View File

@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/allocator.h" #include "mlx/backend/metal/allocator.h"
#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal.h"
@ -34,7 +33,6 @@ BufferCache::~BufferCache() {
} }
void BufferCache::clear() { void BufferCache::clear() {
std::lock_guard<std::mutex> lk(cache_mutex_);
for (auto& [size, holder] : buffer_pool_) { for (auto& [size, holder] : buffer_pool_) {
if (holder->buf) if (holder->buf)
holder->buf->release(); holder->buf->release();
@ -47,12 +45,9 @@ void BufferCache::clear() {
} }
MTL::Buffer* BufferCache::reuse_from_cache(size_t size) { MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
std::lock_guard<std::mutex> lk(cache_mutex_);
// Find the closest buffer in pool // Find the closest buffer in pool
MTL::Buffer* pbuf = nullptr; MTL::Buffer* pbuf = nullptr;
// Make sure we use most of the available memory
auto it = buffer_pool_.lower_bound(size); auto it = buffer_pool_.lower_bound(size);
// Make sure we use most of the available memory // Make sure we use most of the available memory
@ -75,8 +70,6 @@ MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
} }
void BufferCache::recycle_to_cache(MTL::Buffer* buf) { void BufferCache::recycle_to_cache(MTL::Buffer* buf) {
std::lock_guard<std::mutex> lk(cache_mutex_);
// Add to cache // Add to cache
if (buf) { if (buf) {
BufferHolder* bh = new BufferHolder(buf); BufferHolder* bh = new BufferHolder(buf);
@ -90,7 +83,6 @@ void BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
if (min_bytes_to_free >= 0.9 * pool_size_) { if (min_bytes_to_free >= 0.9 * pool_size_) {
clear(); clear();
} else { } else {
std::lock_guard<std::mutex> lk(cache_mutex_);
size_t total_bytes_freed = 0; size_t total_bytes_freed = 0;
while (tail_ && (total_bytes_freed < min_bytes_to_free)) { while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
@ -178,10 +170,10 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
} }
// Try the cache // Try the cache
std::unique_lock lk(mutex_);
MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size); MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size);
size_t pool_size = get_cache_memory();
if (!buf) { if (!buf) {
size_t mem_required = get_active_memory() + pool_size + size; size_t mem_required = get_active_memory() + get_cache_memory() + size;
// If there is too much memory pressure, fail (likely causes a wait). // If there is too much memory pressure, fail (likely causes a wait).
if (!(allow_swap && relaxed_) && mem_required >= block_limit_) { if (!(allow_swap && relaxed_) && mem_required >= block_limit_) {
@ -190,8 +182,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
auto thread_pool = metal::new_scoped_memory_pool(); auto thread_pool = metal::new_scoped_memory_pool();
// If we have a lot of memory pressure, check if we can reclaim some memory // If we have a lot of memory pressure or are over the maximum cache size,
// from the cache // try to reclaim memory from the cache
if (mem_required >= gc_limit_) { if (mem_required >= gc_limit_) {
buffer_cache_.release_cached_buffers(mem_required - gc_limit_); buffer_cache_.release_cached_buffers(mem_required - gc_limit_);
} }
@ -199,27 +191,32 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Allocate new buffer if needed // Allocate new buffer if needed
size_t res_opt = MTL::ResourceStorageModeShared; size_t res_opt = MTL::ResourceStorageModeShared;
res_opt |= MTL::ResourceHazardTrackingModeTracked; res_opt |= MTL::ResourceHazardTrackingModeTracked;
lk.unlock();
buf = device_->newBuffer(size, res_opt); buf = device_->newBuffer(size, res_opt);
} lk.lock();
// Maintain the cache below the requested limit
if (pool_size >= max_pool_size_) {
auto thread_pool = metal::new_scoped_memory_pool();
buffer_cache_.release_cached_buffers(pool_size - max_pool_size_);
} }
active_memory_ += buf->length(); active_memory_ += buf->length();
peak_memory_ = std::max(peak_memory_, active_memory_); peak_memory_ = std::max(peak_memory_, active_memory_);
// Maintain the cache below the requested limit
if (get_cache_memory() >= max_pool_size_) {
auto thread_pool = metal::new_scoped_memory_pool();
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
}
return Buffer{static_cast<void*>(buf)}; return Buffer{static_cast<void*>(buf)};
} }
void MetalAllocator::free(Buffer buffer) { void MetalAllocator::free(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr()); auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
std::unique_lock lk(mutex_);
active_memory_ -= buf->length(); active_memory_ -= buf->length();
if (max_pool_size_ > 0) { if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf); buffer_cache_.recycle_to_cache(buf);
} else { } else {
lk.unlock();
auto thread_pool = metal::new_scoped_memory_pool();
buf->release(); buf->release();
} }
} }

View File

@ -19,12 +19,11 @@ class BufferCache {
public: public:
BufferCache(MTL::Device* device); BufferCache(MTL::Device* device);
~BufferCache(); ~BufferCache();
void clear();
MTL::Buffer* reuse_from_cache(size_t size); MTL::Buffer* reuse_from_cache(size_t size);
void recycle_to_cache(MTL::Buffer* buf); void recycle_to_cache(MTL::Buffer* buf);
void release_cached_buffers(size_t min_bytes_to_free); void release_cached_buffers(size_t min_bytes_to_free);
size_t pool_size() { size_t cache_size() {
return pool_size_; return pool_size_;
} }
@ -38,11 +37,11 @@ class BufferCache {
MTL::Buffer* buf; MTL::Buffer* buf;
}; };
void clear();
void add_at_head(BufferHolder* to_add); void add_at_head(BufferHolder* to_add);
void remove_from_list(BufferHolder* to_remove); void remove_from_list(BufferHolder* to_remove);
MTL::Device* device_; MTL::Device* device_;
std::mutex cache_mutex_;
std::multimap<size_t, BufferHolder*> buffer_pool_; std::multimap<size_t, BufferHolder*> buffer_pool_;
BufferHolder* head_; BufferHolder* head_;
@ -64,7 +63,7 @@ class MetalAllocator : public allocator::Allocator {
return peak_memory_; return peak_memory_;
}; };
size_t get_cache_memory() { size_t get_cache_memory() {
return buffer_cache_.pool_size(); return buffer_cache_.cache_size();
}; };
size_t set_cache_limit(size_t limit); size_t set_cache_limit(size_t limit);
size_t set_memory_limit(size_t limit, bool relaxed); size_t set_memory_limit(size_t limit, bool relaxed);
@ -84,6 +83,8 @@ class MetalAllocator : public allocator::Allocator {
size_t peak_memory_{0}; size_t peak_memory_{0};
size_t max_pool_size_; size_t max_pool_size_;
bool relaxed_{true}; bool relaxed_{true};
std::mutex mutex_;
}; };
MetalAllocator& allocator(); MetalAllocator& allocator();

View File

@ -210,14 +210,19 @@ class RMSprop(Optimizer):
w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon} w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}
Args: Args:
learning_rate (float): The learning rate :math:`\lambda`. learning_rate (float or callable): The learning rate :math:`\lambda`.
alpha (float, optional): The smoothing constant :math:`\alpha`. alpha (float, optional): The smoothing constant :math:`\alpha`.
Default: ``0.99`` Default: ``0.99``
eps (float, optional): The term :math:`\epsilon` added to the denominator eps (float, optional): The term :math:`\epsilon` added to the denominator
to improve numerical stability. Default: ``1e-8`` to improve numerical stability. Default: ``1e-8``
""" """
def __init__(self, learning_rate: float, alpha: float = 0.99, eps: float = 1e-8): def __init__(
self,
learning_rate: Union[float, Callable[[mx.array], mx.array]],
alpha: float = 0.99,
eps: float = 1e-8,
):
super().__init__() super().__init__()
self._maybe_schedule("learning_rate", learning_rate) self._maybe_schedule("learning_rate", learning_rate)
@ -264,12 +269,16 @@ class Adagrad(Optimizer):
w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon} w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}
Args: Args:
learning_rate (float): The learning rate :math:`\lambda`. learning_rate (float or callable): The learning rate :math:`\lambda`.
eps (float, optional): The term :math:`\epsilon` added to the eps (float, optional): The term :math:`\epsilon` added to the
denominator to improve numerical stability. Default: ``1e-8`` denominator to improve numerical stability. Default: ``1e-8``
""" """
def __init__(self, learning_rate: float, eps: float = 1e-8): def __init__(
self,
learning_rate: Union[float, Callable[[mx.array], mx.array]],
eps: float = 1e-8,
):
super().__init__() super().__init__()
self._maybe_schedule("learning_rate", learning_rate) self._maybe_schedule("learning_rate", learning_rate)
@ -311,14 +320,19 @@ class AdaDelta(Optimizer):
w_{t+1} &= w_t - \lambda \Delta w_{t+1} w_{t+1} &= w_t - \lambda \Delta w_{t+1}
Args: Args:
learning_rate (float): The learning rate :math:`\lambda`. learning_rate (float or callable): The learning rate :math:`\lambda`.
rho (float, optional): The coefficient :math:`\rho` used for computing a rho (float, optional): The coefficient :math:`\rho` used for computing a
running average of squared gradients. Default: ``0.9`` running average of squared gradients. Default: ``0.9``
eps (float, optional): The term :math:`\epsilon` added to the denominator to improve eps (float, optional): The term :math:`\epsilon` added to the denominator to improve
numerical stability. Default: `1e-8` numerical stability. Default: `1e-8`
""" """
def __init__(self, learning_rate: float, rho: float = 0.9, eps: float = 1e-6): def __init__(
self,
learning_rate: Union[float, Callable[[mx.array], mx.array]],
rho: float = 0.9,
eps: float = 1e-6,
):
super().__init__() super().__init__()
self._maybe_schedule("learning_rate", learning_rate) self._maybe_schedule("learning_rate", learning_rate)
@ -374,7 +388,7 @@ class Adam(Optimizer):
w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}} w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}}
Args: Args:
learning_rate (float): The learning rate :math:`\lambda`. learning_rate (float or callable): The learning rate :math:`\lambda`.
betas (Tuple[float, float], optional): The coefficients betas (Tuple[float, float], optional): The coefficients
:math:`(\beta_1, \beta_2)` used for computing running averages of the :math:`(\beta_1, \beta_2)` used for computing running averages of the
gradient and its square. Default: ``(0.9, 0.999)`` gradient and its square. Default: ``(0.9, 0.999)``
@ -383,7 +397,10 @@ class Adam(Optimizer):
""" """
def __init__( def __init__(
self, learning_rate: float, betas: List[float] = [0.9, 0.999], eps: float = 1e-8 self,
learning_rate: Union[float, Callable[[mx.array], mx.array]],
betas: List[float] = [0.9, 0.999],
eps: float = 1e-8,
): ):
super().__init__() super().__init__()
@ -430,7 +447,7 @@ class AdamW(Adam):
w_{t+1} &= w_t - \alpha (\frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}} + \lambda w_t) w_{t+1} &= w_t - \alpha (\frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}} + \lambda w_t)
Args: Args:
learning_rate (float): The learning rate :math:`\alpha`. learning_rate (float or callable): The learning rate :math:`\alpha`.
betas (Tuple[float, float], optional): The coefficients betas (Tuple[float, float], optional): The coefficients
:math:`(\beta_1, \beta_2)` used for computing running averages of the :math:`(\beta_1, \beta_2)` used for computing running averages of the
gradient and its square. Default: ``(0.9, 0.999)`` gradient and its square. Default: ``(0.9, 0.999)``
@ -442,7 +459,7 @@ class AdamW(Adam):
def __init__( def __init__(
self, self,
learning_rate: float, learning_rate: Union[float, Callable[[mx.array], mx.array]],
betas: List[float] = [0.9, 0.999], betas: List[float] = [0.9, 0.999],
eps: float = 1e-8, eps: float = 1e-8,
weight_decay: float = 0.01, weight_decay: float = 0.01,
@ -477,7 +494,7 @@ class Adamax(Adam):
w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{v_{t+1} + \epsilon} w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{v_{t+1} + \epsilon}
Args: Args:
learning_rate (float): The learning rate :math:`\lambda`. learning_rate (float or callable): The learning rate :math:`\lambda`.
betas (Tuple[float, float], optional): The coefficients betas (Tuple[float, float], optional): The coefficients
:math:`(\beta_1, \beta_2)` used for computing running averages of the :math:`(\beta_1, \beta_2)` used for computing running averages of the
gradient and its square. Default: ``(0.9, 0.999)`` gradient and its square. Default: ``(0.9, 0.999)``
@ -486,7 +503,10 @@ class Adamax(Adam):
""" """
def __init__( def __init__(
self, learning_rate: float, betas: List[float] = [0.9, 0.999], eps: float = 1e-8 self,
learning_rate: Union[float, Callable[[mx.array], mx.array]],
betas: List[float] = [0.9, 0.999],
eps: float = 1e-8,
): ):
super().__init__(learning_rate, betas, eps) super().__init__(learning_rate, betas, eps)
if not 0.0 <= eps: if not 0.0 <= eps:
@ -537,7 +557,7 @@ class Lion(Optimizer):
w_{t + 1} &= w_t - \eta (\text{sign}(c_t) + \lambda w_t) w_{t + 1} &= w_t - \eta (\text{sign}(c_t) + \lambda w_t)
Args: Args:
learning_rate (float): The learning rate :math:`\eta`. learning_rate (float or callable): The learning rate :math:`\eta`.
betas (Tuple[float, float], optional): The coefficients betas (Tuple[float, float], optional): The coefficients
:math:`(\beta_1, \beta_2)` used for computing the gradient :math:`(\beta_1, \beta_2)` used for computing the gradient
momentum and update direction. Default: ``(0.9, 0.99)`` momentum and update direction. Default: ``(0.9, 0.99)``
@ -546,7 +566,7 @@ class Lion(Optimizer):
def __init__( def __init__(
self, self,
learning_rate: float, learning_rate: Union[float, Callable[[mx.array], mx.array]],
betas: List[float] = [0.9, 0.99], betas: List[float] = [0.9, 0.99],
weight_decay: float = 0.0, weight_decay: float = 0.0,
): ):
@ -583,7 +603,8 @@ class Adafactor(Optimizer):
<https://arxiv.org/abs/1804.04235>`_ <https://arxiv.org/abs/1804.04235>`_
Args: Args:
learning_rate (float, optional): The learning rate. Default: ``None``. learning_rate (float or callable, optional): The learning rate.
Default: ``None``.
eps (tuple(float, float), optional): The first term :math:`\epsilon_1` eps (tuple(float, float), optional): The first term :math:`\epsilon_1`
added to the square of the gradients to improve numerical added to the square of the gradients to improve numerical
stability and the second term :math:`\epsilon_2` is used for stability and the second term :math:`\epsilon_2` is used for
@ -610,7 +631,7 @@ class Adafactor(Optimizer):
def __init__( def __init__(
self, self,
learning_rate: Optional[float] = None, learning_rate: Union[float, Callable[[mx.array], mx.array], None] = None,
eps: Tuple[float, float] = (1e-30, 1e-3), eps: Tuple[float, float] = (1e-30, 1e-3),
clip_threshold: float = 1.0, clip_threshold: float = 1.0,
decay_rate: float = -0.8, decay_rate: float = -0.8,

View File

@ -299,16 +299,16 @@ class TestOptimizers(mlx_tests.MLXTestCase):
class TestSchedulers(unittest.TestCase): class TestSchedulers(unittest.TestCase):
def test_decay_lr(self): def test_decay_lr(self):
for optim_class in optimizers_dict.values(): for optim_class in optimizers_dict.values():
lr_schedule = opt.step_decay(1e-1, 0.9, 1000) lr_schedule = opt.step_decay(1e-1, 0.9, 1)
optimizer = optim_class(learning_rate=lr_schedule) optimizer = optim_class(learning_rate=lr_schedule)
params = {"w": mx.ones((5, 5))} params = {"w": mx.ones((5, 5))}
grads = tree_map(lambda x: mx.ones_like(x), params) grads = tree_map(lambda x: mx.ones_like(x), params)
for it in range(10): for it in range(10):
optimizer.apply_gradients(grads, params)
expected_lr = 0.1 * (0.9**it) expected_lr = 0.1 * (0.9**it)
self.assertAlmostEqual(optimizer.learning_rate, expected_lr, delta=1e-7) self.assertAlmostEqual(optimizer.learning_rate, expected_lr, delta=1e-7)
return optimizer.apply_gradients(grads, params)
def test_step_decay(self): def test_step_decay(self):
lr_schedule = opt.step_decay(1e-1, 0.9, 1000) lr_schedule = opt.step_decay(1e-1, 0.9, 1000)