From d35fa1db41b1d883998e73c6bb8f96cc1490ee03 Mon Sep 17 00:00:00 2001 From: Nicholas Santavas Date: Fri, 22 Dec 2023 19:28:10 +0100 Subject: [PATCH] Add Hinge, Huber and LogCosh losses (#199) --- docs/src/python/nn.rst | 4 +- docs/src/python/nn/losses.rst | 5 +- python/mlx/nn/losses.py | 93 +++++++++++++++++++++++++++++++++++ python/tests/test_nn.py | 18 +++++++ 4 files changed, 117 insertions(+), 3 deletions(-) diff --git a/docs/src/python/nn.rst b/docs/src/python/nn.rst index bc19a8162..4c9868171 100644 --- a/docs/src/python/nn.rst +++ b/docs/src/python/nn.rst @@ -123,7 +123,7 @@ To get more detailed information on the arrays in a :class:`Module` you can use all the parameters in a :class:`Module` do: .. code-block:: python - + from mlx.utils import tree_map shapes = tree_map(lambda p: p.shape, mlp.parameters()) @@ -131,7 +131,7 @@ As another example, you can count the number of parameters in a :class:`Module` with: .. code-block:: python - + from mlx.utils import tree_flatten num_params = sum(v.size for _, v in tree_flatten(mlp.parameters())) diff --git a/docs/src/python/nn/losses.rst b/docs/src/python/nn/losses.rst index b6a202d4a..3fb7589f8 100644 --- a/docs/src/python/nn/losses.rst +++ b/docs/src/python/nn/losses.rst @@ -16,4 +16,7 @@ Loss Functions mse_loss nll_loss smooth_l1_loss - triplet_loss \ No newline at end of file + triplet_loss + hinge_loss + huber_loss + log_cosh_loss \ No newline at end of file diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index 755656e4f..35aedf755 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -1,5 +1,7 @@ # Copyright © 2023 Apple Inc. +import math + import mlx.core as mx from mlx.nn.layers.base import Module @@ -283,3 +285,94 @@ def _reduce(loss: mx.array, reduction: str = "none"): return loss else: raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.") + + +def hinge_loss( + inputs: mx.array, targets: mx.array, reduction: str = "none" +) -> mx.array: + """ + Computes the hinge loss between inputs and targets. + + .. math:: + + \text{hinge}(y, y_{\text{pred}}) = \max(0, 1 - y \cdot y_{\text{pred}}) + + + Args: + inputs (array): The predicted values. + targets (array): The target values. They should be -1 or 1. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + + Returns: + array: The computed hinge loss. + """ + loss = mx.maximum(1 - inputs * targets, 0) + + return _reduce(loss, reduction) + + +def huber_loss( + inputs: mx.array, targets: mx.array, delta: float = 1.0, reduction: str = "none" +) -> mx.array: + """ + Computes the Huber loss between inputs and targets. + + .. math:: + + L_{\delta}(a) = + \left\{ \begin{array}{ll} + \frac{1}{2} a^2 & \text{for } |a| \leq \delta, \\ + \delta \left( |a| - \frac{1}{2} \delta \right) & \text{otherwise.} + \end{array} \right. + + Args: + inputs (array): The predicted values. + targets (array): The target values. + delta (float, optional): The threshold at which to change between L1 and L2 loss. + Default: ``1.0``. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + + Returns: + array: The computed Huber loss. + """ + errors = inputs - targets + abs_errors = mx.abs(errors) + quadratic = mx.minimum(abs_errors, delta) + linear = abs_errors - quadratic + loss = 0.5 * quadratic**2 + delta * linear + + return _reduce(loss, reduction) + + +def log_cosh_loss( + inputs: mx.array, targets: mx.array, reduction: str = "none" +) -> mx.array: + """ + Computes the log cosh loss between inputs and targets. + + Logcosh acts like L2 loss for small errors, ensuring stable gradients, + and like the L1 loss for large errors, reducing sensitivity to outliers. This + dual behavior offers a balanced, robust approach for regression tasks. + + .. math:: + + \text{logcosh}(y_{\text{true}}, y_{\text{pred}}) = + \frac{1}{n} \sum_{i=1}^{n} + \log(\cosh(y_{\text{pred}}^{(i)} - y_{\text{true}}^{(i)})) + + + Args: + inputs (array): The predicted values. + targets (array): The target values. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + + Returns: + array: The computed log cosh loss. + """ + errors = inputs - targets + loss = mx.logaddexp(errors, -errors) - math.log(2) + + return _reduce(loss, reduction) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index ebc6f2b7a..0d1c8b2ff 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -581,6 +581,24 @@ class TestNN(mlx_tests.MLXTestCase): y = alibi(x.astype(mx.float16)) self.assertTrue(y.dtype, mx.float16) + def test_hinge_loss(self): + inputs = mx.ones((2, 4)) + targets = mx.zeros((2, 4)) + loss = nn.losses.hinge_loss(inputs, targets, reduction="mean") + self.assertEqual(loss, 1.0) + + def test_huber_loss(self): + inputs = mx.ones((2, 4)) + targets = mx.zeros((2, 4)) + loss = nn.losses.huber_loss(inputs, targets, reduction="mean") + self.assertEqual(loss, 0.5) + + def test_log_cosh_loss(self): + inputs = mx.ones((2, 4)) + targets = mx.zeros((2, 4)) + loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean") + self.assertEqual(loss, 0.433781) + if __name__ == "__main__": unittest.main()