Fix logsumexp edge case (#740)

* fix logsumexp

* fix inf constant

* also fix power grad

* fix ternary dispatch
This commit is contained in:
Awni Hannun
2024-02-25 08:39:55 -08:00
committed by GitHub
parent ac02cf33bd
commit e6418781ab
12 changed files with 112 additions and 64 deletions

View File

@@ -415,6 +415,14 @@ class TestAutograd(mlx_tests.MLXTestCase):
_, vjps = mx.vjp(func, (arr,), (cotan,))
self.assertEqual(vjps[0].item(), 8.0)
def test_power_grad(self):
def fun(x, y):
res = x - y
return res**x
grad = mx.grad(fun)(mx.array(1.0), mx.array(1.0))
self.assertEqual(grad.item(), 1.0)
if __name__ == "__main__":
unittest.main()