From 8c3da54c7d4e6ffc061ec5985e7c0037d7e72110 Mon Sep 17 00:00:00 2001 From: Vidit Agarwal Date: Sun, 24 Dec 2023 05:56:46 +0530 Subject: [PATCH] Fix failing test for log cosh loss (#275) * fix assert statement in log_cosh_loss * reformatted by pre-commit black --- python/tests/test_nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 0d1c8b2ff..cc56bc430 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -597,7 +597,7 @@ class TestNN(mlx_tests.MLXTestCase): inputs = mx.ones((2, 4)) targets = mx.zeros((2, 4)) loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean") - self.assertEqual(loss, 0.433781) + self.assertAlmostEqual(loss.item(), 0.433781, places=6) if __name__ == "__main__":