feat: implement clip_grad_norm (#1043)

* feat: implement `clip_grad_norm`

* pre-commit

* Add test for clip_grad_norm function in test_optimizers.py

* small fixes

* fix

* lint

* Update tree_reduce

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Refactor clip_grad_norm function to include documentation and improve readability

* format docstring

* Add acknowlegements

* text wrap

* pre-commit

* nits in docs

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Nripesh Niketan
2024-05-03 20:07:02 +04:00
committed by GitHub
parent b00ac960b4
commit 79c859e2e0
7 changed files with 127 additions and 3 deletions

View File

@@ -1,5 +1,7 @@
.. _optimizers:
.. currentmodule:: mlx.optimizers
Optimizers
==========
@@ -34,3 +36,8 @@ model's parameters and the **optimizer state**.
optimizers/optimizer
optimizers/common_optimizers
optimizers/schedulers
.. autosummary::
:toctree: _autosummary
clip_grad_norm

View File

@@ -20,3 +20,4 @@ return python trees will be using the default python ``dict``, ``list`` and
tree_unflatten
tree_map
tree_map_with_path
tree_reduce