From 69505b4e9b4b4623f7abe145d62e2e3c070cd684 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 11 Dec 2023 09:26:49 -0800 Subject: [PATCH] fixes (#131) --- python/mlx/nn/losses.py | 2 +- python/tests/test_nn.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index c5795574c..b9d35d9b5 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -203,4 +203,4 @@ def _reduce(loss: mx.array, reduction: str = "none"): elif reduction == "none": return loss else: - raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.") \ No newline at end of file + raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.") diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 54ef9b32c..b93b77e66 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -107,17 +107,17 @@ class TestNN(mlx_tests.MLXTestCase): # Test with reduction 'none' losses_none = nn.losses.binary_cross_entropy(inputs, targets, reduction="none") expected_none = mx.array([[0.693147, 0.693147], [0.693147, 0.693147]]) - self.assertTrue(mx.array_equal(losses_none, expected_none)) + self.assertTrue(mx.allclose(losses_none, expected_none)) # Test with reduction 'mean' losses_mean = nn.losses.binary_cross_entropy(inputs, targets, reduction="mean") expected_mean = mx.mean(expected_none) - self.assertEqual(losses_mean, expected_mean) + self.assertTrue(mx.allclose(losses_mean, expected_mean)) # Test with reduction 'sum' losses_sum = nn.losses.binary_cross_entropy(inputs, targets, reduction="sum") expected_sum = mx.sum(expected_none) - self.assertEqual(losses_sum, expected_sum) + self.assertTrue(mx.allclose(losses_sum, expected_sum)) def test_gelu(self): inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414]