mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 01:50:16 +08:00
Enable cross_entropy loss to handle dense targets (#517)
* Enable cross_entropy loss to handle dense targets Dense targets means probabilities or one-hot encodings. * better shape check of weights * nits in docstring --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -31,9 +31,14 @@ def cross_entropy(
|
||||
Computes the cross entropy loss.
|
||||
|
||||
Args:
|
||||
logits (array): The unnormalized predicted logits.
|
||||
targets (array): The target values, as class indices.
|
||||
weights (array, optional): Weights for each target. Default: ``None``.
|
||||
logits (array): The unnormalized logits.
|
||||
targets (array): The ground truth values. These can be class indices or
|
||||
probabilities for each class. If the ``targets`` are class indices,
|
||||
then ``targets`` shape should match the ``logits`` shape with
|
||||
the ``axis`` dimension removed. If the ``targets`` are probabilities
|
||||
(or one-hot encoded), then the ``targets`` shape should be the same as
|
||||
the ``logits`` shape.
|
||||
weights (array, optional): Optional weights for each target. Default: ``None``.
|
||||
axis (int, optional): The axis over which to compute softmax. Default: ``-1``.
|
||||
label_smoothing (float, optional): Label smoothing factor. Default: ``0``.
|
||||
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||
@@ -41,11 +46,46 @@ def cross_entropy(
|
||||
|
||||
Returns:
|
||||
array: The computed cross entropy loss.
|
||||
|
||||
Examples:
|
||||
>>> import mlx.core as mx
|
||||
>>> import mlx.nn as nn
|
||||
>>>
|
||||
>>> # Class indices as targets
|
||||
>>> logits = mx.array([[2.0, -1.0], [-1.0, 2.0]])
|
||||
>>> targets = mx.array([0, 1])
|
||||
>>> nn.losses.cross_entropy(logits, targets)
|
||||
array([0.0485873, 0.0485873], dtype=float32)
|
||||
>>>
|
||||
>>> # Probabilities (or one-hot vectors) as targets
|
||||
>>> logits = mx.array([[2.0, -1.0], [-1.0, 2.0]])
|
||||
>>> targets = mx.array([[0.9, 0.1], [0.1, 0.9]])
|
||||
>>> nn.losses.cross_entropy(logits, targets)
|
||||
array([0.348587, 0.348587], dtype=float32)
|
||||
"""
|
||||
if label_smoothing < 0 or label_smoothing >= 1:
|
||||
raise ValueError(f"Label smoothing must in [0, 1), got {label_smoothing}.")
|
||||
|
||||
score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1)
|
||||
# Whether targets are class indices or probabilities
|
||||
targets_as_probs = targets.ndim == logits.ndim
|
||||
|
||||
def _drop_dim(shape, axis):
|
||||
shape.pop(axis)
|
||||
return shape
|
||||
|
||||
# Check shapes in two cases: targets as class indices and targets as probabilities
|
||||
if (targets_as_probs and targets.shape != logits.shape) or (
|
||||
not targets_as_probs and targets.shape != _drop_dim(logits.shape, axis)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Targets shape {targets.shape} does not match logits shape {logits.shape}."
|
||||
)
|
||||
|
||||
if targets_as_probs:
|
||||
score = mx.sum(logits * targets, axis=axis)
|
||||
else:
|
||||
score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1)
|
||||
|
||||
logsumexp_logits = mx.logsumexp(logits, axis=axis)
|
||||
if label_smoothing > 0:
|
||||
# Adjust the true class score with label smoothing
|
||||
@@ -62,10 +102,10 @@ def cross_entropy(
|
||||
|
||||
# Apply weights if provided
|
||||
if weights is not None:
|
||||
if weights.shape != targets.shape:
|
||||
if weights.shape != loss.shape:
|
||||
raise ValueError(
|
||||
f"Weights with shape {weights.shape} is not the same as "
|
||||
f"targets with shape {targets.shape}."
|
||||
f"output loss with shape {loss.shape}."
|
||||
)
|
||||
loss *= weights
|
||||
|
||||
|
Reference in New Issue
Block a user