mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
fix assert statement in log_cosh_loss
This commit is contained in:
parent
f91f450141
commit
4d3c451b3f
@ -597,7 +597,7 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
inputs = mx.ones((2, 4))
|
inputs = mx.ones((2, 4))
|
||||||
targets = mx.zeros((2, 4))
|
targets = mx.zeros((2, 4))
|
||||||
loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean")
|
loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean")
|
||||||
self.assertEqual(loss, 0.433781)
|
self.assertAlmostEqual(loss.item(), 0.433781, places=6)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user