From 4d3c451b3f163cd41b9d2d6ae9a3ed20150b1085 Mon Sep 17 00:00:00 2001 From: vidit Date: Sun, 24 Dec 2023 04:38:03 +0530 Subject: [PATCH] fix assert statement in log_cosh_loss --- python/tests/test_nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 0d1c8b2ff..cc56bc430 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -597,7 +597,7 @@ class TestNN(mlx_tests.MLXTestCase): inputs = mx.ones((2, 4)) targets = mx.zeros((2, 4)) loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean") - self.assertEqual(loss, 0.433781) + self.assertAlmostEqual(loss.item(), 0.433781, places=6) if __name__ == "__main__":