mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Add Margin Ranking Loss (#536)
This commit is contained in:
parent
cb6156d35d
commit
83f63f2184
@ -18,6 +18,7 @@ Loss Functions
|
||||
kl_div_loss
|
||||
l1_loss
|
||||
log_cosh_loss
|
||||
margin_ranking_loss
|
||||
mse_loss
|
||||
nll_loss
|
||||
smooth_l1_loss
|
||||
|
@ -533,3 +533,58 @@ def cosine_similarity_loss(
|
||||
loss = mx.sum(x1 * x2, axis=axis) / mx.maximum(x1_norm * x2_norm, eps)
|
||||
|
||||
return _reduce(loss, reduction)
|
||||
|
||||
|
||||
def margin_ranking_loss(
|
||||
inputs1: mx.array,
|
||||
inputs2: mx.array,
|
||||
targets: mx.array,
|
||||
margin: float = 0.0,
|
||||
reduction: Reduction = "none",
|
||||
) -> mx.array:
|
||||
r"""
|
||||
Calculate the margin ranking loss that loss given inputs :math:`x_1`, :math:`x_2` and a label
|
||||
:math:`y` (containing 1 or -1).
|
||||
|
||||
The loss is given by:
|
||||
|
||||
.. math::
|
||||
\text{loss} = \max (0, -y * (x_1 - x_2) + \text{margin})
|
||||
|
||||
Where :math:`y` represents ``targets``, :math:`x_1` represents ``inputs1`` and :math:`x_2`
|
||||
represents ``inputs2``.
|
||||
|
||||
Args:
|
||||
inputs1 (array): Scores for the first input.
|
||||
inputs2 (array): Scores for the second input.
|
||||
targets (array): Labels indicating whether samples in ``inputs1`` should be ranked higher
|
||||
than samples in ``inputs2``. Values should be 1 or -1.
|
||||
margin (float, optional): The margin by which the scores should be separated.
|
||||
Default: ``0.0``.
|
||||
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||
|
||||
Returns:
|
||||
array: The computed margin ranking loss.
|
||||
|
||||
Examples:
|
||||
>>> import mlx.core as mx
|
||||
>>> import mlx.nn as nn
|
||||
>>> targets = mx.array([1, 1, -1])
|
||||
>>> inputs1 = mx.array([-0.573409, -0.765166, -0.0638])
|
||||
>>> inputs2 = mx.array([0.75596, 0.225763, 0.256995])
|
||||
>>> loss = nn.losses.margin_ranking_loss(inputs1, inputs2, targets)
|
||||
>>> loss
|
||||
array(0.773433, dtype=float32)
|
||||
"""
|
||||
if not (inputs1.shape == inputs2.shape == targets.shape):
|
||||
raise ValueError(
|
||||
f"The shapes of the arguments do not match. The provided shapes are "
|
||||
f"inputs1.shape={inputs1.shape}, inputs2.shape={inputs2.shape}, and "
|
||||
f"targets.shape={targets.shape}."
|
||||
)
|
||||
|
||||
differences = inputs1 - inputs2
|
||||
loss = mx.maximum(0, -targets * differences + margin)
|
||||
|
||||
return _reduce(loss, reduction)
|
||||
|
@ -359,6 +359,25 @@ class TestLosses(mlx_tests.MLXTestCase):
|
||||
expected_sum = mx.sum(expected_none)
|
||||
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
||||
|
||||
def test_margin_ranking_loss(self):
|
||||
inputs1 = mx.array([-0.573409, -0.765166, -0.0638])
|
||||
inputs2 = mx.array([0.75596, 0.225763, 0.256995])
|
||||
targets = mx.array([1, 1, -1])
|
||||
|
||||
# Test with no margin
|
||||
losses = nn.losses.margin_ranking_loss(
|
||||
inputs1, inputs2, targets, reduction="none"
|
||||
)
|
||||
expected = mx.array([1.329369, 0.990929, 0.0])
|
||||
self.assertTrue(mx.allclose(losses, expected))
|
||||
|
||||
# Test with margin
|
||||
losses = nn.losses.margin_ranking_loss(
|
||||
inputs1, inputs2, targets, margin=0.5, reduction="none"
|
||||
)
|
||||
expected = mx.array([1.829369, 1.490929, 0.179205])
|
||||
self.assertTrue(mx.allclose(losses, expected))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user