diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 5ddd0c0ca..fce5dc661 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -6,7 +6,6 @@ import unittest import mlx.core as mx import mlx.nn as nn - import mlx_tests import numpy as np from mlx.utils import tree_flatten, tree_map, tree_unflatten @@ -197,12 +196,18 @@ class TestNN(mlx_tests.MLXTestCase): delta = 1.0 # Test with reduction 'none' - losses_none = nn.losses.huber_loss(predictions, targets, delta, reduction="none") - expected_none = mx.array([0.125, 0.125, 0.125, 0.125]) # Example expected values + losses_none = nn.losses.huber_loss( + predictions, targets, delta, reduction="none" + ) + expected_none = mx.array( + [0.125, 0.125, 0.125, 0.125] + ) # Example expected values self.assertTrue(mx.allclose(losses_none, expected_none)) # Test with reduction 'mean' - losses_mean = nn.losses.huber_loss(predictions, targets, delta, reduction="mean") + losses_mean = nn.losses.huber_loss( + predictions, targets, delta, reduction="mean" + ) expected_mean = mx.mean(expected_none) self.assertTrue(mx.allclose(losses_mean, expected_mean)) @@ -237,17 +242,28 @@ class TestNN(mlx_tests.MLXTestCase): gamma = 2.0 # Test with reduction 'none' - losses_none = nn.losses.focal_loss(inputs, targets, alpha, gamma, reduction="none") - expected_none = mx.array([[0.0433217, 0.0433217, 0.25751, 0.466273], [0.000263401, 0.147487, 0.0433217, 0.0433217]]) + losses_none = nn.losses.focal_loss( + inputs, targets, alpha, gamma, reduction="none" + ) + expected_none = mx.array( + [ + [0.0433217, 0.0433217, 0.25751, 0.466273], + [0.000263401, 0.147487, 0.0433217, 0.0433217], + ] + ) self.assertTrue(mx.allclose(losses_none, expected_none)) # Test with reduction 'mean' - losses_mean = nn.losses.focal_loss(inputs, targets, alpha, gamma, reduction="mean") + losses_mean = nn.losses.focal_loss( + inputs, targets, alpha, gamma, reduction="mean" + ) expected_mean = mx.mean(expected_none) self.assertTrue(mx.allclose(losses_mean, expected_mean)) # Test with reduction 'sum' - losses_sum = nn.losses.focal_loss(inputs, targets, alpha, gamma, reduction="sum") + losses_sum = nn.losses.focal_loss( + inputs, targets, alpha, gamma, reduction="sum" + ) expected_sum = mx.sum(expected_none) self.assertTrue(mx.allclose(losses_sum, expected_sum)) @@ -258,37 +274,49 @@ class TestNN(mlx_tests.MLXTestCase): margin = 1.0 # Test with reduction 'none' - losses_none = nn.losses.contrastive_loss(embeddings1, embeddings2, targets, margin, reduction="none") + losses_none = nn.losses.contrastive_loss( + embeddings1, embeddings2, targets, margin, reduction="none" + ) expected_none = mx.array([0.2, 0.735425]) self.assertTrue(mx.allclose(losses_none, expected_none)) # Test with reduction 'mean' - losses_mean = nn.losses.contrastive_loss(embeddings1, embeddings2, targets, margin, reduction="mean") + losses_mean = nn.losses.contrastive_loss( + embeddings1, embeddings2, targets, margin, reduction="mean" + ) expected_mean = mx.mean(expected_none) self.assertTrue(mx.allclose(losses_mean, expected_mean)) # Test with reduction 'sum' - losses_sum = nn.losses.contrastive_loss(embeddings1, embeddings2, targets, margin, reduction="sum") + losses_sum = nn.losses.contrastive_loss( + embeddings1, embeddings2, targets, margin, reduction="sum" + ) expected_sum = mx.sum(expected_none) self.assertTrue(mx.allclose(losses_sum, expected_sum)) def test_cosine_similarity_loss(self): embeddings1 = mx.array([[0.5, 0.5, 0.2, 0.9], [0.1, 0.3, 0.5, 0.5]]) embeddings2 = mx.array([[0.6, 0.4, 0.3, 0.8], [0.2, 0.5, 0.6, 0.4]]) - targets = mx.array([1, -1]) + targets = mx.array([1, -1]) # Test with reduction 'none' - losses_none = nn.losses.cosine_similarity_loss(embeddings1, embeddings2, targets, reduction="none") + losses_none = nn.losses.cosine_similarity_loss( + embeddings1, embeddings2, targets, reduction="none" + ) expected_none = mx.array([0.0146555, 0.961074]) self.assertTrue(mx.allclose(losses_none, expected_none)) # Test with reduction 'mean' - losses_mean = nn.losses.cosine_similarity_loss(embeddings1, embeddings2, targets, reduction="mean") + losses_mean = nn.losses.cosine_similarity_loss( + embeddings1, embeddings2, targets, reduction="mean" + ) expected_mean = mx.mean(expected_none) self.assertTrue(mx.allclose(losses_mean, expected_mean)) # Test with reduction 'sum' - losses_sum = nn.losses.cosine_similarity_loss(embeddings1, embeddings2, targets, reduction="sum") + losses_sum = nn.losses.cosine_similarity_loss( + embeddings1, embeddings2, targets, reduction="sum" + ) expected_sum = mx.sum(expected_none) self.assertTrue(mx.allclose(losses_sum, expected_sum))