fix assert statement in log_cosh_loss

This commit is contained in:
vidit 2023-12-24 04:38:03 +05:30
parent f91f450141
commit 4d3c451b3f

View File

@ -597,7 +597,7 @@ class TestNN(mlx_tests.MLXTestCase):
inputs = mx.ones((2, 4)) inputs = mx.ones((2, 4))
targets = mx.zeros((2, 4)) targets = mx.zeros((2, 4))
loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean") loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean")
self.assertEqual(loss, 0.433781) self.assertAlmostEqual(loss.item(), 0.433781, places=6)
if __name__ == "__main__": if __name__ == "__main__":