mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Fix vmap constant output size (#1524)
* use inputs to determine output size * remove noop vmap tests
This commit is contained in:
		| @@ -462,6 +462,26 @@ class TestVmap(mlx_tests.MLXTestCase): | ||||
|         expected[:, 0] = mx.array([1, 2, 3])[:, None] | ||||
|         self.assertTrue(mx.allclose(out, expected)) | ||||
|  | ||||
|     def test_vmap_const_func(self): | ||||
|         a = mx.random.uniform(shape=(2, 3, 4)) | ||||
|         b = mx.random.uniform(shape=(4, 3)) | ||||
|  | ||||
|         def const_func(a, b): | ||||
|             return mx.array(2) | ||||
|  | ||||
|         out = mx.vmap(const_func, in_axes=(0, None))(a, b) | ||||
|         self.assertTrue(mx.array_equal(mx.full((2,), 2), out)) | ||||
|         out = mx.vmap(const_func, in_axes=(None, 0))(a, b) | ||||
|         self.assertTrue(mx.array_equal(mx.full((4,), 2), out)) | ||||
|         out = mx.vmap(const_func, in_axes=(1, 1))(a, b) | ||||
|         self.assertTrue(mx.array_equal(mx.full((3,), 2), out)) | ||||
|  | ||||
|         with self.assertRaises(ValueError): | ||||
|             out = mx.vmap(const_func, in_axes=(None, None))(a, b) | ||||
|  | ||||
|         with self.assertRaises(ValueError): | ||||
|             out = mx.vmap(const_func, in_axes=(0, 0))(a, b) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Alex Barron
					Alex Barron