Faster metal compiled kernels + some fixes (#1486)

* bump mac tests to use py39

* work per thread for compiled kernels

* fixe for large arrays

* fix
This commit is contained in:
Awni Hannun
2024-10-14 12:45:38 -07:00
committed by GitHub
parent 0eef4febfd
commit 881615b072
12 changed files with 157 additions and 108 deletions

View File

@@ -758,6 +758,20 @@ class TestCompile(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError):
y = mx.compile(fn)(x)
def test_compile_dynamic_dims(self):
a = mx.random.uniform(shape=(2,) * 10)
b = mx.random.uniform(shape=(2,) * 10)
a = a.T
mx.eval(a, b)
def fn(a, b):
return mx.abs(a + b)
out = mx.compile(fn)(a, b)
expected = fn(a, b)
print((out - expected).abs().max())
self.assertTrue(mx.allclose(out, expected))
if __name__ == "__main__":
unittest.main()