mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
Fix vmap constant output size (#1524)
* use inputs to determine output size * remove noop vmap tests
This commit is contained in:
@@ -462,6 +462,26 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
expected[:, 0] = mx.array([1, 2, 3])[:, None]
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
def test_vmap_const_func(self):
|
||||
a = mx.random.uniform(shape=(2, 3, 4))
|
||||
b = mx.random.uniform(shape=(4, 3))
|
||||
|
||||
def const_func(a, b):
|
||||
return mx.array(2)
|
||||
|
||||
out = mx.vmap(const_func, in_axes=(0, None))(a, b)
|
||||
self.assertTrue(mx.array_equal(mx.full((2,), 2), out))
|
||||
out = mx.vmap(const_func, in_axes=(None, 0))(a, b)
|
||||
self.assertTrue(mx.array_equal(mx.full((4,), 2), out))
|
||||
out = mx.vmap(const_func, in_axes=(1, 1))(a, b)
|
||||
self.assertTrue(mx.array_equal(mx.full((3,), 2), out))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
out = mx.vmap(const_func, in_axes=(None, None))(a, b)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
out = mx.vmap(const_func, in_axes=(0, 0))(a, b)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user