diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 0d1c8b2ff..cc56bc430 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -597,7 +597,7 @@ class TestNN(mlx_tests.MLXTestCase): inputs = mx.ones((2, 4)) targets = mx.zeros((2, 4)) loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean") - self.assertEqual(loss, 0.433781) + self.assertAlmostEqual(loss.item(), 0.433781, places=6) if __name__ == "__main__":