Fix compile with non standard types (#745)

* refactor tree utils

* fix compile + tree code refactor

* Add an extra test

* add a few missing activations to docs

* hash structure

* Encode the full argument structure

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-02-26 19:28:53 -08:00
committed by GitHub
parent 08226ab491
commit fe1dabf272
8 changed files with 438 additions and 282 deletions

View File

@@ -539,6 +539,48 @@ class TestCompile(mlx_tests.MLXTestCase):
z = fun(mx.array(1), "two")
self.assertEqual(z.item(), 3)
# Test nested constant
@partial(mx.compile)
def fun(x, y):
if y[0][0] == 1:
return x + 1
else:
return x + 2
z = fun(mx.array(1), [[1]])
self.assertEqual(z.item(), 2)
z = fun(mx.array(1), [[0]])
self.assertEqual(z.item(), 3)
@partial(mx.compile)
def fun(x, a, b):
for ai in a:
for bi in b:
x = bi * x + ai
return x
z = fun(mx.array(1), [1, 1], [2])
self.assertEqual(z.item(), 7)
z = fun(mx.array(1), [1], [1, 2])
self.assertEqual(z.item(), 5)
counter = [0]
@partial(mx.compile)
def fun(x, y):
counter[0] += 1
return x + y
z = fun(mx.array(1), 1)
self.assertEqual(z.item(), 2)
z = fun(1, mx.array(1))
self.assertEqual(z.item(), 2)
self.assertEqual(counter[0], 2)
def test_compile_inf(self):
@mx.compile
@@ -548,6 +590,21 @@ class TestCompile(mlx_tests.MLXTestCase):
out = fun(mx.array([0.0]))
self.assertEqual(out.item(), False)
def test_unsupported_input_types(self):
class MyClass:
value = 1
@mx.compile
def fun(x, y):
return x + y.value
with self.assertRaises(ValueError):
out = fun(mx.array(0.0), MyClass())
with self.assertRaises(ValueError):
out = fun(mx.array(0.0), y=MyClass())
if __name__ == "__main__":
unittest.main()