mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Add Margin Ranking Loss (#536)
This commit is contained in:
parent
cb6156d35d
commit
83f63f2184
@ -18,6 +18,7 @@ Loss Functions
|
|||||||
kl_div_loss
|
kl_div_loss
|
||||||
l1_loss
|
l1_loss
|
||||||
log_cosh_loss
|
log_cosh_loss
|
||||||
|
margin_ranking_loss
|
||||||
mse_loss
|
mse_loss
|
||||||
nll_loss
|
nll_loss
|
||||||
smooth_l1_loss
|
smooth_l1_loss
|
||||||
|
@ -533,3 +533,58 @@ def cosine_similarity_loss(
|
|||||||
loss = mx.sum(x1 * x2, axis=axis) / mx.maximum(x1_norm * x2_norm, eps)
|
loss = mx.sum(x1 * x2, axis=axis) / mx.maximum(x1_norm * x2_norm, eps)
|
||||||
|
|
||||||
return _reduce(loss, reduction)
|
return _reduce(loss, reduction)
|
||||||
|
|
||||||
|
|
||||||
|
def margin_ranking_loss(
|
||||||
|
inputs1: mx.array,
|
||||||
|
inputs2: mx.array,
|
||||||
|
targets: mx.array,
|
||||||
|
margin: float = 0.0,
|
||||||
|
reduction: Reduction = "none",
|
||||||
|
) -> mx.array:
|
||||||
|
r"""
|
||||||
|
Calculate the margin ranking loss that loss given inputs :math:`x_1`, :math:`x_2` and a label
|
||||||
|
:math:`y` (containing 1 or -1).
|
||||||
|
|
||||||
|
The loss is given by:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{loss} = \max (0, -y * (x_1 - x_2) + \text{margin})
|
||||||
|
|
||||||
|
Where :math:`y` represents ``targets``, :math:`x_1` represents ``inputs1`` and :math:`x_2`
|
||||||
|
represents ``inputs2``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs1 (array): Scores for the first input.
|
||||||
|
inputs2 (array): Scores for the second input.
|
||||||
|
targets (array): Labels indicating whether samples in ``inputs1`` should be ranked higher
|
||||||
|
than samples in ``inputs2``. Values should be 1 or -1.
|
||||||
|
margin (float, optional): The margin by which the scores should be separated.
|
||||||
|
Default: ``0.0``.
|
||||||
|
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||||
|
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The computed margin ranking loss.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import mlx.core as mx
|
||||||
|
>>> import mlx.nn as nn
|
||||||
|
>>> targets = mx.array([1, 1, -1])
|
||||||
|
>>> inputs1 = mx.array([-0.573409, -0.765166, -0.0638])
|
||||||
|
>>> inputs2 = mx.array([0.75596, 0.225763, 0.256995])
|
||||||
|
>>> loss = nn.losses.margin_ranking_loss(inputs1, inputs2, targets)
|
||||||
|
>>> loss
|
||||||
|
array(0.773433, dtype=float32)
|
||||||
|
"""
|
||||||
|
if not (inputs1.shape == inputs2.shape == targets.shape):
|
||||||
|
raise ValueError(
|
||||||
|
f"The shapes of the arguments do not match. The provided shapes are "
|
||||||
|
f"inputs1.shape={inputs1.shape}, inputs2.shape={inputs2.shape}, and "
|
||||||
|
f"targets.shape={targets.shape}."
|
||||||
|
)
|
||||||
|
|
||||||
|
differences = inputs1 - inputs2
|
||||||
|
loss = mx.maximum(0, -targets * differences + margin)
|
||||||
|
|
||||||
|
return _reduce(loss, reduction)
|
||||||
|
@ -359,6 +359,25 @@ class TestLosses(mlx_tests.MLXTestCase):
|
|||||||
expected_sum = mx.sum(expected_none)
|
expected_sum = mx.sum(expected_none)
|
||||||
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
self.assertTrue(mx.allclose(losses_sum, expected_sum))
|
||||||
|
|
||||||
|
def test_margin_ranking_loss(self):
|
||||||
|
inputs1 = mx.array([-0.573409, -0.765166, -0.0638])
|
||||||
|
inputs2 = mx.array([0.75596, 0.225763, 0.256995])
|
||||||
|
targets = mx.array([1, 1, -1])
|
||||||
|
|
||||||
|
# Test with no margin
|
||||||
|
losses = nn.losses.margin_ranking_loss(
|
||||||
|
inputs1, inputs2, targets, reduction="none"
|
||||||
|
)
|
||||||
|
expected = mx.array([1.329369, 0.990929, 0.0])
|
||||||
|
self.assertTrue(mx.allclose(losses, expected))
|
||||||
|
|
||||||
|
# Test with margin
|
||||||
|
losses = nn.losses.margin_ranking_loss(
|
||||||
|
inputs1, inputs2, targets, margin=0.5, reduction="none"
|
||||||
|
)
|
||||||
|
expected = mx.array([1.829369, 1.490929, 0.179205])
|
||||||
|
self.assertTrue(mx.allclose(losses, expected))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user