mlx.nn.losses.cross_entropy#

class mlx.nn.losses.cross_entropy(logits: array, targets: array, axis: int = - 1, reduction: str = 'none')#

Computes the cross entropy loss between logits and targets.

Parameters:
  • logits (mx.array) – The predicted logits.

  • targets (mx.array) – The target values.

  • axis (int, optional) – The axis over which to compute softmax. Default: -1.

  • reduction (str, optional) – Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'none'.

Returns:

The computed cross entropy loss.

Return type:

mx.array