Fix cpu compile (#934)

* fix one cpu bug, test for another

* format hooks

* simplify contiguity check for cpu compile

* fix

* add back donation

* comment
This commit is contained in:
Awni Hannun
2024-04-01 17:37:12 -07:00
committed by GitHub
parent 639e06e1f3
commit 2427fa171e
5 changed files with 157 additions and 106 deletions

View File

@@ -671,6 +671,26 @@ class TestCompile(mlx_tests.MLXTestCase):
out = cmean(x)
self.assertTrue(mx.allclose(out, mean(x)))
def test_compile_broadcast_only(self):
def fn(a):
a = mx.broadcast_to(a, (1,))
return a + a
out = mx.compile(fn)(mx.array(2.0))
# Make sure repr can be called
self.assertTrue(repr(out) is not None)
self.assertTrue(mx.array_equal(out, mx.array([4.0])))
def test_compile_with_long_name(self):
def fn(a, b):
for _ in range(10):
a = a - 1.0
b = b - 1.0
return a + b
out = mx.compile(fn)(mx.array(10.0), mx.array(20.0))
self.assertEqual(out.item(), 10.0)
if __name__ == "__main__":
unittest.main()