logits (mx.array) – The predicted logits.
targets (mx.array) – The target values.
axis (int, optional) – The axis over which to compute softmax. Default: -1
.
reduction (str, optional) – Specifies the reduction to apply to the output:
'none'
| 'mean'
| 'sum'
. Default: 'none'
.