feat: implement clip_grad_norm (#1043)

* feat: implement `clip_grad_norm`

* pre-commit

* Add test for clip_grad_norm function in test_optimizers.py

* small fixes

* fix

* lint

* Update tree_reduce

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Refactor clip_grad_norm function to include documentation and improve readability

* format docstring

* Add acknowlegements

* text wrap

* pre-commit

* nits in docs

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Nripesh Niketan
2024-05-03 20:07:02 +04:00
committed by GitHub
parent b00ac960b4
commit 79c859e2e0
7 changed files with 127 additions and 3 deletions

View File

@@ -4,7 +4,7 @@ import math
from typing import Callable, List, Optional, Tuple, Union
import mlx.core as mx
from mlx.utils import tree_map
from mlx.utils import tree_map, tree_reduce
class Optimizer:
@@ -736,3 +736,35 @@ class Adafactor(Optimizer):
if self.weight_decay != 0:
parameter += parameter * (-self.weight_decay * learning_rate)
return parameter - update
def clip_grad_norm(grads, max_norm):
"""Clips the global norm of the gradients.
This function ensures that the global norm of the gradients does not exceed
``max_norm``. It scales down the gradients proportionally if their norm is
greater than ``max_norm``.
Example:
>>> grads = {"w1": mx.array([2, 3]), "w2": mx.array([1])}
>>> clipped_grads, total_norm = clip_grad_norm(grads, max_norm=2.0)
>>> print(clipped_grads)
{"w1": mx.array([...]), "w2": mx.array([...])}
Args:
grads (dict): A dictionary containing the gradient arrays.
max_norm (float): The maximum allowed global norm of the gradients.
Returns:
(dict, float): The possibly rescaled gradients and the original
gradient norm.
"""
norm_squared = tree_reduce(lambda acc, g: acc + g.square().sum(), grads, 0.0)
total_norm = mx.sqrt(norm_squared)
normalizer = max_norm / (total_norm + 1e-6)
def clipper(g):
return mx.where(total_norm < max_norm, g, g * normalizer)
clipped_grads = tree_map(clipper, grads)
return clipped_grads, total_norm

View File

@@ -191,3 +191,45 @@ def tree_unflatten(tree):
return l
else:
return {k: tree_unflatten(v) for k, v in children.items()}
def tree_reduce(fn, tree, initializer=None, is_leaf=None):
"""Applies a reduction to the leaves of a Python tree.
This function reduces Python trees into an accumulated result by applying
the provided function ``fn`` to the leaves of the tree.
Example:
>>> from mlx.utils import tree_reduce
>>> tree = {"a": [1, 2, 3], "b": [4, 5]}
>>> tree_reduce(lambda acc, x: acc + x, tree, 0)
15
Args:
fn (callable): The reducer function that takes two arguments (accumulator,
current value) and returns the updated accumulator.
tree (Any): The Python tree to reduce. It can be any nested combination of
lists, tuples, or dictionaries.
initializer (Any, optional): The initial value to start the reduction. If
not provided, the first leaf value is used.
is_leaf (callable, optional): A function to determine if an object is a
leaf, returning ``True`` for leaf nodes and ``False`` otherwise.
Returns:
Any: The accumulated value.
"""
if is_leaf is not None and is_leaf(tree):
return tree if initializer is None else fn(initializer, tree)
accumulator = initializer
if isinstance(tree, (list, tuple)):
for item in tree:
accumulator = tree_reduce(fn, item, accumulator, is_leaf)
elif isinstance(tree, dict):
for item in tree.values():
accumulator = tree_reduce(fn, item, accumulator, is_leaf)
else:
return tree if accumulator is None else fn(accumulator, tree)
return accumulator