mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
binding + tests
This commit is contained in:
parent
2b9c24c517
commit
0c1155faf5
@ -80,7 +80,8 @@ bool allows_shapeless(const Primitive& p) {
|
||||
typeid(p) == typeid(Partition) || typeid(p) == typeid(Select) ||
|
||||
typeid(p) == typeid(NumberOfElements) || typeid(p) == typeid(Gather) ||
|
||||
typeid(p) == typeid(Transpose) || typeid(p) == typeid(Concatenate) ||
|
||||
typeid(p) == typeid(Matmul) || typeid(p) == typeid(QuantizedMatmul) ||
|
||||
typeid(p) == typeid(Reshape) || typeid(p) == typeid(Matmul) ||
|
||||
typeid(p) == typeid(QuantizedMatmul) ||
|
||||
typeid(p) == typeid(fast::AffineQuantize) ||
|
||||
typeid(p) == typeid(fast::LayerNorm) ||
|
||||
typeid(p) == typeid(fast::RMSNorm) || typeid(p) == typeid(fast::RoPE) ||
|
||||
|
@ -4884,19 +4884,19 @@ void init_ops(nb::module_& m) {
|
||||
"dynamic_reshape",
|
||||
&dynamic_reshape,
|
||||
nb::arg(),
|
||||
"expression"_a,
|
||||
"expressions"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def dynamic_reshape(a: array, /, expression: Sequence[Union[int, str]], *, stream: "
|
||||
"def dynamic_reshape(a: array, /, expressions: Sequence[Union[int, str]], *, stream: "
|
||||
"Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Dynamically reshape an array based on the given expression.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
expression (tuple(int or str)): The expression which determines the
|
||||
output shape.
|
||||
expressions (tuple(int or str)): The expressions which determine
|
||||
the output shape.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
|
@ -809,6 +809,29 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
out = fun(*inputs)
|
||||
self.assertTrue(mx.allclose(out, mx.full((2, 2), 20)))
|
||||
|
||||
def test_compile_shapeless_with_reshape(self):
|
||||
def fun(a):
|
||||
return mx.reshape(a, (4, 7, 4, 2))
|
||||
|
||||
cfun = mx.compile(fun, shapeless=True)
|
||||
|
||||
a = mx.zeros((4, 7, 8))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
b = cfun(a)
|
||||
|
||||
def fun(a):
|
||||
return mx.dynamic_reshape(a, ("B", "L", 4, 2))
|
||||
|
||||
cfun = mx.compile(fun, shapeless=True)
|
||||
|
||||
b = cfun(a)
|
||||
self.assertEqual(b.shape, (4, 7, 4, 2))
|
||||
|
||||
a = mx.zeros((4, 9, 8))
|
||||
b = cfun(a)
|
||||
self.assertEqual(b.shape, (4, 9, 4, 2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -2718,6 +2718,16 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
a = mx.dynamic_reshape(a, ())
|
||||
self.assertEqual(a.shape, ())
|
||||
|
||||
a = mx.zeros((4, 4, 4))
|
||||
b = mx.dynamic_reshape(a, ("a", "b", "c"))
|
||||
self.assertEqual(b.shape, (4, 4, 4))
|
||||
|
||||
b = mx.dynamic_reshape(a, ("a*b", "c"))
|
||||
self.assertEqual(b.shape, (4 * 4, 4))
|
||||
|
||||
b = mx.dynamic_reshape(a, ("a*b*c", 1, 1))
|
||||
self.assertEqual(b.shape, (4 * 4 * 4, 1, 1))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user