mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-17 23:08:11 +08:00
fix cross entropy axis param (#2641)
* fix cross entropy axis param * faster grad clipping
This commit is contained in:
@@ -60,9 +60,19 @@ class TestLosses(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertTrue(mx.allclose(loss, expected))
|
||||
|
||||
probs = mx.array([[1.0, 0.0], [0.0, 1.0]])
|
||||
# Test a different axis
|
||||
logits = mx.random.normal((4, 8))
|
||||
targets = mx.array([1, 2, 3, 0])
|
||||
loss = nn.losses.cross_entropy(
|
||||
logits, probs, weights=weights, label_smoothing=0.3, reduction="none"
|
||||
logits.T,
|
||||
targets,
|
||||
axis=0,
|
||||
)
|
||||
targets = mx.array([1, 2, 3, 0])
|
||||
expected = nn.losses.cross_entropy(
|
||||
logits,
|
||||
targets,
|
||||
axis=-1,
|
||||
)
|
||||
self.assertTrue(mx.allclose(loss, expected))
|
||||
|
||||
|
Reference in New Issue
Block a user