fix cross entropy axis param (#2641)

* fix cross entropy axis param

* faster grad clipping
This commit is contained in:
Awni Hannun
2025-10-01 16:49:55 -07:00
committed by GitHub
parent 9cee557423
commit e88f2d4a8e
3 changed files with 17 additions and 9 deletions

View File

@@ -60,9 +60,19 @@ class TestLosses(mlx_tests.MLXTestCase):
)
self.assertTrue(mx.allclose(loss, expected))
probs = mx.array([[1.0, 0.0], [0.0, 1.0]])
# Test a different axis
logits = mx.random.normal((4, 8))
targets = mx.array([1, 2, 3, 0])
loss = nn.losses.cross_entropy(
logits, probs, weights=weights, label_smoothing=0.3, reduction="none"
logits.T,
targets,
axis=0,
)
targets = mx.array([1, 2, 3, 0])
expected = nn.losses.cross_entropy(
logits,
targets,
axis=-1,
)
self.assertTrue(mx.allclose(loss, expected))