Add Margin Ranking Loss (#536)

This commit is contained in:
AtomicVar 2024-02-03 02:57:31 +08:00 committed by GitHub
parent cb6156d35d
commit 83f63f2184
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 75 additions and 0 deletions

View File

@ -18,6 +18,7 @@ Loss Functions
kl_div_loss kl_div_loss
l1_loss l1_loss
log_cosh_loss log_cosh_loss
margin_ranking_loss
mse_loss mse_loss
nll_loss nll_loss
smooth_l1_loss smooth_l1_loss

View File

@ -533,3 +533,58 @@ def cosine_similarity_loss(
loss = mx.sum(x1 * x2, axis=axis) / mx.maximum(x1_norm * x2_norm, eps) loss = mx.sum(x1 * x2, axis=axis) / mx.maximum(x1_norm * x2_norm, eps)
return _reduce(loss, reduction) return _reduce(loss, reduction)
def margin_ranking_loss(
inputs1: mx.array,
inputs2: mx.array,
targets: mx.array,
margin: float = 0.0,
reduction: Reduction = "none",
) -> mx.array:
r"""
Calculate the margin ranking loss that loss given inputs :math:`x_1`, :math:`x_2` and a label
:math:`y` (containing 1 or -1).
The loss is given by:
.. math::
\text{loss} = \max (0, -y * (x_1 - x_2) + \text{margin})
Where :math:`y` represents ``targets``, :math:`x_1` represents ``inputs1`` and :math:`x_2`
represents ``inputs2``.
Args:
inputs1 (array): Scores for the first input.
inputs2 (array): Scores for the second input.
targets (array): Labels indicating whether samples in ``inputs1`` should be ranked higher
than samples in ``inputs2``. Values should be 1 or -1.
margin (float, optional): The margin by which the scores should be separated.
Default: ``0.0``.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
Returns:
array: The computed margin ranking loss.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn as nn
>>> targets = mx.array([1, 1, -1])
>>> inputs1 = mx.array([-0.573409, -0.765166, -0.0638])
>>> inputs2 = mx.array([0.75596, 0.225763, 0.256995])
>>> loss = nn.losses.margin_ranking_loss(inputs1, inputs2, targets)
>>> loss
array(0.773433, dtype=float32)
"""
if not (inputs1.shape == inputs2.shape == targets.shape):
raise ValueError(
f"The shapes of the arguments do not match. The provided shapes are "
f"inputs1.shape={inputs1.shape}, inputs2.shape={inputs2.shape}, and "
f"targets.shape={targets.shape}."
)
differences = inputs1 - inputs2
loss = mx.maximum(0, -targets * differences + margin)
return _reduce(loss, reduction)

View File

@ -359,6 +359,25 @@ class TestLosses(mlx_tests.MLXTestCase):
expected_sum = mx.sum(expected_none) expected_sum = mx.sum(expected_none)
self.assertTrue(mx.allclose(losses_sum, expected_sum)) self.assertTrue(mx.allclose(losses_sum, expected_sum))
def test_margin_ranking_loss(self):
inputs1 = mx.array([-0.573409, -0.765166, -0.0638])
inputs2 = mx.array([0.75596, 0.225763, 0.256995])
targets = mx.array([1, 1, -1])
# Test with no margin
losses = nn.losses.margin_ranking_loss(
inputs1, inputs2, targets, reduction="none"
)
expected = mx.array([1.329369, 0.990929, 0.0])
self.assertTrue(mx.allclose(losses, expected))
# Test with margin
losses = nn.losses.margin_ranking_loss(
inputs1, inputs2, targets, margin=0.5, reduction="none"
)
expected = mx.array([1.829369, 1.490929, 0.179205])
self.assertTrue(mx.allclose(losses, expected))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()