From 0b8aeddac6d5dce8a71cf3b2d78699de08c7bc89 Mon Sep 17 00:00:00 2001 From: "YUN, Junwoo" <61632100+Jyun1998@users.noreply.github.com> Date: Tue, 9 Jan 2024 07:01:13 +0900 Subject: [PATCH] Additoinal losses (#336) * cosine similarity loss --------- Co-authored-by: Awni Hannun * Docstring nits --- docs/src/python/nn/losses.rst | 3 +- python/mlx/nn/losses.py | 58 ++++++++++++++++++++++++++++------- python/tests/test_losses.py | 25 +++++++++++++++ 3 files changed, 74 insertions(+), 12 deletions(-) diff --git a/docs/src/python/nn/losses.rst b/docs/src/python/nn/losses.rst index 3fb7589f8..5a80ba947 100644 --- a/docs/src/python/nn/losses.rst +++ b/docs/src/python/nn/losses.rst @@ -19,4 +19,5 @@ Loss Functions triplet_loss hinge_loss huber_loss - log_cosh_loss \ No newline at end of file + log_cosh_loss + cosine_similarity_loss \ No newline at end of file diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index 91316fd04..2a4c5bd9b 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -6,6 +6,17 @@ import mlx.core as mx from mlx.nn.layers.base import Module +def _reduce(loss: mx.array, reduction: str = "none"): + if reduction == "mean": + return mx.mean(loss) + elif reduction == "sum": + return mx.sum(loss) + elif reduction == "none": + return loss + else: + raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.") + + def cross_entropy( logits: mx.array, targets: mx.array, @@ -272,17 +283,6 @@ def triplet_loss( return _reduce(loss, reduction) -def _reduce(loss: mx.array, reduction: str = "none"): - if reduction == "mean": - return mx.mean(loss) - elif reduction == "sum": - return mx.sum(loss) - elif reduction == "none": - return loss - else: - raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.") - - def hinge_loss( inputs: mx.array, targets: mx.array, reduction: str = "none" ) -> mx.array: @@ -372,3 +372,39 @@ def log_cosh_loss( loss = mx.logaddexp(errors, -errors) - math.log(2) return _reduce(loss, reduction) + + +def cosine_similarity_loss( + x1: mx.array, + x2: mx.array, + axis: int = 1, + eps: float = 1e-8, + reduction: str = "none", +) -> mx.array: + r""" + Computes the cosine similarity between the two inputs. + + The cosine similarity loss is given by + + .. math:: + + \frac{x_1 \cdot x_2}{\max(\|x_1\| \cdot \|x_2\|, \epsilon)} + + Args: + x1 (mx.array): The first set of inputs. + x2 (mx.array): The second set of inputs. + axis (int, optional): The embedding axis. Default: ``1``. + eps (float, optional): The minimum value of the denominator used for + numerical stability. Default: ``1e-8``. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + + Returns: + mx.array: The computed cosine similarity loss. + """ + x1_norm = mx.linalg.norm(x1, axis=axis) + x2_norm = mx.linalg.norm(x2, axis=axis) + + loss = mx.sum(x1 * x2, axis=axis) / mx.maximum(x1_norm * x2_norm, eps) + + return _reduce(loss, reduction) diff --git a/python/tests/test_losses.py b/python/tests/test_losses.py index 706e336c2..63bd5a20e 100644 --- a/python/tests/test_losses.py +++ b/python/tests/test_losses.py @@ -274,6 +274,31 @@ class TestLosses(mlx_tests.MLXTestCase): loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean") self.assertAlmostEqual(loss.item(), 0.433781, places=6) + def test_cosine_similarity_loss(self): + embeddings1 = mx.array([[0.5, 0.5, 0.2, 0.9], [0.1, 0.3, 0.5, 0.5]]) + embeddings2 = mx.array([[0.6, 0.4, 0.3, 0.8], [0.2, 0.5, 0.6, 0.4]]) + + # Test with reduction 'none' + losses_none = nn.losses.cosine_similarity_loss( + embeddings1, embeddings2, reduction="none" + ) + expected_none = mx.array([0.985344, 0.961074]) + self.assertTrue(mx.allclose(losses_none, expected_none)) + + # Test with reduction 'mean' + losses_mean = nn.losses.cosine_similarity_loss( + embeddings1, embeddings2, reduction="mean" + ) + expected_mean = mx.mean(expected_none) + self.assertTrue(mx.allclose(losses_mean, expected_mean)) + + # Test with reduction 'sum' + losses_sum = nn.losses.cosine_similarity_loss( + embeddings1, embeddings2, reduction="sum" + ) + expected_sum = mx.sum(expected_none) + self.assertTrue(mx.allclose(losses_sum, expected_sum)) + if __name__ == "__main__": unittest.main()