Fix failing test for log cosh loss (#275)

* fix assert statement in log_cosh_loss

* reformatted by pre-commit black
This commit is contained in:
Vidit Agarwal 2023-12-24 05:56:46 +05:30 committed by GitHub
parent acf1721b98
commit 8c3da54c7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -597,7 +597,7 @@ class TestNN(mlx_tests.MLXTestCase):
inputs = mx.ones((2, 4))
targets = mx.zeros((2, 4))
loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean")
self.assertEqual(loss, 0.433781)
self.assertAlmostEqual(loss.item(), 0.433781, places=6)
if __name__ == "__main__":