fix cross entropy axis param (#2641)

* fix cross entropy axis param

* faster grad clipping
This commit is contained in:
Awni Hannun
2025-10-01 16:49:55 -07:00
committed by GitHub
parent 9cee557423
commit e88f2d4a8e
3 changed files with 17 additions and 9 deletions

View File

@@ -86,7 +86,9 @@ def cross_entropy(
if targets_as_probs:
score = mx.sum(logits * targets, axis=axis)
else:
score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1)
score = mx.take_along_axis(logits, mx.expand_dims(targets, axis), axis).squeeze(
axis
)
logsumexp_logits = mx.logsumexp(logits, axis=axis)
if label_smoothing > 0:

View File

@@ -971,10 +971,6 @@ def clip_grad_norm(grads, max_norm):
"""
norm_squared = tree_reduce(lambda acc, g: acc + g.square().sum(), grads, 0.0)
total_norm = mx.sqrt(norm_squared)
normalizer = max_norm / (total_norm + 1e-6)
def clipper(g):
return mx.where(total_norm < max_norm, g, g * normalizer)
clipped_grads = tree_map(clipper, grads)
normalizer = mx.minimum(max_norm / (total_norm + 1e-6), 1.0)
clipped_grads = tree_map(lambda g: g * normalizer, grads)
return clipped_grads, total_norm