Enable cross_entropy loss to handle dense targets (#517)

* Enable cross_entropy loss to handle dense targets

Dense targets means probabilities or one-hot encodings.

* better shape check of weights

* nits in docstring

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
AtomicVar 2024-01-24 04:17:22 +08:00 committed by GitHub
parent 6b4b30e3fc
commit 755dcf6137
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 87 additions and 86 deletions

View File

@ -31,9 +31,14 @@ def cross_entropy(
Computes the cross entropy loss.
Args:
logits (array): The unnormalized predicted logits.
targets (array): The target values, as class indices.
weights (array, optional): Weights for each target. Default: ``None``.
logits (array): The unnormalized logits.
targets (array): The ground truth values. These can be class indices or
probabilities for each class. If the ``targets`` are class indices,
then ``targets`` shape should match the ``logits`` shape with
the ``axis`` dimension removed. If the ``targets`` are probabilities
(or one-hot encoded), then the ``targets`` shape should be the same as
the ``logits`` shape.
weights (array, optional): Optional weights for each target. Default: ``None``.
axis (int, optional): The axis over which to compute softmax. Default: ``-1``.
label_smoothing (float, optional): Label smoothing factor. Default: ``0``.
reduction (str, optional): Specifies the reduction to apply to the output:
@ -41,11 +46,46 @@ def cross_entropy(
Returns:
array: The computed cross entropy loss.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn as nn
>>>
>>> # Class indices as targets
>>> logits = mx.array([[2.0, -1.0], [-1.0, 2.0]])
>>> targets = mx.array([0, 1])
>>> nn.losses.cross_entropy(logits, targets)
array([0.0485873, 0.0485873], dtype=float32)
>>>
>>> # Probabilities (or one-hot vectors) as targets
>>> logits = mx.array([[2.0, -1.0], [-1.0, 2.0]])
>>> targets = mx.array([[0.9, 0.1], [0.1, 0.9]])
>>> nn.losses.cross_entropy(logits, targets)
array([0.348587, 0.348587], dtype=float32)
"""
if label_smoothing < 0 or label_smoothing >= 1:
raise ValueError(f"Label smoothing must in [0, 1), got {label_smoothing}.")
# Whether targets are class indices or probabilities
targets_as_probs = targets.ndim == logits.ndim
def _drop_dim(shape, axis):
shape.pop(axis)
return shape
# Check shapes in two cases: targets as class indices and targets as probabilities
if (targets_as_probs and targets.shape != logits.shape) or (
not targets_as_probs and targets.shape != _drop_dim(logits.shape, axis)
):
raise ValueError(
f"Targets shape {targets.shape} does not match logits shape {logits.shape}."
)
if targets_as_probs:
score = mx.sum(logits * targets, axis=axis)
else:
score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1)
logsumexp_logits = mx.logsumexp(logits, axis=axis)
if label_smoothing > 0:
# Adjust the true class score with label smoothing
@ -62,10 +102,10 @@ def cross_entropy(
# Apply weights if provided
if weights is not None:
if weights.shape != targets.shape:
if weights.shape != loss.shape:
raise ValueError(
f"Weights with shape {weights.shape} is not the same as "
f"targets with shape {targets.shape}."
f"output loss with shape {loss.shape}."
)
loss *= weights

View File

@ -10,100 +10,61 @@ import numpy as np
class TestLosses(mlx_tests.MLXTestCase):
def test_cross_entropy(self):
# No weights, no label smoothing
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
targets = mx.array([0, 1])
indices = mx.array([0, 1])
expected = mx.array([0.0, 0.0])
loss = nn.losses.cross_entropy(logits, indices, reduction="none")
self.assertTrue(mx.allclose(loss, expected))
# Test with reduction 'none'
losses_none = nn.losses.cross_entropy(logits, targets, reduction="none")
expected_none = mx.array([0.0, 0.0])
self.assertTrue(mx.array_equal(losses_none, expected_none))
probs = mx.array([[1.0, 0.0], [0.0, 1.0]])
loss = nn.losses.cross_entropy(logits, probs, reduction="none")
self.assertTrue(mx.isnan(loss).all()) # produce NaNs, like PyTorch
# Test with reduction 'mean'
losses_mean = nn.losses.cross_entropy(logits, targets, reduction="mean")
expected_mean = mx.mean(expected_none)
self.assertEqual(losses_mean, expected_mean)
# Test with reduction 'sum'
losses_sum = nn.losses.cross_entropy(logits, targets, reduction="sum")
expected_sum = mx.sum(expected_none)
self.assertEqual(losses_sum, expected_sum)
# Test cases with weights and no label smoothing
# With weights, no label smoothing
logits = mx.array([[2.0, -1.0], [-1.0, 2.0]])
targets = mx.array([0, 1])
indices = mx.array([0, 1])
weights = mx.array([1.0, 2.0])
expected = mx.array([0.04858735, 0.0971747])
loss = nn.losses.cross_entropy(
logits, indices, weights=weights, reduction="none"
)
self.assertTrue(mx.allclose(loss, expected))
# Reduction 'none'
losses_none = nn.losses.cross_entropy(
logits,
targets,
weights=weights,
reduction="none",
)
expected_none = mx.array([0.04858735, 0.0971747]) # Calculated losses
self.assertTrue(
np.allclose(losses_none, expected_none, atol=1e-5),
"Test case failed for cross_entropy loss --reduction='none' --weights=[1.0, 2.0]",
)
probs = mx.array([[1.0, 0.0], [0.0, 1.0]])
loss = nn.losses.cross_entropy(logits, probs, weights=weights, reduction="none")
self.assertTrue(mx.allclose(loss, expected))
# Reduction 'mean'
losses_mean = nn.losses.cross_entropy(
logits,
targets,
weights=weights,
reduction="mean",
)
expected_mean = mx.mean(expected_none)
self.assertTrue(
np.allclose(losses_mean, expected_mean, atol=1e-5),
"Test case failed for cross_entropy loss --reduction='mean' --weights=[1.0, 2.0]",
# No weights, with label smoothing
logits = mx.array([[2.0, -1.0], [-1.0, 2.0]])
indices = mx.array([0, 1])
expected = mx.array([0.498587, 0.498587])
loss = nn.losses.cross_entropy(
logits, indices, label_smoothing=0.3, reduction="none"
)
self.assertTrue(mx.allclose(loss, expected))
# Reduction 'sum'
losses_sum = nn.losses.cross_entropy(
logits,
targets,
weights=weights,
reduction="sum",
)
expected_sum = mx.sum(expected_none)
self.assertTrue(
np.allclose(losses_sum, expected_sum, atol=1e-5),
"Test case failed for cross_entropy loss --reduction='sum' --weights=[1.0, 2.0]",
probs = mx.array([[1.0, 0.0], [0.0, 1.0]])
loss = nn.losses.cross_entropy(
logits, probs, label_smoothing=0.3, reduction="none"
)
self.assertTrue(mx.allclose(loss, expected))
# Test case with equal weights and label smoothing > 0
logits = mx.array(
[[0, 0.2, 0.7, 0.1, 0], [0, 0.9, 0.2, 0.2, 1], [1, 0.2, 0.7, 0.9, 1]]
# With weights and label smoothing
logits = mx.array([[2.0, -1.0], [-1.0, 2.0]])
indices = mx.array([0, 1])
weights = mx.array([1.0, 2.0])
expected = mx.array([0.49858734, 0.9971747])
loss = nn.losses.cross_entropy(
logits, indices, weights=weights, label_smoothing=0.3, reduction="none"
)
target = mx.array([2, 1, 0])
self.assertTrue(mx.allclose(loss, expected))
losses_none = nn.losses.cross_entropy(
logits, target, label_smoothing=0.3, reduction="none"
)
expected_none = mx.array([1.29693, 1.38617, 1.48176])
self.assertTrue(
mx.allclose(expected_none, losses_none),
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='none'",
)
expected_mean = mx.mean(expected_none)
losses_mean = nn.losses.cross_entropy(
logits, target, label_smoothing=0.3, reduction="mean"
)
self.assertTrue(
mx.allclose(losses_mean, expected_mean),
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='mean'",
)
expected_sum = mx.sum(expected_none)
losses_sum = nn.losses.cross_entropy(
logits, target, label_smoothing=0.3, reduction="sum"
)
self.assertTrue(
mx.allclose(losses_sum, expected_sum),
"Test case failed for cross_entropy --label_smoothing=0.3 --reduction='sum'",
probs = mx.array([[1.0, 0.0], [0.0, 1.0]])
loss = nn.losses.cross_entropy(
logits, probs, weights=weights, label_smoothing=0.3, reduction="none"
)
self.assertTrue(mx.allclose(loss, expected))
def test_binary_cross_entropy(self):
def _test_logits_as_inputs():