From d40854186911ce4ba3c2c73027cb9c95625d7ecf Mon Sep 17 00:00:00 2001 From: "YUN, Junwoo" <61632100+Jyun1998@users.noreply.github.com> Date: Tue, 2 Jan 2024 23:11:46 +0900 Subject: [PATCH] Update test_nn.py cos sim test --- python/tests/test_nn.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 2bab7e8ef..d1eedfa99 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -265,6 +265,32 @@ class TestNN(mlx_tests.MLXTestCase): expected_sum = mx.sum(expected_none) self.assertTrue(mx.allclose(losses_sum, expected_sum)) + def test_cosine_similarity_loss(self): + embeddings1 = mx.array([[0.5, 0.5, 0.2, 0.9], [0.1, 0.3, 0.5, 0.5]]) + embeddings2 = mx.array([[0.6, 0.4, 0.3, 0.8], [0.2, 0.5, 0.6, 0.4]]) + targets = mx.array([1, -1]) + + # Test with reduction 'none' + losses_none = nn.losses.cosine_similarity_loss( + embeddings1, embeddings2, targets, reduction="none" + ) + expected_none = mx.array([0.0146555, 0.961074]) + self.assertTrue(mx.allclose(losses_none, expected_none)) + + # Test with reduction 'mean' + losses_mean = nn.losses.cosine_similarity_loss( + embeddings1, embeddings2, targets, reduction="mean" + ) + expected_mean = mx.mean(expected_none) + self.assertTrue(mx.allclose(losses_mean, expected_mean)) + + # Test with reduction 'sum' + losses_sum = nn.losses.cosine_similarity_loss( + embeddings1, embeddings2, targets, reduction="sum" + ) + expected_sum = mx.sum(expected_none) + self.assertTrue(mx.allclose(losses_sum, expected_sum)) + def test_gelu(self): inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414] @@ -842,4 +868,5 @@ class TestNN(mlx_tests.MLXTestCase): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() +