From 83f63f2184e88e79df911ebfd02b95da0bee5c9f Mon Sep 17 00:00:00 2001 From: AtomicVar Date: Sat, 3 Feb 2024 02:57:31 +0800 Subject: [PATCH] Add Margin Ranking Loss (#536) --- docs/src/python/nn/losses.rst | 1 + python/mlx/nn/losses.py | 55 +++++++++++++++++++++++++++++++++++ python/tests/test_losses.py | 19 ++++++++++++ 3 files changed, 75 insertions(+) diff --git a/docs/src/python/nn/losses.rst b/docs/src/python/nn/losses.rst index 6c4327eb8..6a2e128c5 100644 --- a/docs/src/python/nn/losses.rst +++ b/docs/src/python/nn/losses.rst @@ -18,6 +18,7 @@ Loss Functions kl_div_loss l1_loss log_cosh_loss + margin_ranking_loss mse_loss nll_loss smooth_l1_loss diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index ae64ab3ac..a466c10ed 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -533,3 +533,58 @@ def cosine_similarity_loss( loss = mx.sum(x1 * x2, axis=axis) / mx.maximum(x1_norm * x2_norm, eps) return _reduce(loss, reduction) + + +def margin_ranking_loss( + inputs1: mx.array, + inputs2: mx.array, + targets: mx.array, + margin: float = 0.0, + reduction: Reduction = "none", +) -> mx.array: + r""" + Calculate the margin ranking loss that loss given inputs :math:`x_1`, :math:`x_2` and a label + :math:`y` (containing 1 or -1). + + The loss is given by: + + .. math:: + \text{loss} = \max (0, -y * (x_1 - x_2) + \text{margin}) + + Where :math:`y` represents ``targets``, :math:`x_1` represents ``inputs1`` and :math:`x_2` + represents ``inputs2``. + + Args: + inputs1 (array): Scores for the first input. + inputs2 (array): Scores for the second input. + targets (array): Labels indicating whether samples in ``inputs1`` should be ranked higher + than samples in ``inputs2``. Values should be 1 or -1. + margin (float, optional): The margin by which the scores should be separated. + Default: ``0.0``. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + + Returns: + array: The computed margin ranking loss. + + Examples: + >>> import mlx.core as mx + >>> import mlx.nn as nn + >>> targets = mx.array([1, 1, -1]) + >>> inputs1 = mx.array([-0.573409, -0.765166, -0.0638]) + >>> inputs2 = mx.array([0.75596, 0.225763, 0.256995]) + >>> loss = nn.losses.margin_ranking_loss(inputs1, inputs2, targets) + >>> loss + array(0.773433, dtype=float32) + """ + if not (inputs1.shape == inputs2.shape == targets.shape): + raise ValueError( + f"The shapes of the arguments do not match. The provided shapes are " + f"inputs1.shape={inputs1.shape}, inputs2.shape={inputs2.shape}, and " + f"targets.shape={targets.shape}." + ) + + differences = inputs1 - inputs2 + loss = mx.maximum(0, -targets * differences + margin) + + return _reduce(loss, reduction) diff --git a/python/tests/test_losses.py b/python/tests/test_losses.py index c6db19983..2160b0a6e 100644 --- a/python/tests/test_losses.py +++ b/python/tests/test_losses.py @@ -359,6 +359,25 @@ class TestLosses(mlx_tests.MLXTestCase): expected_sum = mx.sum(expected_none) self.assertTrue(mx.allclose(losses_sum, expected_sum)) + def test_margin_ranking_loss(self): + inputs1 = mx.array([-0.573409, -0.765166, -0.0638]) + inputs2 = mx.array([0.75596, 0.225763, 0.256995]) + targets = mx.array([1, 1, -1]) + + # Test with no margin + losses = nn.losses.margin_ranking_loss( + inputs1, inputs2, targets, reduction="none" + ) + expected = mx.array([1.329369, 0.990929, 0.0]) + self.assertTrue(mx.allclose(losses, expected)) + + # Test with margin + losses = nn.losses.margin_ranking_loss( + inputs1, inputs2, targets, margin=0.5, reduction="none" + ) + expected = mx.array([1.829369, 1.490929, 0.179205]) + self.assertTrue(mx.allclose(losses, expected)) + if __name__ == "__main__": unittest.main()