Flatten and unflatten (#1692)

* flatten and unflatten

* fix grad

* fix shape infer

* use squeeze + unsqueeze in get_item
This commit is contained in:
Awni Hannun
2024-12-11 21:51:37 -08:00
committed by GitHub
parent 0bf19037ca
commit 4e1e9520e1
19 changed files with 363 additions and 93 deletions

View File

@@ -462,6 +462,22 @@ class TestCompile(mlx_tests.MLXTestCase):
cfun = mx.compile(fun, shapeless=True)
self.assertTrue(mx.array_equal(fun(x2), cfun(x2)))
def test_shapeless_compile_unflatten(self):
x = mx.zeros((1, 1, 4 * 32))
def fun(x):
return mx.unflatten(x, -1, (4, -1))
self.assertEqual(mx.compile(fun, shapeless=True)(x).shape, (1, 1, 4, 32))
def test_shapeless_compile_gather(self):
x = mx.zeros((1, 1, 32))
def fun(x):
return x[:, -1, :]
self.assertEqual(mx.compile(fun, shapeless=True)(x).shape, (1, 32))
def test_compile_with_constant(self):
# Test float
@partial(mx.compile)