mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-03 18:18:15 +08:00
Flatten and unflatten (#1692)
* flatten and unflatten * fix grad * fix shape infer * use squeeze + unsqueeze in get_item
This commit is contained in:
@@ -462,6 +462,22 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
cfun = mx.compile(fun, shapeless=True)
|
||||
self.assertTrue(mx.array_equal(fun(x2), cfun(x2)))
|
||||
|
||||
def test_shapeless_compile_unflatten(self):
|
||||
x = mx.zeros((1, 1, 4 * 32))
|
||||
|
||||
def fun(x):
|
||||
return mx.unflatten(x, -1, (4, -1))
|
||||
|
||||
self.assertEqual(mx.compile(fun, shapeless=True)(x).shape, (1, 1, 4, 32))
|
||||
|
||||
def test_shapeless_compile_gather(self):
|
||||
x = mx.zeros((1, 1, 32))
|
||||
|
||||
def fun(x):
|
||||
return x[:, -1, :]
|
||||
|
||||
self.assertEqual(mx.compile(fun, shapeless=True)(x).shape, (1, 32))
|
||||
|
||||
def test_compile_with_constant(self):
|
||||
# Test float
|
||||
@partial(mx.compile)
|
||||
|
||||
Reference in New Issue
Block a user