Dynamic broadcasting for shapeless compile/export (#1722)

* working towards dynamic broadcast

* shapeless broadcast

* fix build + nits

* use broadcast arrays in quantize matmul

* some cleanup / consistency

* mend

* some comments

* add vjp, jvp for broadcast axes
This commit is contained in:
Awni Hannun
2025-01-09 11:04:24 -08:00
committed by GitHub
parent ec36bfa317
commit 1ccaf80575
20 changed files with 471 additions and 163 deletions

View File

@@ -849,6 +849,79 @@ class TestCompile(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError):
compiled_fun(x)
def test_compile_shapeless_with_broadcast(self):
a = mx.array(0.0)
b = mx.ones((2, 2))
def fun(a):
return mx.broadcast_to(a, b.shape)
cfun = mx.compile(fun, shapeless=True)
# Works on the first shape
cfun(a)
# Fails on a different shape
with self.assertRaises(ValueError):
cfun(mx.array(0.0).reshape(1, 1, 1))
def fun(a, b):
return mx.broadcast_arrays(a, b)
cfun = mx.compile(fun, shapeless=True)
a, b = cfun(a, b)
self.assertEqual(a.shape, (2, 2))
self.assertEqual(b.shape, (2, 2))
# Batched matmul
a = mx.zeros((2, 1, 4, 2))
b = mx.zeros((3, 2, 5))
def fun(a, b):
return a @ b
cfun = mx.compile(fun, shapeless=True)
out = cfun(a, b)
self.assertEqual(out.shape, (2, 3, 4, 5))
# Shapeless compile should be preserved over vjp, jvp, vmap
def fun(args):
return sum(args).sum()
a = mx.array(0.0)
b = mx.ones((2, 2))
cfun = mx.compile(mx.grad(fun), shapeless=True)
out = cfun((a, b))
self.assertEqual(out[0].shape, ())
self.assertEqual(out[1].shape, (2, 2))
out = cfun((b, a))
self.assertEqual(out[0].shape, (2, 2))
self.assertEqual(out[1].shape, ())
# Shapeless compile should be preserved over vjp, jvp, vmap
def fun(args):
return (args[0] @ args[1]).sum()
a = mx.zeros((2, 1, 4, 2))
b = mx.zeros((3, 2, 5))
cfun = mx.compile(mx.grad(fun), shapeless=True)
out = cfun((a, b))
self.assertEqual(out[0].shape, (2, 1, 4, 2))
self.assertEqual(out[1].shape, (3, 2, 5))
a = mx.zeros((3, 1, 4, 2))
b = mx.zeros((2, 2, 5))
out = cfun((a, b))
self.assertEqual(out[0].shape, (3, 1, 4, 2))
self.assertEqual(out[1].shape, (2, 2, 5))
if __name__ == "__main__":
unittest.main()