mlx/python/mlx/nn/losses.py
Angelos Katharopoulos d1f86272a2 angelos's commit files
2023-11-29 10:42:59 -08:00

7 lines
224 B
Python

import mlx.core as mx
def cross_entropy(logits: mx.array, targets: mx.array, axis: int = -1):
score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1)
return mx.logsumexp(logits, axis=axis) - score