mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Dynamic broadcasting for shapeless compile/export (#1722)
* working towards dynamic broadcast * shapeless broadcast * fix build + nits * use broadcast arrays in quantize matmul * some cleanup / consistency * mend * some comments * add vjp, jvp for broadcast axes
This commit is contained in:
@@ -2799,6 +2799,27 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array: The output array with the new shape.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"broadcast_arrays",
|
||||
[](const nb::args& args, mx::StreamOrDevice s) {
|
||||
return broadcast_arrays(nb::cast<std::vector<mx::array>>(args), s);
|
||||
},
|
||||
nb::arg(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def broadcast_arrays(*arrays: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, ...]"),
|
||||
R"pbdoc(
|
||||
Broadcast arrays against one another.
|
||||
|
||||
The broadcasting semantics are the same as Numpy.
|
||||
|
||||
Args:
|
||||
*arrays (array): The input arrays.
|
||||
|
||||
Returns:
|
||||
tuple(array): The output arrays with the broadcasted shape.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"softmax",
|
||||
[](const mx::array& a,
|
||||
@@ -3853,8 +3874,8 @@ void init_ops(nb::module_& m) {
|
||||
|
||||
Args:
|
||||
file (file, str): Path to file to which the arrays are saved.
|
||||
args (arrays): Arrays to be saved.
|
||||
kwargs (arrays): Arrays to be saved. Each array will be saved
|
||||
*args (arrays): Arrays to be saved.
|
||||
**kwargs (arrays): Arrays to be saved. Each array will be saved
|
||||
with the associated keyword as the output file name.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
|
@@ -849,6 +849,79 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
compiled_fun(x)
|
||||
|
||||
def test_compile_shapeless_with_broadcast(self):
|
||||
a = mx.array(0.0)
|
||||
b = mx.ones((2, 2))
|
||||
|
||||
def fun(a):
|
||||
return mx.broadcast_to(a, b.shape)
|
||||
|
||||
cfun = mx.compile(fun, shapeless=True)
|
||||
# Works on the first shape
|
||||
cfun(a)
|
||||
|
||||
# Fails on a different shape
|
||||
with self.assertRaises(ValueError):
|
||||
cfun(mx.array(0.0).reshape(1, 1, 1))
|
||||
|
||||
def fun(a, b):
|
||||
return mx.broadcast_arrays(a, b)
|
||||
|
||||
cfun = mx.compile(fun, shapeless=True)
|
||||
a, b = cfun(a, b)
|
||||
self.assertEqual(a.shape, (2, 2))
|
||||
self.assertEqual(b.shape, (2, 2))
|
||||
|
||||
# Batched matmul
|
||||
a = mx.zeros((2, 1, 4, 2))
|
||||
b = mx.zeros((3, 2, 5))
|
||||
|
||||
def fun(a, b):
|
||||
return a @ b
|
||||
|
||||
cfun = mx.compile(fun, shapeless=True)
|
||||
out = cfun(a, b)
|
||||
self.assertEqual(out.shape, (2, 3, 4, 5))
|
||||
|
||||
# Shapeless compile should be preserved over vjp, jvp, vmap
|
||||
def fun(args):
|
||||
return sum(args).sum()
|
||||
|
||||
a = mx.array(0.0)
|
||||
b = mx.ones((2, 2))
|
||||
|
||||
cfun = mx.compile(mx.grad(fun), shapeless=True)
|
||||
out = cfun((a, b))
|
||||
|
||||
self.assertEqual(out[0].shape, ())
|
||||
self.assertEqual(out[1].shape, (2, 2))
|
||||
|
||||
out = cfun((b, a))
|
||||
|
||||
self.assertEqual(out[0].shape, (2, 2))
|
||||
self.assertEqual(out[1].shape, ())
|
||||
|
||||
# Shapeless compile should be preserved over vjp, jvp, vmap
|
||||
def fun(args):
|
||||
return (args[0] @ args[1]).sum()
|
||||
|
||||
a = mx.zeros((2, 1, 4, 2))
|
||||
b = mx.zeros((3, 2, 5))
|
||||
|
||||
cfun = mx.compile(mx.grad(fun), shapeless=True)
|
||||
out = cfun((a, b))
|
||||
|
||||
self.assertEqual(out[0].shape, (2, 1, 4, 2))
|
||||
self.assertEqual(out[1].shape, (3, 2, 5))
|
||||
|
||||
a = mx.zeros((3, 1, 4, 2))
|
||||
b = mx.zeros((2, 2, 5))
|
||||
|
||||
out = cfun((a, b))
|
||||
|
||||
self.assertEqual(out[0].shape, (3, 1, 4, 2))
|
||||
self.assertEqual(out[1].shape, (2, 2, 5))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@@ -2782,6 +2782,19 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
expected[1:, 2:, 3:] = update
|
||||
self.assertTrue(mx.array_equal(expected, out))
|
||||
|
||||
def test_broadcast_arrays(self):
|
||||
a = mx.array(1)
|
||||
b = mx.array(1.0)
|
||||
a, b = mx.broadcast_arrays(a, b)
|
||||
self.assertEqual(a.shape, ())
|
||||
self.assertEqual(a.dtype, mx.int32)
|
||||
self.assertEqual(b.shape, ())
|
||||
self.assertEqual(b.dtype, mx.float32)
|
||||
|
||||
a, b = mx.broadcast_arrays(mx.zeros((3, 1, 2)), mx.zeros((4, 1)))
|
||||
self.assertEqual(a.shape, (3, 4, 2))
|
||||
self.assertEqual(b.shape, (3, 4, 2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user