From 79c859e2e0984a6684ea5598cdc15ee77cd3d5a6 Mon Sep 17 00:00:00 2001 From: Nripesh Niketan <86844847+NripeshN@users.noreply.github.com> Date: Fri, 3 May 2024 20:07:02 +0400 Subject: [PATCH] feat: implement `clip_grad_norm` (#1043) * feat: implement `clip_grad_norm` * pre-commit * Add test for clip_grad_norm function in test_optimizers.py * small fixes * fix * lint * Update tree_reduce * Update python/mlx/utils.py Co-authored-by: Awni Hannun * Update python/mlx/utils.py Co-authored-by: Awni Hannun * Update python/mlx/utils.py Co-authored-by: Awni Hannun * Update python/mlx/utils.py Co-authored-by: Awni Hannun * Update python/mlx/utils.py Co-authored-by: Awni Hannun * Update python/mlx/utils.py Co-authored-by: Awni Hannun * Refactor clip_grad_norm function to include documentation and improve readability * format docstring * Add acknowlegements * text wrap * pre-commit * nits in docs --------- Co-authored-by: Awni Hannun Co-authored-by: Awni Hannun --- ACKNOWLEDGMENTS.md | 2 +- CMakeLists.txt | 2 +- docs/src/python/optimizers.rst | 7 +++++ docs/src/python/tree_utils.rst | 1 + python/mlx/optimizers/optimizers.py | 34 ++++++++++++++++++++++- python/mlx/utils.py | 42 +++++++++++++++++++++++++++++ python/tests/test_optimizers.py | 42 +++++++++++++++++++++++++++++ 7 files changed, 127 insertions(+), 3 deletions(-) diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 6c151117c..05dca2768 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example: MLX was developed with contributions from the following individuals: -- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. +- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. - Juarez Bochi: Fixed bug in cross attention. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. - Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support. diff --git a/CMakeLists.txt b/CMakeLists.txt index 8781cc4bd..151017b9a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -94,7 +94,7 @@ elseif (MLX_BUILD_METAL) FetchContent_Declare( metal_cpp URL ${METAL_CPP_URL} - PATCH_COMMAND patch -N -i ${METAL_CPP_PATCH} || true + PATCH_COMMAND /usr/bin/patch -N -i ${METAL_CPP_PATCH} || true ) FetchContent_MakeAvailable(metal_cpp) diff --git a/docs/src/python/optimizers.rst b/docs/src/python/optimizers.rst index f437ddc15..84ab933ac 100644 --- a/docs/src/python/optimizers.rst +++ b/docs/src/python/optimizers.rst @@ -1,5 +1,7 @@ .. _optimizers: +.. currentmodule:: mlx.optimizers + Optimizers ========== @@ -34,3 +36,8 @@ model's parameters and the **optimizer state**. optimizers/optimizer optimizers/common_optimizers optimizers/schedulers + +.. autosummary:: + :toctree: _autosummary + + clip_grad_norm diff --git a/docs/src/python/tree_utils.rst b/docs/src/python/tree_utils.rst index dbd0ebce9..6dc60b47d 100644 --- a/docs/src/python/tree_utils.rst +++ b/docs/src/python/tree_utils.rst @@ -20,3 +20,4 @@ return python trees will be using the default python ``dict``, ``list`` and tree_unflatten tree_map tree_map_with_path + tree_reduce diff --git a/python/mlx/optimizers/optimizers.py b/python/mlx/optimizers/optimizers.py index 054466f90..58198f1d4 100644 --- a/python/mlx/optimizers/optimizers.py +++ b/python/mlx/optimizers/optimizers.py @@ -4,7 +4,7 @@ import math from typing import Callable, List, Optional, Tuple, Union import mlx.core as mx -from mlx.utils import tree_map +from mlx.utils import tree_map, tree_reduce class Optimizer: @@ -736,3 +736,35 @@ class Adafactor(Optimizer): if self.weight_decay != 0: parameter += parameter * (-self.weight_decay * learning_rate) return parameter - update + + +def clip_grad_norm(grads, max_norm): + """Clips the global norm of the gradients. + + This function ensures that the global norm of the gradients does not exceed + ``max_norm``. It scales down the gradients proportionally if their norm is + greater than ``max_norm``. + + Example: + >>> grads = {"w1": mx.array([2, 3]), "w2": mx.array([1])} + >>> clipped_grads, total_norm = clip_grad_norm(grads, max_norm=2.0) + >>> print(clipped_grads) + {"w1": mx.array([...]), "w2": mx.array([...])} + + Args: + grads (dict): A dictionary containing the gradient arrays. + max_norm (float): The maximum allowed global norm of the gradients. + + Returns: + (dict, float): The possibly rescaled gradients and the original + gradient norm. + """ + norm_squared = tree_reduce(lambda acc, g: acc + g.square().sum(), grads, 0.0) + total_norm = mx.sqrt(norm_squared) + normalizer = max_norm / (total_norm + 1e-6) + + def clipper(g): + return mx.where(total_norm < max_norm, g, g * normalizer) + + clipped_grads = tree_map(clipper, grads) + return clipped_grads, total_norm diff --git a/python/mlx/utils.py b/python/mlx/utils.py index 31e94a2a1..e7b61373a 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -191,3 +191,45 @@ def tree_unflatten(tree): return l else: return {k: tree_unflatten(v) for k, v in children.items()} + + +def tree_reduce(fn, tree, initializer=None, is_leaf=None): + """Applies a reduction to the leaves of a Python tree. + + This function reduces Python trees into an accumulated result by applying + the provided function ``fn`` to the leaves of the tree. + + Example: + >>> from mlx.utils import tree_reduce + >>> tree = {"a": [1, 2, 3], "b": [4, 5]} + >>> tree_reduce(lambda acc, x: acc + x, tree, 0) + 15 + + Args: + fn (callable): The reducer function that takes two arguments (accumulator, + current value) and returns the updated accumulator. + tree (Any): The Python tree to reduce. It can be any nested combination of + lists, tuples, or dictionaries. + initializer (Any, optional): The initial value to start the reduction. If + not provided, the first leaf value is used. + is_leaf (callable, optional): A function to determine if an object is a + leaf, returning ``True`` for leaf nodes and ``False`` otherwise. + + Returns: + Any: The accumulated value. + """ + if is_leaf is not None and is_leaf(tree): + return tree if initializer is None else fn(initializer, tree) + + accumulator = initializer + + if isinstance(tree, (list, tuple)): + for item in tree: + accumulator = tree_reduce(fn, item, accumulator, is_leaf) + elif isinstance(tree, dict): + for item in tree.values(): + accumulator = tree_reduce(fn, item, accumulator, is_leaf) + else: + return tree if accumulator is None else fn(accumulator, tree) + + return accumulator diff --git a/python/tests/test_optimizers.py b/python/tests/test_optimizers.py index 41538dd49..950850b1f 100644 --- a/python/tests/test_optimizers.py +++ b/python/tests/test_optimizers.py @@ -376,6 +376,48 @@ class TestSchedulers(unittest.TestCase): update() self.assertAlmostEqual(lr_schedule(step), optimizer.learning_rate.item()) + def test_clip_grad_norm(self): + # Test with small gradients that do not require clipping + small_grads = { + "first": [mx.array([0.1, 0.2]), mx.array([0.1])], + "second": mx.array([0.3]), + } + max_norm = 10.0 # A large max_norm that shouldn't trigger clipping + clipped_grads, total_norm = opt.clip_grad_norm(small_grads, max_norm) + self.assertTrue( + tree_equal(lambda x, y: mx.array_equal(x, y), small_grads, clipped_grads), + "Gradients should not be modified when clipping is not necessary.", + ) + + # Test with large gradients that require clipping + large_grads = { + "first": [mx.array([10, 20]), mx.array([10])], + "second": mx.array([30]), + } + max_norm = 1.0 # A small max_norm that should trigger clipping + clipped_grads, total_norm = opt.clip_grad_norm(large_grads, max_norm) + # Correctly extract only the gradient values for norm calculation + clipped_values = [value for _, value in tree_flatten(clipped_grads)] + norm_of_clipped = mx.sqrt( + sum(mx.square(g).sum() for g in clipped_values) + ).item() + self.assertAlmostEqual( + norm_of_clipped, + max_norm, + places=6, + msg="Clipped gradients norm should be close to the specified max_norm.", + ) + + # Ensures that the scaling was done correctly + scale = max_norm / total_norm + expected_grads = tree_map(lambda g: g * scale, large_grads) + self.assertTrue( + tree_equal( + lambda x, y: mx.allclose(x, y, atol=1e-6), expected_grads, clipped_grads + ), + "Gradients were not scaled correctly during clipping.", + ) + if __name__ == "__main__": unittest.main()