diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index 0f0021050..d68092f4a 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -31,9 +31,14 @@ def cross_entropy( Computes the cross entropy loss. Args: - logits (array): The unnormalized predicted logits. - targets (array): The target values, as class indices. - weights (array, optional): Weights for each target. Default: ``None``. + logits (array): The unnormalized logits. + targets (array): The ground truth values. These can be class indices or + probabilities for each class. If the ``targets`` are class indices, + then ``targets`` shape should match the ``logits`` shape with + the ``axis`` dimension removed. If the ``targets`` are probabilities + (or one-hot encoded), then the ``targets`` shape should be the same as + the ``logits`` shape. + weights (array, optional): Optional weights for each target. Default: ``None``. axis (int, optional): The axis over which to compute softmax. Default: ``-1``. label_smoothing (float, optional): Label smoothing factor. Default: ``0``. reduction (str, optional): Specifies the reduction to apply to the output: @@ -41,11 +46,46 @@ def cross_entropy( Returns: array: The computed cross entropy loss. + + Examples: + >>> import mlx.core as mx + >>> import mlx.nn as nn + >>> + >>> # Class indices as targets + >>> logits = mx.array([[2.0, -1.0], [-1.0, 2.0]]) + >>> targets = mx.array([0, 1]) + >>> nn.losses.cross_entropy(logits, targets) + array([0.0485873, 0.0485873], dtype=float32) + >>> + >>> # Probabilities (or one-hot vectors) as targets + >>> logits = mx.array([[2.0, -1.0], [-1.0, 2.0]]) + >>> targets = mx.array([[0.9, 0.1], [0.1, 0.9]]) + >>> nn.losses.cross_entropy(logits, targets) + array([0.348587, 0.348587], dtype=float32) """ if label_smoothing < 0 or label_smoothing >= 1: raise ValueError(f"Label smoothing must in [0, 1), got {label_smoothing}.") - score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1) + # Whether targets are class indices or probabilities + targets_as_probs = targets.ndim == logits.ndim + + def _drop_dim(shape, axis): + shape.pop(axis) + return shape + + # Check shapes in two cases: targets as class indices and targets as probabilities + if (targets_as_probs and targets.shape != logits.shape) or ( + not targets_as_probs and targets.shape != _drop_dim(logits.shape, axis) + ): + raise ValueError( + f"Targets shape {targets.shape} does not match logits shape {logits.shape}." + ) + + if targets_as_probs: + score = mx.sum(logits * targets, axis=axis) + else: + score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1) + logsumexp_logits = mx.logsumexp(logits, axis=axis) if label_smoothing > 0: # Adjust the true class score with label smoothing @@ -62,10 +102,10 @@ def cross_entropy( # Apply weights if provided if weights is not None: - if weights.shape != targets.shape: + if weights.shape != loss.shape: raise ValueError( f"Weights with shape {weights.shape} is not the same as " - f"targets with shape {targets.shape}." + f"output loss with shape {loss.shape}." ) loss *= weights diff --git a/python/tests/test_losses.py b/python/tests/test_losses.py index 2682cbadc..c6db19983 100644 --- a/python/tests/test_losses.py +++ b/python/tests/test_losses.py @@ -10,100 +10,61 @@ import numpy as np class TestLosses(mlx_tests.MLXTestCase): def test_cross_entropy(self): + # No weights, no label smoothing logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]]) - targets = mx.array([0, 1]) + indices = mx.array([0, 1]) + expected = mx.array([0.0, 0.0]) + loss = nn.losses.cross_entropy(logits, indices, reduction="none") + self.assertTrue(mx.allclose(loss, expected)) - # Test with reduction 'none' - losses_none = nn.losses.cross_entropy(logits, targets, reduction="none") - expected_none = mx.array([0.0, 0.0]) - self.assertTrue(mx.array_equal(losses_none, expected_none)) + probs = mx.array([[1.0, 0.0], [0.0, 1.0]]) + loss = nn.losses.cross_entropy(logits, probs, reduction="none") + self.assertTrue(mx.isnan(loss).all()) # produce NaNs, like PyTorch - # Test with reduction 'mean' - losses_mean = nn.losses.cross_entropy(logits, targets, reduction="mean") - expected_mean = mx.mean(expected_none) - self.assertEqual(losses_mean, expected_mean) - - # Test with reduction 'sum' - losses_sum = nn.losses.cross_entropy(logits, targets, reduction="sum") - expected_sum = mx.sum(expected_none) - self.assertEqual(losses_sum, expected_sum) - - # Test cases with weights and no label smoothing + # With weights, no label smoothing logits = mx.array([[2.0, -1.0], [-1.0, 2.0]]) - targets = mx.array([0, 1]) + indices = mx.array([0, 1]) weights = mx.array([1.0, 2.0]) + expected = mx.array([0.04858735, 0.0971747]) + loss = nn.losses.cross_entropy( + logits, indices, weights=weights, reduction="none" + ) + self.assertTrue(mx.allclose(loss, expected)) - # Reduction 'none' - losses_none = nn.losses.cross_entropy( - logits, - targets, - weights=weights, - reduction="none", - ) - expected_none = mx.array([0.04858735, 0.0971747]) # Calculated losses - self.assertTrue( - np.allclose(losses_none, expected_none, atol=1e-5), - "Test case failed for cross_entropy loss --reduction='none' --weights=[1.0, 2.0]", - ) + probs = mx.array([[1.0, 0.0], [0.0, 1.0]]) + loss = nn.losses.cross_entropy(logits, probs, weights=weights, reduction="none") + self.assertTrue(mx.allclose(loss, expected)) - # Reduction 'mean' - losses_mean = nn.losses.cross_entropy( - logits, - targets, - weights=weights, - reduction="mean", - ) - expected_mean = mx.mean(expected_none) - self.assertTrue( - np.allclose(losses_mean, expected_mean, atol=1e-5), - "Test case failed for cross_entropy loss --reduction='mean' --weights=[1.0, 2.0]", + # No weights, with label smoothing + logits = mx.array([[2.0, -1.0], [-1.0, 2.0]]) + indices = mx.array([0, 1]) + expected = mx.array([0.498587, 0.498587]) + loss = nn.losses.cross_entropy( + logits, indices, label_smoothing=0.3, reduction="none" ) + self.assertTrue(mx.allclose(loss, expected)) - # Reduction 'sum' - losses_sum = nn.losses.cross_entropy( - logits, - targets, - weights=weights, - reduction="sum", - ) - expected_sum = mx.sum(expected_none) - self.assertTrue( - np.allclose(losses_sum, expected_sum, atol=1e-5), - "Test case failed for cross_entropy loss --reduction='sum' --weights=[1.0, 2.0]", + probs = mx.array([[1.0, 0.0], [0.0, 1.0]]) + loss = nn.losses.cross_entropy( + logits, probs, label_smoothing=0.3, reduction="none" ) + self.assertTrue(mx.allclose(loss, expected)) - # Test case with equal weights and label smoothing > 0 - logits = mx.array( - [[0, 0.2, 0.7, 0.1, 0], [0, 0.9, 0.2, 0.2, 1], [1, 0.2, 0.7, 0.9, 1]] + # With weights and label smoothing + logits = mx.array([[2.0, -1.0], [-1.0, 2.0]]) + indices = mx.array([0, 1]) + weights = mx.array([1.0, 2.0]) + expected = mx.array([0.49858734, 0.9971747]) + loss = nn.losses.cross_entropy( + logits, indices, weights=weights, label_smoothing=0.3, reduction="none" ) - target = mx.array([2, 1, 0]) + self.assertTrue(mx.allclose(loss, expected)) - losses_none = nn.losses.cross_entropy( - logits, target, label_smoothing=0.3, reduction="none" - ) - expected_none = mx.array([1.29693, 1.38617, 1.48176]) - self.assertTrue( - mx.allclose(expected_none, losses_none), - "Test case failed for cross_entropy --label_smoothing=0.3 --reduction='none'", - ) - - expected_mean = mx.mean(expected_none) - losses_mean = nn.losses.cross_entropy( - logits, target, label_smoothing=0.3, reduction="mean" - ) - self.assertTrue( - mx.allclose(losses_mean, expected_mean), - "Test case failed for cross_entropy --label_smoothing=0.3 --reduction='mean'", - ) - - expected_sum = mx.sum(expected_none) - losses_sum = nn.losses.cross_entropy( - logits, target, label_smoothing=0.3, reduction="sum" - ) - self.assertTrue( - mx.allclose(losses_sum, expected_sum), - "Test case failed for cross_entropy --label_smoothing=0.3 --reduction='sum'", + probs = mx.array([[1.0, 0.0], [0.0, 1.0]]) + loss = nn.losses.cross_entropy( + logits, probs, weights=weights, label_smoothing=0.3, reduction="none" ) + self.assertTrue(mx.allclose(loss, expected)) def test_binary_cross_entropy(self): def _test_logits_as_inputs():