mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
feat: implement clip_grad_norm
(#1043)
* feat: implement `clip_grad_norm` * pre-commit * Add test for clip_grad_norm function in test_optimizers.py * small fixes * fix * lint * Update tree_reduce * Update python/mlx/utils.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/mlx/utils.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/mlx/utils.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/mlx/utils.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/mlx/utils.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/mlx/utils.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Refactor clip_grad_norm function to include documentation and improve readability * format docstring * Add acknowlegements * text wrap * pre-commit * nits in docs --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -376,6 +376,48 @@ class TestSchedulers(unittest.TestCase):
|
||||
update()
|
||||
self.assertAlmostEqual(lr_schedule(step), optimizer.learning_rate.item())
|
||||
|
||||
def test_clip_grad_norm(self):
|
||||
# Test with small gradients that do not require clipping
|
||||
small_grads = {
|
||||
"first": [mx.array([0.1, 0.2]), mx.array([0.1])],
|
||||
"second": mx.array([0.3]),
|
||||
}
|
||||
max_norm = 10.0 # A large max_norm that shouldn't trigger clipping
|
||||
clipped_grads, total_norm = opt.clip_grad_norm(small_grads, max_norm)
|
||||
self.assertTrue(
|
||||
tree_equal(lambda x, y: mx.array_equal(x, y), small_grads, clipped_grads),
|
||||
"Gradients should not be modified when clipping is not necessary.",
|
||||
)
|
||||
|
||||
# Test with large gradients that require clipping
|
||||
large_grads = {
|
||||
"first": [mx.array([10, 20]), mx.array([10])],
|
||||
"second": mx.array([30]),
|
||||
}
|
||||
max_norm = 1.0 # A small max_norm that should trigger clipping
|
||||
clipped_grads, total_norm = opt.clip_grad_norm(large_grads, max_norm)
|
||||
# Correctly extract only the gradient values for norm calculation
|
||||
clipped_values = [value for _, value in tree_flatten(clipped_grads)]
|
||||
norm_of_clipped = mx.sqrt(
|
||||
sum(mx.square(g).sum() for g in clipped_values)
|
||||
).item()
|
||||
self.assertAlmostEqual(
|
||||
norm_of_clipped,
|
||||
max_norm,
|
||||
places=6,
|
||||
msg="Clipped gradients norm should be close to the specified max_norm.",
|
||||
)
|
||||
|
||||
# Ensures that the scaling was done correctly
|
||||
scale = max_norm / total_norm
|
||||
expected_grads = tree_map(lambda g: g * scale, large_grads)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda x, y: mx.allclose(x, y, atol=1e-6), expected_grads, clipped_grads
|
||||
),
|
||||
"Gradients were not scaled correctly during clipping.",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user