From e88f2d4a8e35656e8b4918fca4858a70480678ae Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 1 Oct 2025 16:49:55 -0700 Subject: [PATCH] fix cross entropy axis param (#2641) * fix cross entropy axis param * faster grad clipping --- python/mlx/nn/losses.py | 4 +++- python/mlx/optimizers/optimizers.py | 8 ++------ python/tests/test_losses.py | 14 ++++++++++++-- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index aceb1f98a..b765d25de 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -86,7 +86,9 @@ def cross_entropy( if targets_as_probs: score = mx.sum(logits * targets, axis=axis) else: - score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1) + score = mx.take_along_axis(logits, mx.expand_dims(targets, axis), axis).squeeze( + axis + ) logsumexp_logits = mx.logsumexp(logits, axis=axis) if label_smoothing > 0: diff --git a/python/mlx/optimizers/optimizers.py b/python/mlx/optimizers/optimizers.py index 39f6d760c..2cc9e26b1 100644 --- a/python/mlx/optimizers/optimizers.py +++ b/python/mlx/optimizers/optimizers.py @@ -971,10 +971,6 @@ def clip_grad_norm(grads, max_norm): """ norm_squared = tree_reduce(lambda acc, g: acc + g.square().sum(), grads, 0.0) total_norm = mx.sqrt(norm_squared) - normalizer = max_norm / (total_norm + 1e-6) - - def clipper(g): - return mx.where(total_norm < max_norm, g, g * normalizer) - - clipped_grads = tree_map(clipper, grads) + normalizer = mx.minimum(max_norm / (total_norm + 1e-6), 1.0) + clipped_grads = tree_map(lambda g: g * normalizer, grads) return clipped_grads, total_norm diff --git a/python/tests/test_losses.py b/python/tests/test_losses.py index 2ef1fa36c..cb22eec8e 100644 --- a/python/tests/test_losses.py +++ b/python/tests/test_losses.py @@ -60,9 +60,19 @@ class TestLosses(mlx_tests.MLXTestCase): ) self.assertTrue(mx.allclose(loss, expected)) - probs = mx.array([[1.0, 0.0], [0.0, 1.0]]) + # Test a different axis + logits = mx.random.normal((4, 8)) + targets = mx.array([1, 2, 3, 0]) loss = nn.losses.cross_entropy( - logits, probs, weights=weights, label_smoothing=0.3, reduction="none" + logits.T, + targets, + axis=0, + ) + targets = mx.array([1, 2, 3, 0]) + expected = nn.losses.cross_entropy( + logits, + targets, + axis=-1, ) self.assertTrue(mx.allclose(loss, expected))