mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Fix cpu compile (#934)
* fix one cpu bug, test for another * format hooks * simplify contiguity check for cpu compile * fix * add back donation * comment
This commit is contained in:
@@ -671,6 +671,26 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
out = cmean(x)
|
||||
self.assertTrue(mx.allclose(out, mean(x)))
|
||||
|
||||
def test_compile_broadcast_only(self):
|
||||
def fn(a):
|
||||
a = mx.broadcast_to(a, (1,))
|
||||
return a + a
|
||||
|
||||
out = mx.compile(fn)(mx.array(2.0))
|
||||
# Make sure repr can be called
|
||||
self.assertTrue(repr(out) is not None)
|
||||
self.assertTrue(mx.array_equal(out, mx.array([4.0])))
|
||||
|
||||
def test_compile_with_long_name(self):
|
||||
def fn(a, b):
|
||||
for _ in range(10):
|
||||
a = a - 1.0
|
||||
b = b - 1.0
|
||||
return a + b
|
||||
|
||||
out = mx.compile(fn)(mx.array(10.0), mx.array(20.0))
|
||||
self.assertEqual(out.item(), 10.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user