mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
feat: implement clip_grad_norm (#1043)
* feat: implement `clip_grad_norm` * pre-commit * Add test for clip_grad_norm function in test_optimizers.py * small fixes * fix * lint * Update tree_reduce * Update python/mlx/utils.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/mlx/utils.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/mlx/utils.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/mlx/utils.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/mlx/utils.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/mlx/utils.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Refactor clip_grad_norm function to include documentation and improve readability * format docstring * Add acknowlegements * text wrap * pre-commit * nits in docs --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -191,3 +191,45 @@ def tree_unflatten(tree):
|
||||
return l
|
||||
else:
|
||||
return {k: tree_unflatten(v) for k, v in children.items()}
|
||||
|
||||
|
||||
def tree_reduce(fn, tree, initializer=None, is_leaf=None):
|
||||
"""Applies a reduction to the leaves of a Python tree.
|
||||
|
||||
This function reduces Python trees into an accumulated result by applying
|
||||
the provided function ``fn`` to the leaves of the tree.
|
||||
|
||||
Example:
|
||||
>>> from mlx.utils import tree_reduce
|
||||
>>> tree = {"a": [1, 2, 3], "b": [4, 5]}
|
||||
>>> tree_reduce(lambda acc, x: acc + x, tree, 0)
|
||||
15
|
||||
|
||||
Args:
|
||||
fn (callable): The reducer function that takes two arguments (accumulator,
|
||||
current value) and returns the updated accumulator.
|
||||
tree (Any): The Python tree to reduce. It can be any nested combination of
|
||||
lists, tuples, or dictionaries.
|
||||
initializer (Any, optional): The initial value to start the reduction. If
|
||||
not provided, the first leaf value is used.
|
||||
is_leaf (callable, optional): A function to determine if an object is a
|
||||
leaf, returning ``True`` for leaf nodes and ``False`` otherwise.
|
||||
|
||||
Returns:
|
||||
Any: The accumulated value.
|
||||
"""
|
||||
if is_leaf is not None and is_leaf(tree):
|
||||
return tree if initializer is None else fn(initializer, tree)
|
||||
|
||||
accumulator = initializer
|
||||
|
||||
if isinstance(tree, (list, tuple)):
|
||||
for item in tree:
|
||||
accumulator = tree_reduce(fn, item, accumulator, is_leaf)
|
||||
elif isinstance(tree, dict):
|
||||
for item in tree.values():
|
||||
accumulator = tree_reduce(fn, item, accumulator, is_leaf)
|
||||
else:
|
||||
return tree if accumulator is None else fn(accumulator, tree)
|
||||
|
||||
return accumulator
|
||||
|
||||
Reference in New Issue
Block a user