mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
feat: implement clip_grad_norm
(#1043)
* feat: implement `clip_grad_norm` * pre-commit * Add test for clip_grad_norm function in test_optimizers.py * small fixes * fix * lint * Update tree_reduce * Update python/mlx/utils.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/mlx/utils.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/mlx/utils.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/mlx/utils.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/mlx/utils.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/mlx/utils.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Refactor clip_grad_norm function to include documentation and improve readability * format docstring * Add acknowlegements * text wrap * pre-commit * nits in docs --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
b00ac960b4
commit
79c859e2e0
@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example:
|
|||||||
|
|
||||||
MLX was developed with contributions from the following individuals:
|
MLX was developed with contributions from the following individuals:
|
||||||
|
|
||||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops.
|
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`.
|
||||||
- Juarez Bochi: Fixed bug in cross attention.
|
- Juarez Bochi: Fixed bug in cross attention.
|
||||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
|
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
|
||||||
|
@ -94,7 +94,7 @@ elseif (MLX_BUILD_METAL)
|
|||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
metal_cpp
|
metal_cpp
|
||||||
URL ${METAL_CPP_URL}
|
URL ${METAL_CPP_URL}
|
||||||
PATCH_COMMAND patch -N -i ${METAL_CPP_PATCH} || true
|
PATCH_COMMAND /usr/bin/patch -N -i ${METAL_CPP_PATCH} || true
|
||||||
)
|
)
|
||||||
|
|
||||||
FetchContent_MakeAvailable(metal_cpp)
|
FetchContent_MakeAvailable(metal_cpp)
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
.. _optimizers:
|
.. _optimizers:
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.optimizers
|
||||||
|
|
||||||
Optimizers
|
Optimizers
|
||||||
==========
|
==========
|
||||||
|
|
||||||
@ -34,3 +36,8 @@ model's parameters and the **optimizer state**.
|
|||||||
optimizers/optimizer
|
optimizers/optimizer
|
||||||
optimizers/common_optimizers
|
optimizers/common_optimizers
|
||||||
optimizers/schedulers
|
optimizers/schedulers
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
clip_grad_norm
|
||||||
|
@ -20,3 +20,4 @@ return python trees will be using the default python ``dict``, ``list`` and
|
|||||||
tree_unflatten
|
tree_unflatten
|
||||||
tree_map
|
tree_map
|
||||||
tree_map_with_path
|
tree_map_with_path
|
||||||
|
tree_reduce
|
||||||
|
@ -4,7 +4,7 @@ import math
|
|||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.utils import tree_map
|
from mlx.utils import tree_map, tree_reduce
|
||||||
|
|
||||||
|
|
||||||
class Optimizer:
|
class Optimizer:
|
||||||
@ -736,3 +736,35 @@ class Adafactor(Optimizer):
|
|||||||
if self.weight_decay != 0:
|
if self.weight_decay != 0:
|
||||||
parameter += parameter * (-self.weight_decay * learning_rate)
|
parameter += parameter * (-self.weight_decay * learning_rate)
|
||||||
return parameter - update
|
return parameter - update
|
||||||
|
|
||||||
|
|
||||||
|
def clip_grad_norm(grads, max_norm):
|
||||||
|
"""Clips the global norm of the gradients.
|
||||||
|
|
||||||
|
This function ensures that the global norm of the gradients does not exceed
|
||||||
|
``max_norm``. It scales down the gradients proportionally if their norm is
|
||||||
|
greater than ``max_norm``.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> grads = {"w1": mx.array([2, 3]), "w2": mx.array([1])}
|
||||||
|
>>> clipped_grads, total_norm = clip_grad_norm(grads, max_norm=2.0)
|
||||||
|
>>> print(clipped_grads)
|
||||||
|
{"w1": mx.array([...]), "w2": mx.array([...])}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grads (dict): A dictionary containing the gradient arrays.
|
||||||
|
max_norm (float): The maximum allowed global norm of the gradients.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(dict, float): The possibly rescaled gradients and the original
|
||||||
|
gradient norm.
|
||||||
|
"""
|
||||||
|
norm_squared = tree_reduce(lambda acc, g: acc + g.square().sum(), grads, 0.0)
|
||||||
|
total_norm = mx.sqrt(norm_squared)
|
||||||
|
normalizer = max_norm / (total_norm + 1e-6)
|
||||||
|
|
||||||
|
def clipper(g):
|
||||||
|
return mx.where(total_norm < max_norm, g, g * normalizer)
|
||||||
|
|
||||||
|
clipped_grads = tree_map(clipper, grads)
|
||||||
|
return clipped_grads, total_norm
|
||||||
|
@ -191,3 +191,45 @@ def tree_unflatten(tree):
|
|||||||
return l
|
return l
|
||||||
else:
|
else:
|
||||||
return {k: tree_unflatten(v) for k, v in children.items()}
|
return {k: tree_unflatten(v) for k, v in children.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def tree_reduce(fn, tree, initializer=None, is_leaf=None):
|
||||||
|
"""Applies a reduction to the leaves of a Python tree.
|
||||||
|
|
||||||
|
This function reduces Python trees into an accumulated result by applying
|
||||||
|
the provided function ``fn`` to the leaves of the tree.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from mlx.utils import tree_reduce
|
||||||
|
>>> tree = {"a": [1, 2, 3], "b": [4, 5]}
|
||||||
|
>>> tree_reduce(lambda acc, x: acc + x, tree, 0)
|
||||||
|
15
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn (callable): The reducer function that takes two arguments (accumulator,
|
||||||
|
current value) and returns the updated accumulator.
|
||||||
|
tree (Any): The Python tree to reduce. It can be any nested combination of
|
||||||
|
lists, tuples, or dictionaries.
|
||||||
|
initializer (Any, optional): The initial value to start the reduction. If
|
||||||
|
not provided, the first leaf value is used.
|
||||||
|
is_leaf (callable, optional): A function to determine if an object is a
|
||||||
|
leaf, returning ``True`` for leaf nodes and ``False`` otherwise.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The accumulated value.
|
||||||
|
"""
|
||||||
|
if is_leaf is not None and is_leaf(tree):
|
||||||
|
return tree if initializer is None else fn(initializer, tree)
|
||||||
|
|
||||||
|
accumulator = initializer
|
||||||
|
|
||||||
|
if isinstance(tree, (list, tuple)):
|
||||||
|
for item in tree:
|
||||||
|
accumulator = tree_reduce(fn, item, accumulator, is_leaf)
|
||||||
|
elif isinstance(tree, dict):
|
||||||
|
for item in tree.values():
|
||||||
|
accumulator = tree_reduce(fn, item, accumulator, is_leaf)
|
||||||
|
else:
|
||||||
|
return tree if accumulator is None else fn(accumulator, tree)
|
||||||
|
|
||||||
|
return accumulator
|
||||||
|
@ -376,6 +376,48 @@ class TestSchedulers(unittest.TestCase):
|
|||||||
update()
|
update()
|
||||||
self.assertAlmostEqual(lr_schedule(step), optimizer.learning_rate.item())
|
self.assertAlmostEqual(lr_schedule(step), optimizer.learning_rate.item())
|
||||||
|
|
||||||
|
def test_clip_grad_norm(self):
|
||||||
|
# Test with small gradients that do not require clipping
|
||||||
|
small_grads = {
|
||||||
|
"first": [mx.array([0.1, 0.2]), mx.array([0.1])],
|
||||||
|
"second": mx.array([0.3]),
|
||||||
|
}
|
||||||
|
max_norm = 10.0 # A large max_norm that shouldn't trigger clipping
|
||||||
|
clipped_grads, total_norm = opt.clip_grad_norm(small_grads, max_norm)
|
||||||
|
self.assertTrue(
|
||||||
|
tree_equal(lambda x, y: mx.array_equal(x, y), small_grads, clipped_grads),
|
||||||
|
"Gradients should not be modified when clipping is not necessary.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with large gradients that require clipping
|
||||||
|
large_grads = {
|
||||||
|
"first": [mx.array([10, 20]), mx.array([10])],
|
||||||
|
"second": mx.array([30]),
|
||||||
|
}
|
||||||
|
max_norm = 1.0 # A small max_norm that should trigger clipping
|
||||||
|
clipped_grads, total_norm = opt.clip_grad_norm(large_grads, max_norm)
|
||||||
|
# Correctly extract only the gradient values for norm calculation
|
||||||
|
clipped_values = [value for _, value in tree_flatten(clipped_grads)]
|
||||||
|
norm_of_clipped = mx.sqrt(
|
||||||
|
sum(mx.square(g).sum() for g in clipped_values)
|
||||||
|
).item()
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
norm_of_clipped,
|
||||||
|
max_norm,
|
||||||
|
places=6,
|
||||||
|
msg="Clipped gradients norm should be close to the specified max_norm.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensures that the scaling was done correctly
|
||||||
|
scale = max_norm / total_norm
|
||||||
|
expected_grads = tree_map(lambda g: g * scale, large_grads)
|
||||||
|
self.assertTrue(
|
||||||
|
tree_equal(
|
||||||
|
lambda x, y: mx.allclose(x, y, atol=1e-6), expected_grads, clipped_grads
|
||||||
|
),
|
||||||
|
"Gradients were not scaled correctly during clipping.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user