mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-02 09:18:11 +08:00
Fix compile with non standard types (#745)
* refactor tree utils * fix compile + tree code refactor * Add an extra test * add a few missing activations to docs * hash structure * Encode the full argument structure --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -539,6 +539,48 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
z = fun(mx.array(1), "two")
|
||||
self.assertEqual(z.item(), 3)
|
||||
|
||||
# Test nested constant
|
||||
@partial(mx.compile)
|
||||
def fun(x, y):
|
||||
if y[0][0] == 1:
|
||||
return x + 1
|
||||
else:
|
||||
return x + 2
|
||||
|
||||
z = fun(mx.array(1), [[1]])
|
||||
self.assertEqual(z.item(), 2)
|
||||
|
||||
z = fun(mx.array(1), [[0]])
|
||||
self.assertEqual(z.item(), 3)
|
||||
|
||||
@partial(mx.compile)
|
||||
def fun(x, a, b):
|
||||
for ai in a:
|
||||
for bi in b:
|
||||
x = bi * x + ai
|
||||
return x
|
||||
|
||||
z = fun(mx.array(1), [1, 1], [2])
|
||||
self.assertEqual(z.item(), 7)
|
||||
|
||||
z = fun(mx.array(1), [1], [1, 2])
|
||||
self.assertEqual(z.item(), 5)
|
||||
|
||||
counter = [0]
|
||||
|
||||
@partial(mx.compile)
|
||||
def fun(x, y):
|
||||
counter[0] += 1
|
||||
return x + y
|
||||
|
||||
z = fun(mx.array(1), 1)
|
||||
self.assertEqual(z.item(), 2)
|
||||
|
||||
z = fun(1, mx.array(1))
|
||||
self.assertEqual(z.item(), 2)
|
||||
|
||||
self.assertEqual(counter[0], 2)
|
||||
|
||||
def test_compile_inf(self):
|
||||
|
||||
@mx.compile
|
||||
@@ -548,6 +590,21 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
out = fun(mx.array([0.0]))
|
||||
self.assertEqual(out.item(), False)
|
||||
|
||||
def test_unsupported_input_types(self):
|
||||
|
||||
class MyClass:
|
||||
value = 1
|
||||
|
||||
@mx.compile
|
||||
def fun(x, y):
|
||||
return x + y.value
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
out = fun(mx.array(0.0), MyClass())
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
out = fun(mx.array(0.0), y=MyClass())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user