try dynamic reshape

This commit is contained in:
Awni Hannun
2024-12-06 12:09:08 -08:00
parent 40c62c1321
commit ee59d50293
7 changed files with 164 additions and 0 deletions

View File

@@ -4880,4 +4880,27 @@ void init_ops(nb::module_& m) {
Returns:
array: The imaginary part of ``a``.
)pbdoc");
m.def(
"dynamic_reshape",
&dynamic_reshape,
nb::arg(),
"expression"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def dynamic_reshape(a: array, /, expression: Sequence[Union[int, str]], *, stream: "
"Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Dynamically reshape an array based on the given expression.
Args:
a (array): Input array.
expression (tuple(int or str)): The expression which determines the
output shape.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The reshaped array.
)pbdoc");
}

View File

@@ -2713,6 +2713,11 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx.imag(z).dtype, mx.float32)
self.assertTrue(mx.array_equal(mx.imag(z), y))
def test_dynamic_reshape(self):
a = mx.array(1)[None, None]
a = mx.dynamic_reshape(a, ())
self.assertEqual(a.shape, ())
if __name__ == "__main__":
unittest.main()