NumberOfElements for shapeless compile and vmap fixes (#802)

This commit is contained in:
Angelos Katharopoulos
2024-03-13 10:34:14 -07:00
committed by GitHub
parent 29d0c10ee5
commit 76c919b4ec
13 changed files with 289 additions and 72 deletions

View File

@@ -653,6 +653,24 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(expected[0], out[0]))
self.assertTrue(mx.allclose(expected[1], out[1]))
def test_shapeless_mean(self):
def mean(x):
return mx.mean(x, keepdims=True)
cmean = mx.compile(mean, shapeless=True)
x = mx.ones(2)
out = cmean(x)
self.assertTrue(mx.allclose(out, mean(x)))
x = mx.ones(4)
out = cmean(x)
self.assertTrue(mx.allclose(out, mean(x)))
x = mx.ones(7)
out = cmean(x)
self.assertTrue(mx.allclose(out, mean(x)))
if __name__ == "__main__":
unittest.main()

View File

@@ -253,6 +253,17 @@ class TestVmap(mlx_tests.MLXTestCase):
expected = mx.array([2, 1])
self.assertTrue(mx.array_equal(out, expected))
def test_vmap_mean(self):
a = mx.arange(8).reshape(2, 4)
out = mx.vmap(mx.mean)(a)
expected = mx.mean(a, axis=1)
self.assertTrue(mx.allclose(out, expected))
a = mx.arange(16).reshape(2, 2, 4)
out = mx.vmap(mx.vmap(mx.mean))(a)
expected = mx.mean(a, axis=2)
self.assertTrue(mx.allclose(out, expected))
def test_mismatch_input_sizes(self):
a = mx.ones((10, 1))
b = mx.ones((1, 1, 1, 5))