mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-29 22:01:17 +08:00
Some fixes in cache / thread safety (#777)
* some fixes in cache / thread safety * speed up no cache case * fix opt test * optimizer docs * otpimizer docs * fix adafactor * fix adafactor
This commit is contained in:
parent
859ae15a54
commit
cbcf44a4ca
@ -1,5 +1,4 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/metal/allocator.h"
|
#include "mlx/backend/metal/allocator.h"
|
||||||
#include "mlx/backend/metal/metal.h"
|
#include "mlx/backend/metal/metal.h"
|
||||||
|
|
||||||
@ -34,7 +33,6 @@ BufferCache::~BufferCache() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void BufferCache::clear() {
|
void BufferCache::clear() {
|
||||||
std::lock_guard<std::mutex> lk(cache_mutex_);
|
|
||||||
for (auto& [size, holder] : buffer_pool_) {
|
for (auto& [size, holder] : buffer_pool_) {
|
||||||
if (holder->buf)
|
if (holder->buf)
|
||||||
holder->buf->release();
|
holder->buf->release();
|
||||||
@ -47,12 +45,9 @@ void BufferCache::clear() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
|
MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
|
||||||
std::lock_guard<std::mutex> lk(cache_mutex_);
|
|
||||||
|
|
||||||
// Find the closest buffer in pool
|
// Find the closest buffer in pool
|
||||||
MTL::Buffer* pbuf = nullptr;
|
MTL::Buffer* pbuf = nullptr;
|
||||||
|
|
||||||
// Make sure we use most of the available memory
|
|
||||||
auto it = buffer_pool_.lower_bound(size);
|
auto it = buffer_pool_.lower_bound(size);
|
||||||
|
|
||||||
// Make sure we use most of the available memory
|
// Make sure we use most of the available memory
|
||||||
@ -75,8 +70,6 @@ MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void BufferCache::recycle_to_cache(MTL::Buffer* buf) {
|
void BufferCache::recycle_to_cache(MTL::Buffer* buf) {
|
||||||
std::lock_guard<std::mutex> lk(cache_mutex_);
|
|
||||||
|
|
||||||
// Add to cache
|
// Add to cache
|
||||||
if (buf) {
|
if (buf) {
|
||||||
BufferHolder* bh = new BufferHolder(buf);
|
BufferHolder* bh = new BufferHolder(buf);
|
||||||
@ -90,7 +83,6 @@ void BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
|
|||||||
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
||||||
clear();
|
clear();
|
||||||
} else {
|
} else {
|
||||||
std::lock_guard<std::mutex> lk(cache_mutex_);
|
|
||||||
size_t total_bytes_freed = 0;
|
size_t total_bytes_freed = 0;
|
||||||
|
|
||||||
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
||||||
@ -178,10 +170,10 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Try the cache
|
// Try the cache
|
||||||
|
std::unique_lock lk(mutex_);
|
||||||
MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size);
|
MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||||
size_t pool_size = get_cache_memory();
|
|
||||||
if (!buf) {
|
if (!buf) {
|
||||||
size_t mem_required = get_active_memory() + pool_size + size;
|
size_t mem_required = get_active_memory() + get_cache_memory() + size;
|
||||||
|
|
||||||
// If there is too much memory pressure, fail (likely causes a wait).
|
// If there is too much memory pressure, fail (likely causes a wait).
|
||||||
if (!(allow_swap && relaxed_) && mem_required >= block_limit_) {
|
if (!(allow_swap && relaxed_) && mem_required >= block_limit_) {
|
||||||
@ -190,8 +182,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
|||||||
|
|
||||||
auto thread_pool = metal::new_scoped_memory_pool();
|
auto thread_pool = metal::new_scoped_memory_pool();
|
||||||
|
|
||||||
// If we have a lot of memory pressure, check if we can reclaim some memory
|
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||||
// from the cache
|
// try to reclaim memory from the cache
|
||||||
if (mem_required >= gc_limit_) {
|
if (mem_required >= gc_limit_) {
|
||||||
buffer_cache_.release_cached_buffers(mem_required - gc_limit_);
|
buffer_cache_.release_cached_buffers(mem_required - gc_limit_);
|
||||||
}
|
}
|
||||||
@ -199,27 +191,32 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
|||||||
// Allocate new buffer if needed
|
// Allocate new buffer if needed
|
||||||
size_t res_opt = MTL::ResourceStorageModeShared;
|
size_t res_opt = MTL::ResourceStorageModeShared;
|
||||||
res_opt |= MTL::ResourceHazardTrackingModeTracked;
|
res_opt |= MTL::ResourceHazardTrackingModeTracked;
|
||||||
|
lk.unlock();
|
||||||
buf = device_->newBuffer(size, res_opt);
|
buf = device_->newBuffer(size, res_opt);
|
||||||
}
|
lk.lock();
|
||||||
|
|
||||||
// Maintain the cache below the requested limit
|
|
||||||
if (pool_size >= max_pool_size_) {
|
|
||||||
auto thread_pool = metal::new_scoped_memory_pool();
|
|
||||||
buffer_cache_.release_cached_buffers(pool_size - max_pool_size_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
active_memory_ += buf->length();
|
active_memory_ += buf->length();
|
||||||
peak_memory_ = std::max(peak_memory_, active_memory_);
|
peak_memory_ = std::max(peak_memory_, active_memory_);
|
||||||
|
|
||||||
|
// Maintain the cache below the requested limit
|
||||||
|
if (get_cache_memory() >= max_pool_size_) {
|
||||||
|
auto thread_pool = metal::new_scoped_memory_pool();
|
||||||
|
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
||||||
|
}
|
||||||
|
|
||||||
return Buffer{static_cast<void*>(buf)};
|
return Buffer{static_cast<void*>(buf)};
|
||||||
}
|
}
|
||||||
|
|
||||||
void MetalAllocator::free(Buffer buffer) {
|
void MetalAllocator::free(Buffer buffer) {
|
||||||
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
|
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
|
||||||
|
std::unique_lock lk(mutex_);
|
||||||
active_memory_ -= buf->length();
|
active_memory_ -= buf->length();
|
||||||
if (max_pool_size_ > 0) {
|
if (get_cache_memory() < max_pool_size_) {
|
||||||
buffer_cache_.recycle_to_cache(buf);
|
buffer_cache_.recycle_to_cache(buf);
|
||||||
} else {
|
} else {
|
||||||
|
lk.unlock();
|
||||||
|
auto thread_pool = metal::new_scoped_memory_pool();
|
||||||
buf->release();
|
buf->release();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -19,12 +19,11 @@ class BufferCache {
|
|||||||
public:
|
public:
|
||||||
BufferCache(MTL::Device* device);
|
BufferCache(MTL::Device* device);
|
||||||
~BufferCache();
|
~BufferCache();
|
||||||
void clear();
|
|
||||||
|
|
||||||
MTL::Buffer* reuse_from_cache(size_t size);
|
MTL::Buffer* reuse_from_cache(size_t size);
|
||||||
void recycle_to_cache(MTL::Buffer* buf);
|
void recycle_to_cache(MTL::Buffer* buf);
|
||||||
void release_cached_buffers(size_t min_bytes_to_free);
|
void release_cached_buffers(size_t min_bytes_to_free);
|
||||||
size_t pool_size() {
|
size_t cache_size() {
|
||||||
return pool_size_;
|
return pool_size_;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -38,11 +37,11 @@ class BufferCache {
|
|||||||
MTL::Buffer* buf;
|
MTL::Buffer* buf;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
void clear();
|
||||||
void add_at_head(BufferHolder* to_add);
|
void add_at_head(BufferHolder* to_add);
|
||||||
void remove_from_list(BufferHolder* to_remove);
|
void remove_from_list(BufferHolder* to_remove);
|
||||||
|
|
||||||
MTL::Device* device_;
|
MTL::Device* device_;
|
||||||
std::mutex cache_mutex_;
|
|
||||||
|
|
||||||
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
||||||
BufferHolder* head_;
|
BufferHolder* head_;
|
||||||
@ -64,7 +63,7 @@ class MetalAllocator : public allocator::Allocator {
|
|||||||
return peak_memory_;
|
return peak_memory_;
|
||||||
};
|
};
|
||||||
size_t get_cache_memory() {
|
size_t get_cache_memory() {
|
||||||
return buffer_cache_.pool_size();
|
return buffer_cache_.cache_size();
|
||||||
};
|
};
|
||||||
size_t set_cache_limit(size_t limit);
|
size_t set_cache_limit(size_t limit);
|
||||||
size_t set_memory_limit(size_t limit, bool relaxed);
|
size_t set_memory_limit(size_t limit, bool relaxed);
|
||||||
@ -84,6 +83,8 @@ class MetalAllocator : public allocator::Allocator {
|
|||||||
size_t peak_memory_{0};
|
size_t peak_memory_{0};
|
||||||
size_t max_pool_size_;
|
size_t max_pool_size_;
|
||||||
bool relaxed_{true};
|
bool relaxed_{true};
|
||||||
|
|
||||||
|
std::mutex mutex_;
|
||||||
};
|
};
|
||||||
|
|
||||||
MetalAllocator& allocator();
|
MetalAllocator& allocator();
|
||||||
|
@ -210,14 +210,19 @@ class RMSprop(Optimizer):
|
|||||||
w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}
|
w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
learning_rate (float): The learning rate :math:`\lambda`.
|
learning_rate (float or callable): The learning rate :math:`\lambda`.
|
||||||
alpha (float, optional): The smoothing constant :math:`\alpha`.
|
alpha (float, optional): The smoothing constant :math:`\alpha`.
|
||||||
Default: ``0.99``
|
Default: ``0.99``
|
||||||
eps (float, optional): The term :math:`\epsilon` added to the denominator
|
eps (float, optional): The term :math:`\epsilon` added to the denominator
|
||||||
to improve numerical stability. Default: ``1e-8``
|
to improve numerical stability. Default: ``1e-8``
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, learning_rate: float, alpha: float = 0.99, eps: float = 1e-8):
|
def __init__(
|
||||||
|
self,
|
||||||
|
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
||||||
|
alpha: float = 0.99,
|
||||||
|
eps: float = 1e-8,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self._maybe_schedule("learning_rate", learning_rate)
|
self._maybe_schedule("learning_rate", learning_rate)
|
||||||
@ -264,12 +269,16 @@ class Adagrad(Optimizer):
|
|||||||
w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}
|
w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
learning_rate (float): The learning rate :math:`\lambda`.
|
learning_rate (float or callable): The learning rate :math:`\lambda`.
|
||||||
eps (float, optional): The term :math:`\epsilon` added to the
|
eps (float, optional): The term :math:`\epsilon` added to the
|
||||||
denominator to improve numerical stability. Default: ``1e-8``
|
denominator to improve numerical stability. Default: ``1e-8``
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, learning_rate: float, eps: float = 1e-8):
|
def __init__(
|
||||||
|
self,
|
||||||
|
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
||||||
|
eps: float = 1e-8,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self._maybe_schedule("learning_rate", learning_rate)
|
self._maybe_schedule("learning_rate", learning_rate)
|
||||||
@ -311,14 +320,19 @@ class AdaDelta(Optimizer):
|
|||||||
w_{t+1} &= w_t - \lambda \Delta w_{t+1}
|
w_{t+1} &= w_t - \lambda \Delta w_{t+1}
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
learning_rate (float): The learning rate :math:`\lambda`.
|
learning_rate (float or callable): The learning rate :math:`\lambda`.
|
||||||
rho (float, optional): The coefficient :math:`\rho` used for computing a
|
rho (float, optional): The coefficient :math:`\rho` used for computing a
|
||||||
running average of squared gradients. Default: ``0.9``
|
running average of squared gradients. Default: ``0.9``
|
||||||
eps (float, optional): The term :math:`\epsilon` added to the denominator to improve
|
eps (float, optional): The term :math:`\epsilon` added to the denominator to improve
|
||||||
numerical stability. Default: `1e-8`
|
numerical stability. Default: `1e-8`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, learning_rate: float, rho: float = 0.9, eps: float = 1e-6):
|
def __init__(
|
||||||
|
self,
|
||||||
|
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
||||||
|
rho: float = 0.9,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self._maybe_schedule("learning_rate", learning_rate)
|
self._maybe_schedule("learning_rate", learning_rate)
|
||||||
@ -374,7 +388,7 @@ class Adam(Optimizer):
|
|||||||
w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}}
|
w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}}
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
learning_rate (float): The learning rate :math:`\lambda`.
|
learning_rate (float or callable): The learning rate :math:`\lambda`.
|
||||||
betas (Tuple[float, float], optional): The coefficients
|
betas (Tuple[float, float], optional): The coefficients
|
||||||
:math:`(\beta_1, \beta_2)` used for computing running averages of the
|
:math:`(\beta_1, \beta_2)` used for computing running averages of the
|
||||||
gradient and its square. Default: ``(0.9, 0.999)``
|
gradient and its square. Default: ``(0.9, 0.999)``
|
||||||
@ -383,7 +397,10 @@ class Adam(Optimizer):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, learning_rate: float, betas: List[float] = [0.9, 0.999], eps: float = 1e-8
|
self,
|
||||||
|
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
||||||
|
betas: List[float] = [0.9, 0.999],
|
||||||
|
eps: float = 1e-8,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -430,7 +447,7 @@ class AdamW(Adam):
|
|||||||
w_{t+1} &= w_t - \alpha (\frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}} + \lambda w_t)
|
w_{t+1} &= w_t - \alpha (\frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}} + \lambda w_t)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
learning_rate (float): The learning rate :math:`\alpha`.
|
learning_rate (float or callable): The learning rate :math:`\alpha`.
|
||||||
betas (Tuple[float, float], optional): The coefficients
|
betas (Tuple[float, float], optional): The coefficients
|
||||||
:math:`(\beta_1, \beta_2)` used for computing running averages of the
|
:math:`(\beta_1, \beta_2)` used for computing running averages of the
|
||||||
gradient and its square. Default: ``(0.9, 0.999)``
|
gradient and its square. Default: ``(0.9, 0.999)``
|
||||||
@ -442,7 +459,7 @@ class AdamW(Adam):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
learning_rate: float,
|
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
||||||
betas: List[float] = [0.9, 0.999],
|
betas: List[float] = [0.9, 0.999],
|
||||||
eps: float = 1e-8,
|
eps: float = 1e-8,
|
||||||
weight_decay: float = 0.01,
|
weight_decay: float = 0.01,
|
||||||
@ -477,7 +494,7 @@ class Adamax(Adam):
|
|||||||
w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{v_{t+1} + \epsilon}
|
w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{v_{t+1} + \epsilon}
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
learning_rate (float): The learning rate :math:`\lambda`.
|
learning_rate (float or callable): The learning rate :math:`\lambda`.
|
||||||
betas (Tuple[float, float], optional): The coefficients
|
betas (Tuple[float, float], optional): The coefficients
|
||||||
:math:`(\beta_1, \beta_2)` used for computing running averages of the
|
:math:`(\beta_1, \beta_2)` used for computing running averages of the
|
||||||
gradient and its square. Default: ``(0.9, 0.999)``
|
gradient and its square. Default: ``(0.9, 0.999)``
|
||||||
@ -486,7 +503,10 @@ class Adamax(Adam):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, learning_rate: float, betas: List[float] = [0.9, 0.999], eps: float = 1e-8
|
self,
|
||||||
|
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
||||||
|
betas: List[float] = [0.9, 0.999],
|
||||||
|
eps: float = 1e-8,
|
||||||
):
|
):
|
||||||
super().__init__(learning_rate, betas, eps)
|
super().__init__(learning_rate, betas, eps)
|
||||||
if not 0.0 <= eps:
|
if not 0.0 <= eps:
|
||||||
@ -537,7 +557,7 @@ class Lion(Optimizer):
|
|||||||
w_{t + 1} &= w_t - \eta (\text{sign}(c_t) + \lambda w_t)
|
w_{t + 1} &= w_t - \eta (\text{sign}(c_t) + \lambda w_t)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
learning_rate (float): The learning rate :math:`\eta`.
|
learning_rate (float or callable): The learning rate :math:`\eta`.
|
||||||
betas (Tuple[float, float], optional): The coefficients
|
betas (Tuple[float, float], optional): The coefficients
|
||||||
:math:`(\beta_1, \beta_2)` used for computing the gradient
|
:math:`(\beta_1, \beta_2)` used for computing the gradient
|
||||||
momentum and update direction. Default: ``(0.9, 0.99)``
|
momentum and update direction. Default: ``(0.9, 0.99)``
|
||||||
@ -546,7 +566,7 @@ class Lion(Optimizer):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
learning_rate: float,
|
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
||||||
betas: List[float] = [0.9, 0.99],
|
betas: List[float] = [0.9, 0.99],
|
||||||
weight_decay: float = 0.0,
|
weight_decay: float = 0.0,
|
||||||
):
|
):
|
||||||
@ -583,7 +603,8 @@ class Adafactor(Optimizer):
|
|||||||
<https://arxiv.org/abs/1804.04235>`_
|
<https://arxiv.org/abs/1804.04235>`_
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
learning_rate (float, optional): The learning rate. Default: ``None``.
|
learning_rate (float or callable, optional): The learning rate.
|
||||||
|
Default: ``None``.
|
||||||
eps (tuple(float, float), optional): The first term :math:`\epsilon_1`
|
eps (tuple(float, float), optional): The first term :math:`\epsilon_1`
|
||||||
added to the square of the gradients to improve numerical
|
added to the square of the gradients to improve numerical
|
||||||
stability and the second term :math:`\epsilon_2` is used for
|
stability and the second term :math:`\epsilon_2` is used for
|
||||||
@ -610,7 +631,7 @@ class Adafactor(Optimizer):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
learning_rate: Optional[float] = None,
|
learning_rate: Union[float, Callable[[mx.array], mx.array], None] = None,
|
||||||
eps: Tuple[float, float] = (1e-30, 1e-3),
|
eps: Tuple[float, float] = (1e-30, 1e-3),
|
||||||
clip_threshold: float = 1.0,
|
clip_threshold: float = 1.0,
|
||||||
decay_rate: float = -0.8,
|
decay_rate: float = -0.8,
|
||||||
|
@ -299,16 +299,16 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
|||||||
class TestSchedulers(unittest.TestCase):
|
class TestSchedulers(unittest.TestCase):
|
||||||
def test_decay_lr(self):
|
def test_decay_lr(self):
|
||||||
for optim_class in optimizers_dict.values():
|
for optim_class in optimizers_dict.values():
|
||||||
lr_schedule = opt.step_decay(1e-1, 0.9, 1000)
|
lr_schedule = opt.step_decay(1e-1, 0.9, 1)
|
||||||
optimizer = optim_class(learning_rate=lr_schedule)
|
optimizer = optim_class(learning_rate=lr_schedule)
|
||||||
|
|
||||||
params = {"w": mx.ones((5, 5))}
|
params = {"w": mx.ones((5, 5))}
|
||||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||||
|
|
||||||
for it in range(10):
|
for it in range(10):
|
||||||
|
optimizer.apply_gradients(grads, params)
|
||||||
expected_lr = 0.1 * (0.9**it)
|
expected_lr = 0.1 * (0.9**it)
|
||||||
self.assertAlmostEqual(optimizer.learning_rate, expected_lr, delta=1e-7)
|
self.assertAlmostEqual(optimizer.learning_rate, expected_lr, delta=1e-7)
|
||||||
return optimizer.apply_gradients(grads, params)
|
|
||||||
|
|
||||||
def test_step_decay(self):
|
def test_step_decay(self):
|
||||||
lr_schedule = opt.step_decay(1e-1, 0.9, 1000)
|
lr_schedule = opt.step_decay(1e-1, 0.9, 1000)
|
||||||
|
Loading…
Reference in New Issue
Block a user