Fix logsumexp edge case (#740)

* fix logsumexp

* fix inf constant

* also fix power grad

* fix ternary dispatch
This commit is contained in:
Awni Hannun
2024-02-25 08:39:55 -08:00
committed by GitHub
parent ac02cf33bd
commit e6418781ab
12 changed files with 112 additions and 64 deletions

View File

@@ -539,6 +539,15 @@ class TestCompile(mlx_tests.MLXTestCase):
z = fun(mx.array(1), "two")
self.assertEqual(z.item(), 3)
def test_compile_inf(self):
@mx.compile
def fun(x):
return mx.isinf(x + 2)
out = fun(mx.array([0.0]))
self.assertEqual(out.item(), False)
if __name__ == "__main__":
unittest.main()