mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Fix concatenate vmap (#1600)
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							2af7e8a9a6
						
					
				
				
					commit
					5e89aace9b
				
			| @@ -527,6 +527,26 @@ class TestVmap(mlx_tests.MLXTestCase): | ||||
|         with self.assertRaises(ValueError): | ||||
|             out = mx.vmap(const_func, in_axes=(0, 0))(a, b) | ||||
|  | ||||
|     def test_vmap_concatenate(self): | ||||
|         x = mx.random.uniform(shape=(2, 2, 2)) | ||||
|  | ||||
|         def cat_fun(x, y): | ||||
|             return mx.concatenate([x, y], axis=1) | ||||
|  | ||||
|         def cat_constant(x): | ||||
|             y = mx.ones((2, 1)) | ||||
|             return mx.concatenate([x, y], 1) | ||||
|  | ||||
|         out = mx.vmap(cat_fun, in_axes=(0, 2))(x, x) | ||||
|         target = mx.stack( | ||||
|             [mx.concatenate([x[i], x[:, :, i]], axis=1) for i in range(2)] | ||||
|         ) | ||||
|         self.assertTrue(mx.array_equal(out, target)) | ||||
|  | ||||
|         out = mx.vmap(cat_constant)(x) | ||||
|         target = mx.concatenate([x, mx.ones((2, 2, 1))], axis=2) | ||||
|         self.assertTrue(mx.array_equal(out, target)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user