mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
ExpandDims
primitive (#1687)
* add squeeze primitive * simplify squeeze, use in gather * fix * fix * fix * fix * fix no cpu * use squeeze in matmul and friends * expand dims primitive * comment
This commit is contained in:
@@ -392,27 +392,6 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
out = fun(x, y=y, z=z)
|
||||
self.assertEqual(out.item(), 6)
|
||||
|
||||
def test_shapeless_compile(self):
|
||||
y = 1
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def fun(x):
|
||||
return x + y
|
||||
|
||||
x = mx.array([1, 2])
|
||||
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3])))
|
||||
|
||||
# The function is not recompiled, so the change
|
||||
# to y should not be reflected in the output
|
||||
y = 2
|
||||
x = mx.array([1, 2, 3])
|
||||
self.assertTrue(mx.array_equal(fun(x), mx.array([2, 3, 4])))
|
||||
|
||||
# Type change recompiles
|
||||
x = mx.array([1.0, 2.0, 3.0])
|
||||
self.assertTrue(mx.array_equal(fun(x), mx.array([3.0, 4.0, 5.0])))
|
||||
fun(x, y=y, z=z)
|
||||
|
||||
def test_shapeless_compile(self):
|
||||
y = 1
|
||||
|
||||
@@ -477,6 +456,12 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
mx.eval(cfun(x1))
|
||||
self.assertTrue(mx.array_equal(fun(x2), cfun(x2)))
|
||||
|
||||
def fun(x):
|
||||
return x * x.sum(-1, keepdims=False)
|
||||
|
||||
cfun = mx.compile(fun, shapeless=True)
|
||||
self.assertTrue(mx.array_equal(fun(x2), cfun(x2)))
|
||||
|
||||
def test_compile_with_constant(self):
|
||||
# Test float
|
||||
@partial(mx.compile)
|
||||
@@ -809,6 +794,13 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
out = fun(*inputs)
|
||||
self.assertTrue(mx.allclose(out, mx.full((2, 2), 20)))
|
||||
|
||||
def test_shapeless_compile_matmul(self):
|
||||
a = mx.array([0.0, 1.0, 2.0])
|
||||
b = mx.array([0.0, 1.0, 2.0])
|
||||
|
||||
fun = mx.compile(lambda a, b: a @ b, shapeless=True)
|
||||
self.assertTrue(mx.allclose(fun(a, b), a @ b))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user