mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Dynamic broadcasting for shapeless compile/export (#1722)
* working towards dynamic broadcast * shapeless broadcast * fix build + nits * use broadcast arrays in quantize matmul * some cleanup / consistency * mend * some comments * add vjp, jvp for broadcast axes
This commit is contained in:
@@ -849,6 +849,79 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
compiled_fun(x)
|
||||
|
||||
def test_compile_shapeless_with_broadcast(self):
|
||||
a = mx.array(0.0)
|
||||
b = mx.ones((2, 2))
|
||||
|
||||
def fun(a):
|
||||
return mx.broadcast_to(a, b.shape)
|
||||
|
||||
cfun = mx.compile(fun, shapeless=True)
|
||||
# Works on the first shape
|
||||
cfun(a)
|
||||
|
||||
# Fails on a different shape
|
||||
with self.assertRaises(ValueError):
|
||||
cfun(mx.array(0.0).reshape(1, 1, 1))
|
||||
|
||||
def fun(a, b):
|
||||
return mx.broadcast_arrays(a, b)
|
||||
|
||||
cfun = mx.compile(fun, shapeless=True)
|
||||
a, b = cfun(a, b)
|
||||
self.assertEqual(a.shape, (2, 2))
|
||||
self.assertEqual(b.shape, (2, 2))
|
||||
|
||||
# Batched matmul
|
||||
a = mx.zeros((2, 1, 4, 2))
|
||||
b = mx.zeros((3, 2, 5))
|
||||
|
||||
def fun(a, b):
|
||||
return a @ b
|
||||
|
||||
cfun = mx.compile(fun, shapeless=True)
|
||||
out = cfun(a, b)
|
||||
self.assertEqual(out.shape, (2, 3, 4, 5))
|
||||
|
||||
# Shapeless compile should be preserved over vjp, jvp, vmap
|
||||
def fun(args):
|
||||
return sum(args).sum()
|
||||
|
||||
a = mx.array(0.0)
|
||||
b = mx.ones((2, 2))
|
||||
|
||||
cfun = mx.compile(mx.grad(fun), shapeless=True)
|
||||
out = cfun((a, b))
|
||||
|
||||
self.assertEqual(out[0].shape, ())
|
||||
self.assertEqual(out[1].shape, (2, 2))
|
||||
|
||||
out = cfun((b, a))
|
||||
|
||||
self.assertEqual(out[0].shape, (2, 2))
|
||||
self.assertEqual(out[1].shape, ())
|
||||
|
||||
# Shapeless compile should be preserved over vjp, jvp, vmap
|
||||
def fun(args):
|
||||
return (args[0] @ args[1]).sum()
|
||||
|
||||
a = mx.zeros((2, 1, 4, 2))
|
||||
b = mx.zeros((3, 2, 5))
|
||||
|
||||
cfun = mx.compile(mx.grad(fun), shapeless=True)
|
||||
out = cfun((a, b))
|
||||
|
||||
self.assertEqual(out[0].shape, (2, 1, 4, 2))
|
||||
self.assertEqual(out[1].shape, (3, 2, 5))
|
||||
|
||||
a = mx.zeros((3, 1, 4, 2))
|
||||
b = mx.zeros((2, 2, 5))
|
||||
|
||||
out = cfun((a, b))
|
||||
|
||||
self.assertEqual(out[0].shape, (3, 1, 4, 2))
|
||||
self.assertEqual(out[1].shape, (2, 2, 5))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user