Fix vmap constant output size (#1524)

* use inputs to determine output size

* remove noop vmap tests
This commit is contained in:
Alex Barron
2024-10-30 16:16:53 -07:00
committed by GitHub
parent 917252a5a1
commit 048fabdabd
3 changed files with 38 additions and 36 deletions

View File

@@ -462,6 +462,26 @@ class TestVmap(mlx_tests.MLXTestCase):
expected[:, 0] = mx.array([1, 2, 3])[:, None]
self.assertTrue(mx.allclose(out, expected))
def test_vmap_const_func(self):
a = mx.random.uniform(shape=(2, 3, 4))
b = mx.random.uniform(shape=(4, 3))
def const_func(a, b):
return mx.array(2)
out = mx.vmap(const_func, in_axes=(0, None))(a, b)
self.assertTrue(mx.array_equal(mx.full((2,), 2), out))
out = mx.vmap(const_func, in_axes=(None, 0))(a, b)
self.assertTrue(mx.array_equal(mx.full((4,), 2), out))
out = mx.vmap(const_func, in_axes=(1, 1))(a, b)
self.assertTrue(mx.array_equal(mx.full((3,), 2), out))
with self.assertRaises(ValueError):
out = mx.vmap(const_func, in_axes=(None, None))(a, b)
with self.assertRaises(ValueError):
out = mx.vmap(const_func, in_axes=(0, 0))(a, b)
if __name__ == "__main__":
unittest.main()