mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Shapeless compilation for some graphs (#687)
* shapeless compilation for some graphs * update compile benchmark * default compile a few activations * buffer donation * bugfix * shapeless fix * update tests to work for cpu and gpu fusion * test kwargs * add kwargs to compile * Recompile when python arguments change * no compile for tanh * some constant tests --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -381,6 +381,164 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertFalse(mx.allclose(fun(), fun(), 1e-2, 1e-2))
|
||||
|
||||
def test_compile_kwargs(self):
|
||||
|
||||
@mx.compile
|
||||
def fun(x, y, z):
|
||||
return x + y + z
|
||||
|
||||
x = mx.array(1)
|
||||
y = mx.array(2)
|
||||
z = mx.array(3)
|
||||
out = fun(x, y=y, z=z)
|
||||
self.assertEqual(out.item(), 6)
|
||||
|
||||
def test_shapeless_compile(self):
|
||||
y = 1
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def fun(x):
|
||||
return x + y
|
||||
|
||||
x = mx.array([1, 2])
|
||||
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3])))
|
||||
|
||||
# The function is not recompiled, so the change
|
||||
# to y should not be reflected in the output
|
||||
y = 2
|
||||
x = mx.array([1, 2, 3])
|
||||
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3, 4])))
|
||||
|
||||
# Type change recompiles
|
||||
x = mx.array([1.0, 2.0, 3.0])
|
||||
self.assertTrue(mx.array_equal(fun(x), mx.array([3.0, 4.0, 5.0])))
|
||||
fun(x, y=y, z=z)
|
||||
|
||||
def test_shapeless_compile(self):
|
||||
y = 1
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def fun(x):
|
||||
return x + y
|
||||
|
||||
x = mx.array([1, 2])
|
||||
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3])))
|
||||
|
||||
# The function is not recompiled, so the change
|
||||
# to y should not be reflected in the output
|
||||
y = 2
|
||||
x = mx.array([1, 2, 3])
|
||||
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3, 4])))
|
||||
|
||||
# Type change recompiles
|
||||
x = mx.array([1.0, 2.0, 3.0])
|
||||
self.assertTrue(mx.array_equal(fun(x), mx.array([3.0, 4.0, 5.0])))
|
||||
|
||||
# Dim change recompiles
|
||||
x = mx.array([[1, 2, 3]])
|
||||
self.assertTrue(mx.array_equal(fun(x), mx.array([[3, 4, 5]])))
|
||||
|
||||
def test_shapeless_compile_with_broadcasts(self):
|
||||
x = mx.ones((2, 2))
|
||||
y = mx.array([2, 2])
|
||||
|
||||
def fun(x, y):
|
||||
return x * y
|
||||
|
||||
cfun = mx.compile(fun, shapeless=True)
|
||||
self.assertTrue(mx.array_equal(cfun(x, y), fun(x, y)))
|
||||
self.assertTrue(mx.array_equal(cfun(y, x), fun(y, x)))
|
||||
y = mx.array([[3]])
|
||||
self.assertTrue(mx.array_equal(cfun(x, y), fun(x, y)))
|
||||
self.assertTrue(mx.array_equal(cfun(y, x), fun(y, x)))
|
||||
|
||||
def test_shapeless_compile_with_reduction(self):
|
||||
# Test shapeless compile with a reduction
|
||||
z = 1
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def fun(x, y):
|
||||
return x + y.sum(0, keepdims=True) + z
|
||||
|
||||
x = mx.ones((2, 2), mx.int32)
|
||||
y = mx.ones((2, 2), mx.int32)
|
||||
self.assertTrue(mx.array_equal(fun(x, y), mx.full(shape=(2, 2), vals=4)))
|
||||
x = mx.ones((3, 3), mx.int32)
|
||||
y = mx.ones((3, 3), mx.int32)
|
||||
z = 2
|
||||
self.assertTrue(mx.array_equal(fun(x, y), mx.full(shape=(3, 3), vals=5)))
|
||||
|
||||
x1 = mx.array([[1, 2], [3, 4], [5, 6]])
|
||||
x2 = mx.array([[1, 2]])
|
||||
|
||||
def fun(x):
|
||||
return x * x.sum(-1, keepdims=True)
|
||||
|
||||
cfun = mx.compile(fun, shapeless=True)
|
||||
mx.eval(cfun(x1))
|
||||
self.assertTrue(mx.array_equal(fun(x2), cfun(x2)))
|
||||
|
||||
def test_compile_with_constant(self):
|
||||
|
||||
# Test float
|
||||
@partial(mx.compile)
|
||||
def fun(x, y):
|
||||
return x + y
|
||||
|
||||
z = fun(mx.array(1.0), 1.0)
|
||||
self.assertEqual(z.item(), 2.0)
|
||||
|
||||
z = fun(mx.array(1.0), 2.0)
|
||||
self.assertEqual(z.item(), 3.0)
|
||||
|
||||
z = fun(mx.array(1.0), y=1.0)
|
||||
self.assertEqual(z.item(), 2.0)
|
||||
|
||||
z = fun(mx.array(1.0), y=3.0)
|
||||
self.assertEqual(z.item(), 4.0)
|
||||
|
||||
# Test tuple
|
||||
@partial(mx.compile)
|
||||
def fun(x, y=(1, 2)):
|
||||
return x + y[0] + y[1]
|
||||
|
||||
z = fun(mx.array(1))
|
||||
self.assertEqual(z.item(), 4)
|
||||
|
||||
z = fun(mx.array(1), (2, 2))
|
||||
self.assertEqual(z.item(), 5)
|
||||
|
||||
z = fun(mx.array(1), (2, 1))
|
||||
self.assertEqual(z.item(), 4)
|
||||
|
||||
# Test bool
|
||||
@partial(mx.compile)
|
||||
def fun(x, y):
|
||||
if y:
|
||||
return x + 1
|
||||
else:
|
||||
return x + 2
|
||||
|
||||
z = fun(mx.array(1), True)
|
||||
self.assertEqual(z.item(), 2)
|
||||
|
||||
z = fun(mx.array(1), False)
|
||||
self.assertEqual(z.item(), 3)
|
||||
|
||||
# Test string
|
||||
@partial(mx.compile)
|
||||
def fun(x, y):
|
||||
if y == "one":
|
||||
return x + 1
|
||||
else:
|
||||
return x + 2
|
||||
|
||||
z = fun(mx.array(1), "one")
|
||||
self.assertEqual(z.item(), 2)
|
||||
|
||||
z = fun(mx.array(1), "two")
|
||||
self.assertEqual(z.item(), 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user