mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
fix cross entropy axis param (#2641)
* fix cross entropy axis param * faster grad clipping
This commit is contained in:
@@ -86,7 +86,9 @@ def cross_entropy(
|
|||||||
if targets_as_probs:
|
if targets_as_probs:
|
||||||
score = mx.sum(logits * targets, axis=axis)
|
score = mx.sum(logits * targets, axis=axis)
|
||||||
else:
|
else:
|
||||||
score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1)
|
score = mx.take_along_axis(logits, mx.expand_dims(targets, axis), axis).squeeze(
|
||||||
|
axis
|
||||||
|
)
|
||||||
|
|
||||||
logsumexp_logits = mx.logsumexp(logits, axis=axis)
|
logsumexp_logits = mx.logsumexp(logits, axis=axis)
|
||||||
if label_smoothing > 0:
|
if label_smoothing > 0:
|
||||||
|
@@ -971,10 +971,6 @@ def clip_grad_norm(grads, max_norm):
|
|||||||
"""
|
"""
|
||||||
norm_squared = tree_reduce(lambda acc, g: acc + g.square().sum(), grads, 0.0)
|
norm_squared = tree_reduce(lambda acc, g: acc + g.square().sum(), grads, 0.0)
|
||||||
total_norm = mx.sqrt(norm_squared)
|
total_norm = mx.sqrt(norm_squared)
|
||||||
normalizer = max_norm / (total_norm + 1e-6)
|
normalizer = mx.minimum(max_norm / (total_norm + 1e-6), 1.0)
|
||||||
|
clipped_grads = tree_map(lambda g: g * normalizer, grads)
|
||||||
def clipper(g):
|
|
||||||
return mx.where(total_norm < max_norm, g, g * normalizer)
|
|
||||||
|
|
||||||
clipped_grads = tree_map(clipper, grads)
|
|
||||||
return clipped_grads, total_norm
|
return clipped_grads, total_norm
|
||||||
|
@@ -60,9 +60,19 @@ class TestLosses(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
self.assertTrue(mx.allclose(loss, expected))
|
self.assertTrue(mx.allclose(loss, expected))
|
||||||
|
|
||||||
probs = mx.array([[1.0, 0.0], [0.0, 1.0]])
|
# Test a different axis
|
||||||
|
logits = mx.random.normal((4, 8))
|
||||||
|
targets = mx.array([1, 2, 3, 0])
|
||||||
loss = nn.losses.cross_entropy(
|
loss = nn.losses.cross_entropy(
|
||||||
logits, probs, weights=weights, label_smoothing=0.3, reduction="none"
|
logits.T,
|
||||||
|
targets,
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
targets = mx.array([1, 2, 3, 0])
|
||||||
|
expected = nn.losses.cross_entropy(
|
||||||
|
logits,
|
||||||
|
targets,
|
||||||
|
axis=-1,
|
||||||
)
|
)
|
||||||
self.assertTrue(mx.allclose(loss, expected))
|
self.assertTrue(mx.allclose(loss, expected))
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user