diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 9bfdb6ee9..072948137 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -80,7 +80,8 @@ bool allows_shapeless(const Primitive& p) { typeid(p) == typeid(Partition) || typeid(p) == typeid(Select) || typeid(p) == typeid(NumberOfElements) || typeid(p) == typeid(Gather) || typeid(p) == typeid(Transpose) || typeid(p) == typeid(Concatenate) || - typeid(p) == typeid(Matmul) || typeid(p) == typeid(QuantizedMatmul) || + typeid(p) == typeid(Reshape) || typeid(p) == typeid(Matmul) || + typeid(p) == typeid(QuantizedMatmul) || typeid(p) == typeid(fast::AffineQuantize) || typeid(p) == typeid(fast::LayerNorm) || typeid(p) == typeid(fast::RMSNorm) || typeid(p) == typeid(fast::RoPE) || diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 1386bc486..960175700 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4884,19 +4884,19 @@ void init_ops(nb::module_& m) { "dynamic_reshape", &dynamic_reshape, nb::arg(), - "expression"_a, + "expressions"_a, nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def dynamic_reshape(a: array, /, expression: Sequence[Union[int, str]], *, stream: " + "def dynamic_reshape(a: array, /, expressions: Sequence[Union[int, str]], *, stream: " "Union[None, Stream, Device] = None) -> array"), R"pbdoc( Dynamically reshape an array based on the given expression. Args: a (array): Input array. - expression (tuple(int or str)): The expression which determines the - output shape. + expressions (tuple(int or str)): The expressions which determine + the output shape. stream (Stream, optional): Stream or device. Defaults to ``None`` in which case the default stream of the default device is used. diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index feb8e6da6..235c552fc 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -809,6 +809,29 @@ class TestCompile(mlx_tests.MLXTestCase): out = fun(*inputs) self.assertTrue(mx.allclose(out, mx.full((2, 2), 20))) + def test_compile_shapeless_with_reshape(self): + def fun(a): + return mx.reshape(a, (4, 7, 4, 2)) + + cfun = mx.compile(fun, shapeless=True) + + a = mx.zeros((4, 7, 8)) + + with self.assertRaises(ValueError): + b = cfun(a) + + def fun(a): + return mx.dynamic_reshape(a, ("B", "L", 4, 2)) + + cfun = mx.compile(fun, shapeless=True) + + b = cfun(a) + self.assertEqual(b.shape, (4, 7, 4, 2)) + + a = mx.zeros((4, 9, 8)) + b = cfun(a) + self.assertEqual(b.shape, (4, 9, 4, 2)) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index dcf81d197..43c08262e 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2718,6 +2718,16 @@ class TestOps(mlx_tests.MLXTestCase): a = mx.dynamic_reshape(a, ()) self.assertEqual(a.shape, ()) + a = mx.zeros((4, 4, 4)) + b = mx.dynamic_reshape(a, ("a", "b", "c")) + self.assertEqual(b.shape, (4, 4, 4)) + + b = mx.dynamic_reshape(a, ("a*b", "c")) + self.assertEqual(b.shape, (4 * 4, 4)) + + b = mx.dynamic_reshape(a, ("a*b*c", 1, 1)) + self.assertEqual(b.shape, (4 * 4 * 4, 1, 1)) + if __name__ == "__main__": unittest.main()