logits (mx.array) – The predicted logits.
targets (mx.array) – The target values.
axis (int, optional) – The axis over which to compute softmax. Default: -1.
reduction (str, optional) – Specifies the reduction to apply to the output:
'none' | 'mean' | 'sum'. Default: 'none'.