diff --git a/docs/src/python/nn/losses.rst b/docs/src/python/nn/losses.rst index 4808ce5ab..4c99ff15c 100644 --- a/docs/src/python/nn/losses.rst +++ b/docs/src/python/nn/losses.rst @@ -9,9 +9,10 @@ Loss Functions :toctree: _autosummary_functions :template: nn-module-template.rst - cross_entropy binary_cross_entropy + cross_entropy + kl_div_loss l1_loss mse_loss nll_loss - kl_div_loss + smooth_l1_loss diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index b9d35d9b5..3b0f31ce1 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -4,145 +4,138 @@ import mlx.core as mx from mlx.nn.layers.base import Module -def _make_loss_module(f): - def decorator(klass): - klass.__call__ = lambda self, inputs, targets: f( - inputs, targets, self.reduction - ) - return klass - - return decorator - - def cross_entropy( - logits: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none" + logits: mx.array, + targets: mx.array, + weights: mx.array = None, + axis: int = -1, + label_smoothing: float = 0.0, + reduction: str = "none", ) -> mx.array: """ - Computes the cross entropy loss between logits and targets. + Computes the cross entropy loss. Args: - logits (mx.array): The predicted logits. - targets (mx.array): The target values. + logits (array): The unnormalized predicted logits. + targets (array): The target values, as class indices. + weights (array, optional): Weights for each target. Default: ``None``. axis (int, optional): The axis over which to compute softmax. Default: ``-1``. + label_smoothing (float, optional): Label smoothing factor. Default: ``0``. reduction (str, optional): Specifies the reduction to apply to the output: - ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. Returns: - mx.array: The computed cross entropy loss. + array: The computed cross entropy loss. """ - score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1) - loss = mx.logsumexp(logits, axis=axis) - score + if label_smoothing < 0 or label_smoothing >= 1: + raise ValueError(f"Label smoothing must in [0, 1), got {label_smoothing}.") + score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1) + logsumexp_logits = mx.logsumexp(logits, axis=axis) + if label_smoothing > 0: + # Adjust the true class score with label smoothing + adjusted_score = (1 - label_smoothing) * score + + # Calculate the mean logit across the classes for smoothed loss + mean_logits = logits.mean(axis=axis) + smoothed_loss = -mean_logits * label_smoothing + + # Combine the adjusted score and smoothed loss with the logsumexp logits + loss = logsumexp_logits - adjusted_score + smoothed_loss + else: + loss = logsumexp_logits - score + + # Apply weights if provided + if weights is not None: + if weights.shape != targets.shape: + raise ValueError( + f"Weights with shape {weights.shape} is not the same as " + f"targets with shape {targets.shape}." + ) + loss *= weights + + # Apply reduction return _reduce(loss, reduction) def binary_cross_entropy( - inputs: mx.array, targets: mx.array, reduction: str = "none" + logits: mx.array, targets: mx.array, reduction: str = "none" ) -> mx.array: """ - Computes the binary cross entropy loss between inputs and targets. + Computes the binary cross entropy loss. Args: - inputs (mx.array): The predicted inputs (post-sigmoid probabilities). - targets (mx.array): The target values (binary labels). + logits (array): The unnormalized (pre-sigmoid) predicted logits. + targets (array): The binary target values in {0, 1}. reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. Returns: - mx.array: The computed binary cross entropy loss. + array: The computed binary cross entropy loss. Examples: >>> import mlx.core as mx >>> import mlx.nn as nn - >>> inputs = mx.array([0.1, 0.2, 0.3, 0.4]) + >>> inputs = mx.array([0.105361, 0.223144, 1.20397, 0.916291]) >>> targets = mx.array([0, 0, 1, 1]) - >>> loss = nn.losses.binary_cross_entropy(inputs, targets) + >>> loss = nn.losses.binary_cross_entropy(inputs, targets, "mean") >>> loss - array([0.612192]) + array([0.612192], dtype=float32) """ - loss = -targets * mx.log(inputs) - (1 - targets) * mx.log(1 - inputs) + loss = mx.logaddexp(0.0, logits) - targets * logits return _reduce(loss, reduction) -@_make_loss_module(binary_cross_entropy) -class BCELoss(Module): - """ - Binary Cross Entropy Loss module. - It computes the binary cross entropy loss between predicted probabilities (post-sigmoid inputs) and target binary labels. - - Args: - reduction (str, optional): Specifies the reduction to apply to the output: - - 'none': no reduction (default) - - 'mean': compute the mean loss - - 'sum': compute the sum of the loss - - Examples: - >>> import mlx.core as mx - >>> from mlx.nn.losses import BCELoss - >>> - >>> # Create BCELoss module with default reduction ('none') - >>> loss_module_none = BCELoss() - >>> inputs = mx.array([0.5, 0.7, 0.3]) - >>> targets = mx.array([1, 0, 1]) - >>> loss_none = loss_module_none(inputs, targets) - >>> print(loss_none) - array([0.693147, 1.20397, 1.20397], dtype=float32) - - >>> # Create BCELoss module with reduction 'mean' - >>> loss_module_mean = BCELoss(reduction='mean') - >>> loss_mean = loss_module_mean(inputs, targets) - >>> print(loss_mean) - array(1.0337, dtype=float32) - - >>> # Create BCELoss module with reduction 'sum' - >>> loss_module_sum = BCELoss(reduction='sum') - >>> loss_sum = loss_module_sum(inputs, targets) - >>> print(loss_sum) - array(3.10109, dtype=float32) - """ - - def __init__(self, reduction: str = "none"): - super().__init__() - - self.reduction = reduction - - def l1_loss( - predictions: mx.array, targets: mx.array, reduction: str = "none" + predictions: mx.array, targets: mx.array, reduction: str = "mean" ) -> mx.array: """ - Computes the L1 loss between predictions and targets. + Computes the L1 loss. Args: - predictions (mx.array): The predicted values. - targets (mx.array): The target values. + predictions (array): The predicted values. + targets (array): The target values. reduction (str, optional): Specifies the reduction to apply to the output: - ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``. Returns: - mx.array: The computed L1 loss. + array: The computed L1 loss. """ - loss = mx.mean(mx.abs(predictions - targets)) + if predictions.shape != targets.shape: + raise ValueError( + f"Predictions shape {predictions.shape} does not match " + f"targets shape {targets.shape}." + ) + loss = mx.abs(predictions - targets) return _reduce(loss, reduction) def mse_loss( - predictions: mx.array, targets: mx.array, reduction: str = "none" + predictions: mx.array, targets: mx.array, reduction: str = "mean" ) -> mx.array: """ - Computes the mean squared error loss between predictions and targets. + Computes the mean squared error loss. Args: - predictions (mx.array): The predicted values. - targets (mx.array): The target values. + predictions (array): The predicted values. + targets (array): The target values. reduction (str, optional): Specifies the reduction to apply to the output: - ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``. Returns: - mx.array: The computed mean squared error loss. + array: The computed mean squared error loss. """ - loss = mx.square(predictions - targets) + if predictions.shape != targets.shape: + raise ValueError( + f"Predictions shape {predictions.shape} does not match " + f"targets shape {targets.shape}." + ) + assert ( + predictions.shape == targets.shape + ), f"Shape of predictions {predictions.shape} and targets {targets.shape} must match" + + loss = mx.square(predictions - targets) return _reduce(loss, reduction) @@ -150,17 +143,17 @@ def nll_loss( inputs: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none" ) -> mx.array: """ - Computes the negative log likelihood loss between inputs and targets. + Computes the negative log likelihood loss. Args: - inputs (mx.array): The predicted distribution in log space. - targets (mx.array): The target values. + inputs (array): The predicted distribution in log space. + targets (array): The target values. axis (int, optional): The distribution axis. Default: ``-1``. reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. Returns: - mx.array: The computed NLL loss. + array: The computed NLL loss. """ loss = -mx.take_along_axis(inputs, targets[..., None], axis).squeeze(-1) @@ -171,8 +164,7 @@ def kl_div_loss( inputs: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none" ) -> mx.array: """ - Computes the Kullback-Leibler divergence loss between targets and the - inputs. + Computes the Kullback-Leibler divergence loss. Computes the following when ``reduction == 'none'``: @@ -181,20 +173,65 @@ def kl_div_loss( mx.exp(targets) * (targets - inputs).sum(axis) Args: - inputs (mx.array): Log probabilities for the predicted distribution. - targets (mx.array): Log probabilities for the target distribution. + inputs (array): Log probabilities for the predicted distribution. + targets (array): Log probabilities for the target distribution. axis (int, optional): The distribution axis. Default: ``-1``. reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. Returns: - mx.array: The computed Kullback-Leibler divergence loss. + array: The computed Kullback-Leibler divergence loss. """ loss = mx.sum(mx.exp(targets) * (targets - inputs), axis) return _reduce(loss, reduction) +def smooth_l1_loss( + predictions: mx.array, targets: mx.array, beta: float = 1.0, reduction: str = "mean" +) -> mx.array: + r""" + Computes the smooth L1 loss. + + The smooth L1 loss is a variant of the L1 loss which replaces the absolute + difference with a squared difference when the absolute difference is less + than ``beta``. + + The formula for the smooth L1 Loss is: + + .. math:: + + l = + \begin{cases} + 0.5 (x - y)^2, & \text{ if } & (x - y) < \beta \\ + |x - y| - 0.5 \beta, & & \text{otherwise} + \end{cases} + + Args: + predictions (array): Predicted values. + targets (array): Ground truth values. + beta (float, optional): The threshold after which the loss changes + from the squared to the absolute difference. Default: ``1.0``. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``. + + Returns: + array: The computed smooth L1 loss. + """ + if predictions.shape != targets.shape: + raise ValueError( + f"Predictions shape {predictions.shape} does not match " + f"targets shape {targets.shape}." + ) + + diff = predictions - targets + loss = mx.where( + diff < beta, 0.5 * mx.square(diff) / beta, mx.abs(diff) - 0.5 * beta + ) + + return _reduce(loss, reduction) + + def _reduce(loss: mx.array, reduction: str = "none"): if reduction == "mean": return mx.mean(loss) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 852816c20..d93aa3cb2 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -37,30 +37,169 @@ class TestNN(mlx_tests.MLXTestCase): expected_sum = mx.sum(expected_none) self.assertEqual(losses_sum, expected_sum) + # Test cases with weights and no label smoothing + logits = mx.array([[2.0, -1.0], [-1.0, 2.0]]) + targets = mx.array([0, 1]) + weights = mx.array([1.0, 2.0]) + + # Reduction 'none' + losses_none = nn.losses.cross_entropy( + logits, + targets, + weights=weights, + reduction="none", + ) + expected_none = mx.array([0.04858735, 0.0971747]) # Calculated losses + self.assertTrue( + np.allclose(losses_none, expected_none, atol=1e-5), + "Test case failed for cross_entropy loss --reduction='none' --weights=[1.0, 2.0]", + ) + + # Reduction 'mean' + losses_mean = nn.losses.cross_entropy( + logits, + targets, + weights=weights, + reduction="mean", + ) + expected_mean = mx.mean(expected_none) + self.assertTrue( + np.allclose(losses_mean, expected_mean, atol=1e-5), + "Test case failed for cross_entropy loss --reduction='mean' --weights=[1.0, 2.0]", + ) + + # Reduction 'sum' + losses_sum = nn.losses.cross_entropy( + logits, + targets, + weights=weights, + reduction="sum", + ) + expected_sum = mx.sum(expected_none) + self.assertTrue( + np.allclose(losses_sum, expected_sum, atol=1e-5), + "Test case failed for cross_entropy loss --reduction='sum' --weights=[1.0, 2.0]", + ) + + # Test case with equal weights and label smoothing > 0 + logits = mx.array( + [[0, 0.2, 0.7, 0.1, 0], [0, 0.9, 0.2, 0.2, 1], [1, 0.2, 0.7, 0.9, 1]] + ) + target = mx.array([2, 1, 0]) + + losses_none = nn.losses.cross_entropy( + logits, target, label_smoothing=0.3, reduction="none" + ) + expected_none = mx.array([1.29693, 1.38617, 1.48176]) + self.assertTrue( + mx.allclose(expected_none, losses_none), + "Test case failed for cross_entropy --label_smoothing=0.3 --reduction='none'", + ) + + expected_mean = mx.mean(expected_none) + losses_mean = nn.losses.cross_entropy( + logits, target, label_smoothing=0.3, reduction="mean" + ) + self.assertTrue( + mx.allclose(losses_mean, expected_mean), + "Test case failed for cross_entropy --label_smoothing=0.3 --reduction='mean'", + ) + + expected_sum = mx.sum(expected_none) + losses_sum = nn.losses.cross_entropy( + logits, target, label_smoothing=0.3, reduction="sum" + ) + self.assertTrue( + mx.allclose(losses_sum, expected_sum), + "Test case failed for cross_entropy --label_smoothing=0.3 --reduction='sum'", + ) + def test_l1_loss(self): predictions = mx.array([0.5, 0.2, 0.9, 0.0]) targets = mx.array([0.5, 0.2, 0.9, 0.0]) + + # Expected result + expected_none = mx.array([0, 0, 0, 0]).astype(mx.float32) + expected_sum = mx.sum(expected_none) + expected_mean = mx.mean(expected_none) + losses = nn.losses.l1_loss(predictions, targets, reduction="none") - self.assertEqual(losses, 0.0) + self.assertTrue( + mx.array_equal(losses, expected_none), + "Test failed for l1_loss --reduction='none'", + ) + + losses = nn.losses.l1_loss(predictions, targets, reduction="sum") + self.assertTrue(mx.array_equal(losses, expected_sum)) + + losses = nn.losses.l1_loss(predictions, targets, reduction="mean") + self.assertTrue(mx.array_equal(losses, expected_mean)) def test_mse_loss(self): predictions = mx.array([0.5, 0.2, 0.9, 0.0]) targets = mx.array([0.7, 0.1, 0.8, 0.2]) + expected_none = mx.array([0.04, 0.01, 0.01, 0.04]) + expected_mean = mx.mean(expected_none) + expected_sum = mx.sum(expected_none) + # Test with reduction 'none' losses_none = nn.losses.mse_loss(predictions, targets, reduction="none") - expected_none = mx.array([0.04, 0.01, 0.01, 0.04]) - self.assertTrue(mx.allclose(losses_none, expected_none)) + self.assertTrue( + np.allclose(losses_none, expected_none, 1e-5), + "Test case failed for mse_loss --reduction='none'", + ) # Test with reduction 'mean' losses_mean = nn.losses.mse_loss(predictions, targets, reduction="mean") - expected_mean = mx.mean(expected_none) - self.assertEqual(losses_mean, expected_mean) + self.assertEqual( + losses_mean, + expected_mean, + "Test case failed for mse_loss --reduction='mean'", + ) # Test with reduction 'sum' losses_sum = nn.losses.mse_loss(predictions, targets, reduction="sum") + self.assertEqual( + losses_sum, expected_sum, "Test case failed for mse_loss --reduction='sum'" + ) + + def test_smooth_l1_loss(self): + predictions = mx.array([1.5, 2.5, 0.5, 3.5]) + targets = mx.array([1.0, 2.0, 0.5, 2.5]) + beta = 1.0 + + # Expected results + expected_none = mx.array([0.125, 0.125, 0.0, 0.5]) expected_sum = mx.sum(expected_none) - self.assertEqual(losses_sum, expected_sum) + expected_mean = mx.mean(expected_none) + + # Test with reduction 'none' + loss_none = nn.losses.smooth_l1_loss( + predictions, targets, beta, reduction="none" + ) + self.assertTrue( + mx.array_equal(loss_none, expected_none), + "Test case failed for smooth_l1_loss --reduction='none'", + ) + + # Test with reduction 'sum' + loss_sum = nn.losses.smooth_l1_loss(predictions, targets, beta, reduction="sum") + self.assertEqual( + loss_sum, + expected_sum, + "Test case failed for smooth_l1_loss --reduction='sum'", + ) + + # Test with reduction 'mean' + loss_mean = nn.losses.smooth_l1_loss( + predictions, targets, beta, reduction="mean" + ) + self.assertEqual( + loss_mean, + expected_mean, + "Test case failed for smooth_l1_loss --reduction='mean'", + ) def test_nll_loss(self): logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]]) @@ -100,77 +239,6 @@ class TestNN(mlx_tests.MLXTestCase): expected_sum = mx.sum(expected_none) self.assertTrue(mx.allclose(losses_sum, expected_sum)) - def test_binary_cross_entropy(self): - inputs = mx.array([[0.5, 0.5, 0.2, 0.9], [0.1, 0.3, 0.5, 0.5]]) - targets = mx.array([[0.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0, 0.0]]) - - # Test with reduction 'none' - losses_none = nn.losses.binary_cross_entropy(inputs, targets, reduction="none") - expected_none = mx.array( - [ - [ - 0.6931471824645996, - 0.6931471824645996, - 0.2231435477733612, - 0.10536054521799088, - ], - [ - 2.3025851249694824, - 0.3566749691963196, - 0.6931471824645996, - 0.6931471824645996, - ], - ] - ) - self.assertTrue(mx.allclose(losses_none, expected_none, rtol=1e-5, atol=1e-8)) - - # Test with reduction 'mean' - losses_mean = nn.losses.binary_cross_entropy(inputs, targets, reduction="mean") - expected_mean = mx.mean(expected_none) - self.assertTrue(mx.allclose(losses_mean, expected_mean)) - - # Test with reduction 'sum' - losses_sum = nn.losses.binary_cross_entropy(inputs, targets, reduction="sum") - expected_sum = mx.sum(expected_none) - self.assertTrue(mx.allclose(losses_sum, expected_sum)) - - def test_bce_loss_module(self): - inputs = mx.array([[0.5, 0.5, 0.2, 0.9], [0.1, 0.3, 0.5, 0.5]]) - targets = mx.array([[0.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0, 0.0]]) - - # Test with reduction 'none' - loss_module_none = nn.losses.BCELoss(reduction="none") - losses_none = loss_module_none(inputs, targets) - expected_none = mx.array( - [ - [ - 0.6931471824645996, - 0.6931471824645996, - 0.2231435477733612, - 0.10536054521799088, - ], - [ - 2.3025851249694824, - 0.3566749691963196, - 0.6931471824645996, - 0.6931471824645996, - ], - ] - ) - self.assertTrue(mx.allclose(losses_none, expected_none, rtol=1e-5, atol=1e-8)) - - # Test with reduction 'mean' - loss_module_mean = nn.losses.BCELoss(reduction="mean") - losses_mean = loss_module_mean(inputs, targets) - expected_mean = mx.mean(expected_none) - self.assertTrue(mx.allclose(losses_mean, expected_mean)) - - # Test with reduction 'sum' - loss_module_sum = nn.losses.BCELoss(reduction="sum") - losses_sum = loss_module_sum(inputs, targets) - expected_sum = mx.sum(expected_none) - self.assertTrue(mx.allclose(losses_sum, expected_sum)) - def test_gelu(self): inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414]