NumberOfElements for shapeless compile and vmap fixes (#802)

This commit is contained in:
Angelos Katharopoulos
2024-03-13 10:34:14 -07:00
committed by GitHub
parent 29d0c10ee5
commit 76c919b4ec
13 changed files with 289 additions and 72 deletions

View File

@@ -253,6 +253,17 @@ class TestVmap(mlx_tests.MLXTestCase):
expected = mx.array([2, 1])
self.assertTrue(mx.array_equal(out, expected))
def test_vmap_mean(self):
a = mx.arange(8).reshape(2, 4)
out = mx.vmap(mx.mean)(a)
expected = mx.mean(a, axis=1)
self.assertTrue(mx.allclose(out, expected))
a = mx.arange(16).reshape(2, 2, 4)
out = mx.vmap(mx.vmap(mx.mean))(a)
expected = mx.mean(a, axis=2)
self.assertTrue(mx.allclose(out, expected))
def test_mismatch_input_sizes(self):
a = mx.ones((10, 1))
b = mx.ones((1, 1, 1, 5))