mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
7 lines
224 B
Python
7 lines
224 B
Python
import mlx.core as mx
|
|
|
|
|
|
def cross_entropy(logits: mx.array, targets: mx.array, axis: int = -1):
|
|
score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1)
|
|
return mx.logsumexp(logits, axis=axis) - score
|