Add Hinge, Huber and LogCosh losses (#199)

This commit is contained in:
Nicholas Santavas 2023-12-22 19:28:10 +01:00 committed by GitHub
parent e8deca84e0
commit d35fa1db41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 117 additions and 3 deletions

View File

@ -123,7 +123,7 @@ To get more detailed information on the arrays in a :class:`Module` you can use
all the parameters in a :class:`Module` do:
.. code-block:: python
from mlx.utils import tree_map
shapes = tree_map(lambda p: p.shape, mlp.parameters())
@ -131,7 +131,7 @@ As another example, you can count the number of parameters in a :class:`Module`
with:
.. code-block:: python
from mlx.utils import tree_flatten
num_params = sum(v.size for _, v in tree_flatten(mlp.parameters()))

View File

@ -16,4 +16,7 @@ Loss Functions
mse_loss
nll_loss
smooth_l1_loss
triplet_loss
triplet_loss
hinge_loss
huber_loss
log_cosh_loss

View File

@ -1,5 +1,7 @@
# Copyright © 2023 Apple Inc.
import math
import mlx.core as mx
from mlx.nn.layers.base import Module
@ -283,3 +285,94 @@ def _reduce(loss: mx.array, reduction: str = "none"):
return loss
else:
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
def hinge_loss(
inputs: mx.array, targets: mx.array, reduction: str = "none"
) -> mx.array:
"""
Computes the hinge loss between inputs and targets.
.. math::
\text{hinge}(y, y_{\text{pred}}) = \max(0, 1 - y \cdot y_{\text{pred}})
Args:
inputs (array): The predicted values.
targets (array): The target values. They should be -1 or 1.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
Returns:
array: The computed hinge loss.
"""
loss = mx.maximum(1 - inputs * targets, 0)
return _reduce(loss, reduction)
def huber_loss(
inputs: mx.array, targets: mx.array, delta: float = 1.0, reduction: str = "none"
) -> mx.array:
"""
Computes the Huber loss between inputs and targets.
.. math::
L_{\delta}(a) =
\left\{ \begin{array}{ll}
\frac{1}{2} a^2 & \text{for } |a| \leq \delta, \\
\delta \left( |a| - \frac{1}{2} \delta \right) & \text{otherwise.}
\end{array} \right.
Args:
inputs (array): The predicted values.
targets (array): The target values.
delta (float, optional): The threshold at which to change between L1 and L2 loss.
Default: ``1.0``.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
Returns:
array: The computed Huber loss.
"""
errors = inputs - targets
abs_errors = mx.abs(errors)
quadratic = mx.minimum(abs_errors, delta)
linear = abs_errors - quadratic
loss = 0.5 * quadratic**2 + delta * linear
return _reduce(loss, reduction)
def log_cosh_loss(
inputs: mx.array, targets: mx.array, reduction: str = "none"
) -> mx.array:
"""
Computes the log cosh loss between inputs and targets.
Logcosh acts like L2 loss for small errors, ensuring stable gradients,
and like the L1 loss for large errors, reducing sensitivity to outliers. This
dual behavior offers a balanced, robust approach for regression tasks.
.. math::
\text{logcosh}(y_{\text{true}}, y_{\text{pred}}) =
\frac{1}{n} \sum_{i=1}^{n}
\log(\cosh(y_{\text{pred}}^{(i)} - y_{\text{true}}^{(i)}))
Args:
inputs (array): The predicted values.
targets (array): The target values.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
Returns:
array: The computed log cosh loss.
"""
errors = inputs - targets
loss = mx.logaddexp(errors, -errors) - math.log(2)
return _reduce(loss, reduction)

View File

@ -581,6 +581,24 @@ class TestNN(mlx_tests.MLXTestCase):
y = alibi(x.astype(mx.float16))
self.assertTrue(y.dtype, mx.float16)
def test_hinge_loss(self):
inputs = mx.ones((2, 4))
targets = mx.zeros((2, 4))
loss = nn.losses.hinge_loss(inputs, targets, reduction="mean")
self.assertEqual(loss, 1.0)
def test_huber_loss(self):
inputs = mx.ones((2, 4))
targets = mx.zeros((2, 4))
loss = nn.losses.huber_loss(inputs, targets, reduction="mean")
self.assertEqual(loss, 0.5)
def test_log_cosh_loss(self):
inputs = mx.ones((2, 4))
targets = mx.zeros((2, 4))
loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean")
self.assertEqual(loss, 0.433781)
if __name__ == "__main__":
unittest.main()