binding + tests

This commit is contained in:
Awni Hannun 2024-12-06 20:05:00 -08:00
parent 2b9c24c517
commit 0c1155faf5
4 changed files with 39 additions and 5 deletions

View File

@ -80,7 +80,8 @@ bool allows_shapeless(const Primitive& p) {
typeid(p) == typeid(Partition) || typeid(p) == typeid(Select) ||
typeid(p) == typeid(NumberOfElements) || typeid(p) == typeid(Gather) ||
typeid(p) == typeid(Transpose) || typeid(p) == typeid(Concatenate) ||
typeid(p) == typeid(Matmul) || typeid(p) == typeid(QuantizedMatmul) ||
typeid(p) == typeid(Reshape) || typeid(p) == typeid(Matmul) ||
typeid(p) == typeid(QuantizedMatmul) ||
typeid(p) == typeid(fast::AffineQuantize) ||
typeid(p) == typeid(fast::LayerNorm) ||
typeid(p) == typeid(fast::RMSNorm) || typeid(p) == typeid(fast::RoPE) ||

View File

@ -4884,19 +4884,19 @@ void init_ops(nb::module_& m) {
"dynamic_reshape",
&dynamic_reshape,
nb::arg(),
"expression"_a,
"expressions"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def dynamic_reshape(a: array, /, expression: Sequence[Union[int, str]], *, stream: "
"def dynamic_reshape(a: array, /, expressions: Sequence[Union[int, str]], *, stream: "
"Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Dynamically reshape an array based on the given expression.
Args:
a (array): Input array.
expression (tuple(int or str)): The expression which determines the
output shape.
expressions (tuple(int or str)): The expressions which determine
the output shape.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.

View File

@ -809,6 +809,29 @@ class TestCompile(mlx_tests.MLXTestCase):
out = fun(*inputs)
self.assertTrue(mx.allclose(out, mx.full((2, 2), 20)))
def test_compile_shapeless_with_reshape(self):
def fun(a):
return mx.reshape(a, (4, 7, 4, 2))
cfun = mx.compile(fun, shapeless=True)
a = mx.zeros((4, 7, 8))
with self.assertRaises(ValueError):
b = cfun(a)
def fun(a):
return mx.dynamic_reshape(a, ("B", "L", 4, 2))
cfun = mx.compile(fun, shapeless=True)
b = cfun(a)
self.assertEqual(b.shape, (4, 7, 4, 2))
a = mx.zeros((4, 9, 8))
b = cfun(a)
self.assertEqual(b.shape, (4, 9, 4, 2))
if __name__ == "__main__":
unittest.main()

View File

@ -2718,6 +2718,16 @@ class TestOps(mlx_tests.MLXTestCase):
a = mx.dynamic_reshape(a, ())
self.assertEqual(a.shape, ())
a = mx.zeros((4, 4, 4))
b = mx.dynamic_reshape(a, ("a", "b", "c"))
self.assertEqual(b.shape, (4, 4, 4))
b = mx.dynamic_reshape(a, ("a*b", "c"))
self.assertEqual(b.shape, (4 * 4, 4))
b = mx.dynamic_reshape(a, ("a*b*c", 1, 1))
self.assertEqual(b.shape, (4 * 4 * 4, 1, 1))
if __name__ == "__main__":
unittest.main()