feat: implement clip_grad_norm (#1043)

* feat: implement `clip_grad_norm`

* pre-commit

* Add test for clip_grad_norm function in test_optimizers.py

* small fixes

* fix

* lint

* Update tree_reduce

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Refactor clip_grad_norm function to include documentation and improve readability

* format docstring

* Add acknowlegements

* text wrap

* pre-commit

* nits in docs

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Nripesh Niketan
2024-05-03 20:07:02 +04:00
committed by GitHub
parent b00ac960b4
commit 79c859e2e0
7 changed files with 127 additions and 3 deletions

View File

@@ -376,6 +376,48 @@ class TestSchedulers(unittest.TestCase):
update()
self.assertAlmostEqual(lr_schedule(step), optimizer.learning_rate.item())
def test_clip_grad_norm(self):
# Test with small gradients that do not require clipping
small_grads = {
"first": [mx.array([0.1, 0.2]), mx.array([0.1])],
"second": mx.array([0.3]),
}
max_norm = 10.0 # A large max_norm that shouldn't trigger clipping
clipped_grads, total_norm = opt.clip_grad_norm(small_grads, max_norm)
self.assertTrue(
tree_equal(lambda x, y: mx.array_equal(x, y), small_grads, clipped_grads),
"Gradients should not be modified when clipping is not necessary.",
)
# Test with large gradients that require clipping
large_grads = {
"first": [mx.array([10, 20]), mx.array([10])],
"second": mx.array([30]),
}
max_norm = 1.0 # A small max_norm that should trigger clipping
clipped_grads, total_norm = opt.clip_grad_norm(large_grads, max_norm)
# Correctly extract only the gradient values for norm calculation
clipped_values = [value for _, value in tree_flatten(clipped_grads)]
norm_of_clipped = mx.sqrt(
sum(mx.square(g).sum() for g in clipped_values)
).item()
self.assertAlmostEqual(
norm_of_clipped,
max_norm,
places=6,
msg="Clipped gradients norm should be close to the specified max_norm.",
)
# Ensures that the scaling was done correctly
scale = max_norm / total_norm
expected_grads = tree_map(lambda g: g * scale, large_grads)
self.assertTrue(
tree_equal(
lambda x, y: mx.allclose(x, y, atol=1e-6), expected_grads, clipped_grads
),
"Gradients were not scaled correctly during clipping.",
)
if __name__ == "__main__":
unittest.main()