mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Support LR schedulers (#334)
* Add a few LR schedulers * Move parents's constructor call to the top * Fix docstring * refactor optimizers into two files * add docs * nit * Fix Callable type annotation for python 3.8 --------- Co-authored-by: Awni Hannun <awni@apple.com> Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
parent
85143fecdd
commit
818cda16bc
1
docs/.gitignore
vendored
1
docs/.gitignore
vendored
@ -1,2 +1,3 @@
|
|||||||
src/python/_autosummary*/
|
src/python/_autosummary*/
|
||||||
src/python/nn/_autosummary*/
|
src/python/nn/_autosummary*/
|
||||||
|
src/python/optimizers/_autosummary*/
|
||||||
|
@ -31,20 +31,6 @@ model's parameters and the **optimizer state**.
|
|||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
|
|
||||||
optimizer
|
optimizers/optimizer
|
||||||
|
optimizers/common_optimizers
|
||||||
.. currentmodule:: mlx.optimizers
|
optimizers/schedulers
|
||||||
|
|
||||||
.. autosummary::
|
|
||||||
:toctree: _autosummary
|
|
||||||
:template: optimizers-template.rst
|
|
||||||
|
|
||||||
SGD
|
|
||||||
RMSprop
|
|
||||||
Adagrad
|
|
||||||
Adafactor
|
|
||||||
AdaDelta
|
|
||||||
Adam
|
|
||||||
AdamW
|
|
||||||
Adamax
|
|
||||||
Lion
|
|
||||||
|
20
docs/src/python/optimizers/common_optimizers.rst
Normal file
20
docs/src/python/optimizers/common_optimizers.rst
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
.. _common_optimizers:
|
||||||
|
|
||||||
|
Common Optimizers
|
||||||
|
=================
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.optimizers
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
:template: optimizers-template.rst
|
||||||
|
|
||||||
|
SGD
|
||||||
|
RMSprop
|
||||||
|
Adagrad
|
||||||
|
Adafactor
|
||||||
|
AdaDelta
|
||||||
|
Adam
|
||||||
|
AdamW
|
||||||
|
Adamax
|
||||||
|
Lion
|
13
docs/src/python/optimizers/schedulers.rst
Normal file
13
docs/src/python/optimizers/schedulers.rst
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
.. _schedulers:
|
||||||
|
|
||||||
|
Schedulers
|
||||||
|
==========
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.optimizers
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
step_decay
|
||||||
|
exponential_decay
|
||||||
|
cosine_decay
|
4
python/mlx/optimizers/__init__.py
Normal file
4
python/mlx/optimizers/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
from mlx.optimizers.optimizers import *
|
||||||
|
from mlx.optimizers.schedulers import *
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional, Tuple
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.utils import tree_map
|
from mlx.utils import tree_map
|
||||||
@ -12,9 +12,10 @@ class Optimizer:
|
|||||||
optimizer on a per-parameter basis and apply it to a parameter tree.
|
optimizer on a per-parameter basis and apply it to a parameter tree.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, schedulers=None):
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
self._state = {}
|
self._state = {"step": mx.array(0, mx.uint64)}
|
||||||
|
self._schedulers = {k: v for k, v in (schedulers or {}).items()}
|
||||||
|
|
||||||
def update(self, model: "mlx.nn.Module", gradients: dict):
|
def update(self, model: "mlx.nn.Module", gradients: dict):
|
||||||
"""Apply the gradients to the parameters of the model and update the
|
"""Apply the gradients to the parameters of the model and update the
|
||||||
@ -44,9 +45,8 @@ class Optimizer:
|
|||||||
>>> optimizer = optim.SGD(learning_rate=1e-1, momentum=0.9)
|
>>> optimizer = optim.SGD(learning_rate=1e-1, momentum=0.9)
|
||||||
>>> model = nn.Linear(2, 2)
|
>>> model = nn.Linear(2, 2)
|
||||||
>>> optimizer.init(model.trainable_parameters())
|
>>> optimizer.init(model.trainable_parameters())
|
||||||
>>> optimizer.state
|
>>> optimizer.state.keys()
|
||||||
{'learning_rate': array(0.1, dtype=float32), 'weight': {'v': array([[0, 0],
|
dict_keys(['step', 'learning_rate', 'weight', 'bias'])
|
||||||
[0, 0]], dtype=float32)}, 'bias': {'v': array([0, 0], dtype=float32)}}
|
|
||||||
"""
|
"""
|
||||||
self._state.update(tree_map(lambda x: {}, parameters))
|
self._state.update(tree_map(lambda x: {}, parameters))
|
||||||
tree_map(self.init_single, parameters, self._state)
|
tree_map(self.init_single, parameters, self._state)
|
||||||
@ -76,6 +76,15 @@ class Optimizer:
|
|||||||
"""
|
"""
|
||||||
if not self._initialized:
|
if not self._initialized:
|
||||||
self.init(gradients)
|
self.init(gradients)
|
||||||
|
|
||||||
|
# Update any scheduled variables
|
||||||
|
for param, scheduler in self._schedulers.items():
|
||||||
|
self.state[param] = scheduler(self.step)
|
||||||
|
|
||||||
|
# Increment the step
|
||||||
|
self.state["step"] = self.step + 1
|
||||||
|
|
||||||
|
# Apply the update
|
||||||
return tree_map(self.apply_single, gradients, parameters, self.state)
|
return tree_map(self.apply_single, gradients, parameters, self.state)
|
||||||
|
|
||||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||||
@ -97,14 +106,31 @@ class Optimizer:
|
|||||||
def state(self, state: dict):
|
def state(self, state: dict):
|
||||||
self._state = state
|
self._state = state
|
||||||
|
|
||||||
|
@property
|
||||||
|
def step(self):
|
||||||
|
return self.state["step"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def learning_rate(self):
|
def learning_rate(self):
|
||||||
return self.state["learning_rate"]
|
return self.state["learning_rate"]
|
||||||
|
|
||||||
@learning_rate.setter
|
@learning_rate.setter
|
||||||
def learning_rate(self, learning_rate: mx.array):
|
def learning_rate(self, learning_rate: Union[float, mx.array]):
|
||||||
self.state["learning_rate"] = mx.array(learning_rate)
|
self.state["learning_rate"] = mx.array(learning_rate)
|
||||||
|
|
||||||
|
def _maybe_schedule(
|
||||||
|
self, name: str, param: Union[float, Callable[[mx.array], mx.array]]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
To be used by derived classes to optionally put a parameter on a schedule.
|
||||||
|
"""
|
||||||
|
if isinstance(param, Callable):
|
||||||
|
self._schedulers[name] = param
|
||||||
|
param = param(self.step)
|
||||||
|
else:
|
||||||
|
param = mx.array(param)
|
||||||
|
self.state[name] = param
|
||||||
|
|
||||||
|
|
||||||
class SGD(Optimizer):
|
class SGD(Optimizer):
|
||||||
r"""The stochastic gradient descent optimizer.
|
r"""The stochastic gradient descent optimizer.
|
||||||
@ -117,7 +143,7 @@ class SGD(Optimizer):
|
|||||||
w_{t+1} &= w_t - \lambda v_{t+1}
|
w_{t+1} &= w_t - \lambda v_{t+1}
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
learning_rate (float): The learning rate :math:`\lambda`.
|
learning_rate (float or callable): The learning rate :math:`\lambda`.
|
||||||
momentum (float, optional): The momentum strength :math:`\mu`. Default: ``0``
|
momentum (float, optional): The momentum strength :math:`\mu`. Default: ``0``
|
||||||
weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0``
|
weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0``
|
||||||
dampening (float, optional): Dampening for momentum :math:`\tau`. Default: ``0``
|
dampening (float, optional): Dampening for momentum :math:`\tau`. Default: ``0``
|
||||||
@ -126,7 +152,7 @@ class SGD(Optimizer):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
learning_rate: float,
|
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
||||||
momentum: float = 0.0,
|
momentum: float = 0.0,
|
||||||
weight_decay: float = 0.0,
|
weight_decay: float = 0.0,
|
||||||
dampening: float = 0.0,
|
dampening: float = 0.0,
|
||||||
@ -138,7 +164,7 @@ class SGD(Optimizer):
|
|||||||
)
|
)
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.learning_rate = learning_rate
|
self._maybe_schedule("learning_rate", learning_rate)
|
||||||
self.momentum = momentum
|
self.momentum = momentum
|
||||||
self.weight_decay = weight_decay
|
self.weight_decay = weight_decay
|
||||||
self.dampening = dampening
|
self.dampening = dampening
|
||||||
@ -194,7 +220,7 @@ class RMSprop(Optimizer):
|
|||||||
def __init__(self, learning_rate: float, alpha: float = 0.99, eps: float = 1e-8):
|
def __init__(self, learning_rate: float, alpha: float = 0.99, eps: float = 1e-8):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.learning_rate = learning_rate
|
self._maybe_schedule("learning_rate", learning_rate)
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
||||||
@ -246,7 +272,7 @@ class Adagrad(Optimizer):
|
|||||||
def __init__(self, learning_rate: float, eps: float = 1e-8):
|
def __init__(self, learning_rate: float, eps: float = 1e-8):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.learning_rate = learning_rate
|
self._maybe_schedule("learning_rate", learning_rate)
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
||||||
if self.eps < 0.0:
|
if self.eps < 0.0:
|
||||||
@ -295,7 +321,7 @@ class AdaDelta(Optimizer):
|
|||||||
def __init__(self, learning_rate: float, rho: float = 0.9, eps: float = 1e-6):
|
def __init__(self, learning_rate: float, rho: float = 0.9, eps: float = 1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.learning_rate = learning_rate
|
self._maybe_schedule("learning_rate", learning_rate)
|
||||||
self.rho = rho
|
self.rho = rho
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
if self.rho < 0.0:
|
if self.rho < 0.0:
|
||||||
@ -361,7 +387,7 @@ class Adam(Optimizer):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.learning_rate = learning_rate
|
self._maybe_schedule("learning_rate", learning_rate)
|
||||||
self.betas = betas
|
self.betas = betas
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
||||||
@ -526,7 +552,7 @@ class Lion(Optimizer):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.learning_rate = learning_rate
|
self._maybe_schedule("learning_rate", learning_rate)
|
||||||
self.betas = betas
|
self.betas = betas
|
||||||
self.weight_decay = weight_decay
|
self.weight_decay = weight_decay
|
||||||
|
|
||||||
@ -596,7 +622,7 @@ class Adafactor(Optimizer):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if learning_rate is not None:
|
if learning_rate is not None:
|
||||||
self.learning_rate = learning_rate
|
self._maybe_schedule("learning_rate", learning_rate)
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.clip_threshold = clip_threshold
|
self.clip_threshold = clip_threshold
|
||||||
self.decay_rate = decay_rate
|
self.decay_rate = decay_rate
|
||||||
@ -608,7 +634,6 @@ class Adafactor(Optimizer):
|
|||||||
|
|
||||||
def init_single(self, parameter: mx.array, state: dict):
|
def init_single(self, parameter: mx.array, state: dict):
|
||||||
"""Initialize optimizer state"""
|
"""Initialize optimizer state"""
|
||||||
state["step"] = 0
|
|
||||||
if parameter.ndim >= 2:
|
if parameter.ndim >= 2:
|
||||||
shape = parameter.shape
|
shape = parameter.shape
|
||||||
dtype = parameter.dtype
|
dtype = parameter.dtype
|
||||||
@ -626,10 +651,11 @@ class Adafactor(Optimizer):
|
|||||||
def _compute_learning_rate(self, step, parameter_rms):
|
def _compute_learning_rate(self, step, parameter_rms):
|
||||||
if self.relative_step:
|
if self.relative_step:
|
||||||
min_step = 1e-6 * step if self.warmup_init else 1e-2
|
min_step = 1e-6 * step if self.warmup_init else 1e-2
|
||||||
relative_step_size = min(min_step, 1 / math.sqrt(step))
|
relative_step_size = mx.minimum(min_step, mx.rsqrt(step))
|
||||||
else:
|
else:
|
||||||
relative_step_size = self.learning_rate.astype(parameter_rms)
|
relative_step_size = self.learning_rate
|
||||||
|
|
||||||
|
relative_step_size = relative_step_size.astype(parameter_rms.dtype)
|
||||||
parameter_scale = 1.0
|
parameter_scale = 1.0
|
||||||
if self.scale_parameter:
|
if self.scale_parameter:
|
||||||
parameter_scale = mx.maximum(self.eps[1], parameter_rms)
|
parameter_scale = mx.maximum(self.eps[1], parameter_rms)
|
||||||
@ -648,13 +674,12 @@ class Adafactor(Optimizer):
|
|||||||
"""Performs the Adafactor parameter and state update."""
|
"""Performs the Adafactor parameter and state update."""
|
||||||
factored = gradient.ndim >= 2
|
factored = gradient.ndim >= 2
|
||||||
|
|
||||||
step = state["step"] + 1
|
step = self.step
|
||||||
state["step"] = step
|
|
||||||
use_first_moment = self.beta_1 is not None
|
use_first_moment = self.beta_1 is not None
|
||||||
|
|
||||||
parameter_rms = self._compute_rms(parameter)
|
parameter_rms = self._compute_rms(parameter)
|
||||||
learning_rate = self._compute_learning_rate(step, parameter_rms)
|
learning_rate = self._compute_learning_rate(step, parameter_rms)
|
||||||
beta_2 = 1.0 - math.pow(step, self.decay_rate)
|
beta_2 = 1.0 - (step**self.decay_rate).astype(parameter_rms.dtype)
|
||||||
update = mx.square(gradient) + self.eps[0]
|
update = mx.square(gradient) + self.eps[0]
|
||||||
|
|
||||||
if factored:
|
if factored:
|
86
python/mlx/optimizers/schedulers.py
Normal file
86
python/mlx/optimizers/schedulers.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
|
def exponential_decay(init: float, decay_rate: float):
|
||||||
|
r"""Make an exponential decay scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
init (float): Initial value.
|
||||||
|
decay_rate (float): Multiplicative factor to decay by.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> lr_schedule = optim.exponential_decay(1e-1, 0.9)
|
||||||
|
>>> optimizer = optim.SGD(learning_rate=lr_schedule)
|
||||||
|
>>> optimizer.learning_rate
|
||||||
|
array(0.1, dtype=float32)
|
||||||
|
>>>
|
||||||
|
>>> for _ in range(5): optimizer.update({}, {})
|
||||||
|
...
|
||||||
|
>>> optimizer.learning_rate
|
||||||
|
array(0.06561, dtype=float32)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def schedule(step):
|
||||||
|
return init * decay_rate**step
|
||||||
|
|
||||||
|
return schedule
|
||||||
|
|
||||||
|
|
||||||
|
def step_decay(init: float, decay_rate: float, step_size: int):
|
||||||
|
r"""Make a step decay scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
init (float): Initial value.
|
||||||
|
decay_rate (float): Multiplicative factor to decay by.
|
||||||
|
step_size (int): Decay every ``step_size`` steps.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> lr_schedule = optim.step_decay(1e-1, 0.9, 10)
|
||||||
|
>>> optimizer = optim.SGD(learning_rate=lr_schedule)
|
||||||
|
>>> optimizer.learning_rate
|
||||||
|
array(0.1, dtype=float32)
|
||||||
|
>>>
|
||||||
|
>>> for _ in range(21): optimizer.update({}, {})
|
||||||
|
...
|
||||||
|
>>> optimizer.learning_rate
|
||||||
|
array(0.081, dtype=float32)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def schedule(step):
|
||||||
|
return init * (decay_rate ** (step // step_size))
|
||||||
|
|
||||||
|
return schedule
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_decay(init: float, decay_steps: int):
|
||||||
|
r"""Make a cosine decay scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
init (float): Initial value.
|
||||||
|
decay_steps (int): Number of steps to decay over. The decayed
|
||||||
|
value is constant for steps beyond ``decay_steps``.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> lr_schedule = optim.cosine_decay(1e-1, 1000)
|
||||||
|
>>> optimizer = optim.SGD(learning_rate=lr_schedule)
|
||||||
|
>>> optimizer.learning_rate
|
||||||
|
array(0.1, dtype=float32)
|
||||||
|
>>>
|
||||||
|
>>> for _ in range(5): optimizer.update({}, {})
|
||||||
|
...
|
||||||
|
>>> optimizer.learning_rate
|
||||||
|
array(0.0999961, dtype=float32)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def scheduler(step):
|
||||||
|
s = mx.minimum(step, decay_steps)
|
||||||
|
decay = 0.5 * (1.0 + mx.cos((math.pi / decay_steps) * s))
|
||||||
|
return init * decay
|
||||||
|
|
||||||
|
return scheduler
|
@ -971,6 +971,12 @@ void init_array(py::module_& m) {
|
|||||||
return power(a, to_array(v, a.dtype()));
|
return power(a, to_array(v, a.dtype()));
|
||||||
},
|
},
|
||||||
"other"_a)
|
"other"_a)
|
||||||
|
.def(
|
||||||
|
"__rpow__",
|
||||||
|
[](const array& a, const ScalarOrArray v) {
|
||||||
|
return power(to_array(v, a.dtype()), a);
|
||||||
|
},
|
||||||
|
"other"_a)
|
||||||
.def(
|
.def(
|
||||||
"__ipow__",
|
"__ipow__",
|
||||||
[](array& a, const ScalarOrArray v) {
|
[](array& a, const ScalarOrArray v) {
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import math
|
||||||
import unittest
|
import unittest
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
@ -15,9 +16,12 @@ from mlx.utils import tree_flatten, tree_map
|
|||||||
def get_all_optimizers():
|
def get_all_optimizers():
|
||||||
classes = dict()
|
classes = dict()
|
||||||
for name, obj in inspect.getmembers(opt):
|
for name, obj in inspect.getmembers(opt):
|
||||||
if inspect.isclass(obj):
|
if (
|
||||||
if obj.__name__ not in ["Optimizer"]:
|
inspect.isclass(obj)
|
||||||
classes[name] = obj
|
and issubclass(obj, opt.Optimizer)
|
||||||
|
and obj != opt.Optimizer
|
||||||
|
):
|
||||||
|
classes[name] = obj
|
||||||
return classes
|
return classes
|
||||||
|
|
||||||
|
|
||||||
@ -204,18 +208,16 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
|||||||
x = mx.zeros((5, 5))
|
x = mx.zeros((5, 5))
|
||||||
grad = mx.ones_like(x)
|
grad = mx.ones_like(x)
|
||||||
optimizer = opt.Adafactor()
|
optimizer = opt.Adafactor()
|
||||||
optimizer.init(x)
|
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
xp = optimizer.apply_single(grad, x, optimizer.state)
|
xp = optimizer.apply_gradients(grad, x)
|
||||||
self.assertEqual(xp.dtype, x.dtype)
|
self.assertEqual(xp.dtype, x.dtype)
|
||||||
self.assertEqual(xp.shape, x.shape)
|
self.assertEqual(xp.shape, x.shape)
|
||||||
|
|
||||||
x = mx.zeros((5, 5), mx.float16)
|
x = mx.zeros((5, 5), mx.float16)
|
||||||
grad = mx.ones_like(x)
|
grad = mx.ones_like(x)
|
||||||
optimizer = opt.Adafactor()
|
optimizer = opt.Adafactor()
|
||||||
optimizer.init(x)
|
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
xp = optimizer.apply_single(grad, x, optimizer.state)
|
xp = optimizer.apply_gradients(grad, x)
|
||||||
self.assertEqual(xp.dtype, x.dtype)
|
self.assertEqual(xp.dtype, x.dtype)
|
||||||
self.assertEqual(xp.shape, x.shape)
|
self.assertEqual(xp.shape, x.shape)
|
||||||
self.assertEqual(optimizer.state["step"], 2)
|
self.assertEqual(optimizer.state["step"], 2)
|
||||||
@ -294,5 +296,50 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0)))
|
self.assertTrue(mx.allclose(result["w"], mx.full((5, 5), 3.0)))
|
||||||
|
|
||||||
|
|
||||||
|
class TestSchedulers(unittest.TestCase):
|
||||||
|
def test_decay_lr(self):
|
||||||
|
for optim_class in optimizers_dict.values():
|
||||||
|
lr_schedule = opt.step_decay(1e-1, 0.9, 1000)
|
||||||
|
optimizer = optim_class(learning_rate=lr_schedule)
|
||||||
|
|
||||||
|
params = {"w": mx.ones((5, 5))}
|
||||||
|
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||||
|
|
||||||
|
for it in range(10):
|
||||||
|
expected_lr = 0.1 * (0.9**it)
|
||||||
|
self.assertAlmostEqual(optimizer.learning_rate, expected_lr, delta=1e-7)
|
||||||
|
return optimizer.apply_gradients(grads, params)
|
||||||
|
|
||||||
|
def test_step_decay(self):
|
||||||
|
lr_schedule = opt.step_decay(1e-1, 0.9, 1000)
|
||||||
|
lr = lr_schedule(2500)
|
||||||
|
expected_lr = 0.1 * (0.9**2)
|
||||||
|
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
|
||||||
|
|
||||||
|
def test_exponential_decay(self):
|
||||||
|
lr_schedule = opt.exponential_decay(1e-1, 0.99)
|
||||||
|
lr = lr_schedule(10)
|
||||||
|
expected_lr = 0.1 * (0.99**10)
|
||||||
|
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
|
||||||
|
|
||||||
|
def test_cosine_decay(self):
|
||||||
|
lr_schedule = opt.cosine_decay(0.1, 10)
|
||||||
|
lr = lr_schedule(4)
|
||||||
|
expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
|
||||||
|
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
|
||||||
|
|
||||||
|
def test_compile_with_schedule(self):
|
||||||
|
lr_schedule = opt.exponential_decay(1e-1, 0.9)
|
||||||
|
optimizer = opt.SGD(learning_rate=lr_schedule)
|
||||||
|
|
||||||
|
@partial(mx.compile, inputs=optimizer.state, outputs=optimizer.state)
|
||||||
|
def update():
|
||||||
|
optimizer.update({}, {})
|
||||||
|
|
||||||
|
for step in range(5):
|
||||||
|
update()
|
||||||
|
self.assertAlmostEqual(lr_schedule(step), optimizer.learning_rate.item())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user