diff --git a/docs/src/python/nn/losses.rst b/docs/src/python/nn/losses.rst index 4c99ff15c..b6a202d4a 100644 --- a/docs/src/python/nn/losses.rst +++ b/docs/src/python/nn/losses.rst @@ -16,3 +16,4 @@ Loss Functions mse_loss nll_loss smooth_l1_loss + triplet_loss \ No newline at end of file diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index 3b0f31ce1..755656e4f 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -232,6 +232,48 @@ def smooth_l1_loss( return _reduce(loss, reduction) +def triplet_loss( + anchors: mx.array, + positives: mx.array, + negatives: mx.array, + axis: int = -1, + p: int = 2, + margin: float = 1.0, + eps: float = 1e-6, + reduction: str = "none", +) -> mx.array: + r""" + Computes the triplet loss for a set of anchor, positive, and negative samples. + Margin is represented with alpha in the math section. + + .. math:: + + L_{\text{triplet}} = \max\left(\|A - P\|_p - \|A - N\|_p + \alpha, 0\right) + + Args: + anchors (array): The anchor samples. + positives (array): The positive samples. + negatives (array): The negative samples. + axis (int, optional): The distribution axis. Default: ``-1``. + p (int, optional): The norm degree for pairwise distance. Default: ``2``. + margin (float, optional): Margin for the triplet loss. Defaults to ``1.0``. + eps (float, optional): Small positive constant to prevent numerical instability. Defaults to ``1e-6``. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + + Returns: + array: Computed triplet loss. If reduction is "none", returns a tensor of the same shape as input; + if reduction is "mean" or "sum", returns a scalar tensor. + """ + loss = mx.maximum( + mx.sqrt(mx.power(anchors - positives, p).sum(axis) + eps) + - mx.sqrt(mx.power(anchors - negatives, p).sum(axis) + eps) + + margin, + 0, + ) + return _reduce(loss, reduction) + + def _reduce(loss: mx.array, reduction: str = "none"): if reduction == "mean": return mx.mean(loss) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index d93aa3cb2..17ec5175c 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -239,6 +239,32 @@ class TestNN(mlx_tests.MLXTestCase): expected_sum = mx.sum(expected_none) self.assertTrue(mx.allclose(losses_sum, expected_sum)) + def test_triplet_loss(self): + anchors = mx.array([[1, 2, 3], [1, 2, 3]]) + positives = mx.array([[4, 5, 6], [0, -1, 2]]) + negatives = mx.array([[7, 8, 9], [3, 2, 3]]) + + # Test with reduction 'none' + losses_none = nn.losses.triplet_loss( + anchors, positives, negatives, reduction="none" + ) + expected_none = mx.array([0, 2.31662]) + self.assertTrue(mx.allclose(losses_none, expected_none)) + + # Test with reduction 'mean' + losses_mean = nn.losses.triplet_loss( + anchors, positives, negatives, reduction="mean" + ) + expected_mean = mx.mean(expected_none) + self.assertTrue(mx.allclose(losses_mean, expected_mean)) + + # Test with reduction 'sum' + losses_sum = nn.losses.triplet_loss( + anchors, positives, negatives, reduction="sum" + ) + expected_sum = mx.sum(expected_none) + self.assertTrue(mx.allclose(losses_sum, expected_sum)) + def test_gelu(self): inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414]