mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 09:51:17 +08:00
Add Hinge, Huber and LogCosh losses (#199)
This commit is contained in:
parent
e8deca84e0
commit
d35fa1db41
@ -123,7 +123,7 @@ To get more detailed information on the arrays in a :class:`Module` you can use
|
|||||||
all the parameters in a :class:`Module` do:
|
all the parameters in a :class:`Module` do:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from mlx.utils import tree_map
|
from mlx.utils import tree_map
|
||||||
shapes = tree_map(lambda p: p.shape, mlp.parameters())
|
shapes = tree_map(lambda p: p.shape, mlp.parameters())
|
||||||
|
|
||||||
@ -131,7 +131,7 @@ As another example, you can count the number of parameters in a :class:`Module`
|
|||||||
with:
|
with:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
num_params = sum(v.size for _, v in tree_flatten(mlp.parameters()))
|
num_params = sum(v.size for _, v in tree_flatten(mlp.parameters()))
|
||||||
|
|
||||||
|
@ -16,4 +16,7 @@ Loss Functions
|
|||||||
mse_loss
|
mse_loss
|
||||||
nll_loss
|
nll_loss
|
||||||
smooth_l1_loss
|
smooth_l1_loss
|
||||||
triplet_loss
|
triplet_loss
|
||||||
|
hinge_loss
|
||||||
|
huber_loss
|
||||||
|
log_cosh_loss
|
@ -1,5 +1,7 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.nn.layers.base import Module
|
from mlx.nn.layers.base import Module
|
||||||
|
|
||||||
@ -283,3 +285,94 @@ def _reduce(loss: mx.array, reduction: str = "none"):
|
|||||||
return loss
|
return loss
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
|
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
|
||||||
|
|
||||||
|
|
||||||
|
def hinge_loss(
|
||||||
|
inputs: mx.array, targets: mx.array, reduction: str = "none"
|
||||||
|
) -> mx.array:
|
||||||
|
"""
|
||||||
|
Computes the hinge loss between inputs and targets.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
\text{hinge}(y, y_{\text{pred}}) = \max(0, 1 - y \cdot y_{\text{pred}})
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (array): The predicted values.
|
||||||
|
targets (array): The target values. They should be -1 or 1.
|
||||||
|
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||||
|
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The computed hinge loss.
|
||||||
|
"""
|
||||||
|
loss = mx.maximum(1 - inputs * targets, 0)
|
||||||
|
|
||||||
|
return _reduce(loss, reduction)
|
||||||
|
|
||||||
|
|
||||||
|
def huber_loss(
|
||||||
|
inputs: mx.array, targets: mx.array, delta: float = 1.0, reduction: str = "none"
|
||||||
|
) -> mx.array:
|
||||||
|
"""
|
||||||
|
Computes the Huber loss between inputs and targets.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
L_{\delta}(a) =
|
||||||
|
\left\{ \begin{array}{ll}
|
||||||
|
\frac{1}{2} a^2 & \text{for } |a| \leq \delta, \\
|
||||||
|
\delta \left( |a| - \frac{1}{2} \delta \right) & \text{otherwise.}
|
||||||
|
\end{array} \right.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (array): The predicted values.
|
||||||
|
targets (array): The target values.
|
||||||
|
delta (float, optional): The threshold at which to change between L1 and L2 loss.
|
||||||
|
Default: ``1.0``.
|
||||||
|
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||||
|
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The computed Huber loss.
|
||||||
|
"""
|
||||||
|
errors = inputs - targets
|
||||||
|
abs_errors = mx.abs(errors)
|
||||||
|
quadratic = mx.minimum(abs_errors, delta)
|
||||||
|
linear = abs_errors - quadratic
|
||||||
|
loss = 0.5 * quadratic**2 + delta * linear
|
||||||
|
|
||||||
|
return _reduce(loss, reduction)
|
||||||
|
|
||||||
|
|
||||||
|
def log_cosh_loss(
|
||||||
|
inputs: mx.array, targets: mx.array, reduction: str = "none"
|
||||||
|
) -> mx.array:
|
||||||
|
"""
|
||||||
|
Computes the log cosh loss between inputs and targets.
|
||||||
|
|
||||||
|
Logcosh acts like L2 loss for small errors, ensuring stable gradients,
|
||||||
|
and like the L1 loss for large errors, reducing sensitivity to outliers. This
|
||||||
|
dual behavior offers a balanced, robust approach for regression tasks.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
\text{logcosh}(y_{\text{true}}, y_{\text{pred}}) =
|
||||||
|
\frac{1}{n} \sum_{i=1}^{n}
|
||||||
|
\log(\cosh(y_{\text{pred}}^{(i)} - y_{\text{true}}^{(i)}))
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (array): The predicted values.
|
||||||
|
targets (array): The target values.
|
||||||
|
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||||
|
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The computed log cosh loss.
|
||||||
|
"""
|
||||||
|
errors = inputs - targets
|
||||||
|
loss = mx.logaddexp(errors, -errors) - math.log(2)
|
||||||
|
|
||||||
|
return _reduce(loss, reduction)
|
||||||
|
@ -581,6 +581,24 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
y = alibi(x.astype(mx.float16))
|
y = alibi(x.astype(mx.float16))
|
||||||
self.assertTrue(y.dtype, mx.float16)
|
self.assertTrue(y.dtype, mx.float16)
|
||||||
|
|
||||||
|
def test_hinge_loss(self):
|
||||||
|
inputs = mx.ones((2, 4))
|
||||||
|
targets = mx.zeros((2, 4))
|
||||||
|
loss = nn.losses.hinge_loss(inputs, targets, reduction="mean")
|
||||||
|
self.assertEqual(loss, 1.0)
|
||||||
|
|
||||||
|
def test_huber_loss(self):
|
||||||
|
inputs = mx.ones((2, 4))
|
||||||
|
targets = mx.zeros((2, 4))
|
||||||
|
loss = nn.losses.huber_loss(inputs, targets, reduction="mean")
|
||||||
|
self.assertEqual(loss, 0.5)
|
||||||
|
|
||||||
|
def test_log_cosh_loss(self):
|
||||||
|
inputs = mx.ones((2, 4))
|
||||||
|
targets = mx.zeros((2, 4))
|
||||||
|
loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean")
|
||||||
|
self.assertEqual(loss, 0.433781)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user