Dynamic broadcasting for shapeless compile/export (#1722)

* working towards dynamic broadcast

* shapeless broadcast

* fix build + nits

* use broadcast arrays in quantize matmul

* some cleanup / consistency

* mend

* some comments

* add vjp, jvp for broadcast axes
This commit is contained in:
Awni Hannun
2025-01-09 11:04:24 -08:00
committed by GitHub
parent ec36bfa317
commit 1ccaf80575
20 changed files with 471 additions and 163 deletions

View File

@@ -2799,6 +2799,27 @@ void init_ops(nb::module_& m) {
Returns:
array: The output array with the new shape.
)pbdoc");
m.def(
"broadcast_arrays",
[](const nb::args& args, mx::StreamOrDevice s) {
return broadcast_arrays(nb::cast<std::vector<mx::array>>(args), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def broadcast_arrays(*arrays: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, ...]"),
R"pbdoc(
Broadcast arrays against one another.
The broadcasting semantics are the same as Numpy.
Args:
*arrays (array): The input arrays.
Returns:
tuple(array): The output arrays with the broadcasted shape.
)pbdoc");
m.def(
"softmax",
[](const mx::array& a,
@@ -3853,8 +3874,8 @@ void init_ops(nb::module_& m) {
Args:
file (file, str): Path to file to which the arrays are saved.
args (arrays): Arrays to be saved.
kwargs (arrays): Arrays to be saved. Each array will be saved
*args (arrays): Arrays to be saved.
**kwargs (arrays): Arrays to be saved. Each array will be saved
with the associated keyword as the output file name.
)pbdoc");
m.def(

View File

@@ -849,6 +849,79 @@ class TestCompile(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError):
compiled_fun(x)
def test_compile_shapeless_with_broadcast(self):
a = mx.array(0.0)
b = mx.ones((2, 2))
def fun(a):
return mx.broadcast_to(a, b.shape)
cfun = mx.compile(fun, shapeless=True)
# Works on the first shape
cfun(a)
# Fails on a different shape
with self.assertRaises(ValueError):
cfun(mx.array(0.0).reshape(1, 1, 1))
def fun(a, b):
return mx.broadcast_arrays(a, b)
cfun = mx.compile(fun, shapeless=True)
a, b = cfun(a, b)
self.assertEqual(a.shape, (2, 2))
self.assertEqual(b.shape, (2, 2))
# Batched matmul
a = mx.zeros((2, 1, 4, 2))
b = mx.zeros((3, 2, 5))
def fun(a, b):
return a @ b
cfun = mx.compile(fun, shapeless=True)
out = cfun(a, b)
self.assertEqual(out.shape, (2, 3, 4, 5))
# Shapeless compile should be preserved over vjp, jvp, vmap
def fun(args):
return sum(args).sum()
a = mx.array(0.0)
b = mx.ones((2, 2))
cfun = mx.compile(mx.grad(fun), shapeless=True)
out = cfun((a, b))
self.assertEqual(out[0].shape, ())
self.assertEqual(out[1].shape, (2, 2))
out = cfun((b, a))
self.assertEqual(out[0].shape, (2, 2))
self.assertEqual(out[1].shape, ())
# Shapeless compile should be preserved over vjp, jvp, vmap
def fun(args):
return (args[0] @ args[1]).sum()
a = mx.zeros((2, 1, 4, 2))
b = mx.zeros((3, 2, 5))
cfun = mx.compile(mx.grad(fun), shapeless=True)
out = cfun((a, b))
self.assertEqual(out[0].shape, (2, 1, 4, 2))
self.assertEqual(out[1].shape, (3, 2, 5))
a = mx.zeros((3, 1, 4, 2))
b = mx.zeros((2, 2, 5))
out = cfun((a, b))
self.assertEqual(out[0].shape, (3, 1, 4, 2))
self.assertEqual(out[1].shape, (2, 2, 5))
if __name__ == "__main__":
unittest.main()

View File

@@ -2782,6 +2782,19 @@ class TestOps(mlx_tests.MLXTestCase):
expected[1:, 2:, 3:] = update
self.assertTrue(mx.array_equal(expected, out))
def test_broadcast_arrays(self):
a = mx.array(1)
b = mx.array(1.0)
a, b = mx.broadcast_arrays(a, b)
self.assertEqual(a.shape, ())
self.assertEqual(a.dtype, mx.int32)
self.assertEqual(b.shape, ())
self.assertEqual(b.dtype, mx.float32)
a, b = mx.broadcast_arrays(mx.zeros((3, 1, 2)), mx.zeros((4, 1)))
self.assertEqual(a.shape, (3, 4, 2))
self.assertEqual(b.shape, (3, 4, 2))
if __name__ == "__main__":
unittest.main()