mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Fix Split::vmap (#1845)
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							1c0c118f7c
						
					
				
				
					commit
					9eb7d7362f
				
			| @@ -4273,7 +4273,9 @@ std::pair<std::vector<array>, std::vector<int>> Split::vmap( | ||||
|     const std::vector<array>& inputs, | ||||
|     const std::vector<int>& axes) { | ||||
|   int axis_left = axes[0] >= 0 && axes[0] <= axis_; | ||||
|   return {{split(inputs[0], indices_, axis_ + axis_left, stream())}, axes}; | ||||
|   auto output = split(inputs[0], indices_, axis_ + axis_left, stream()); | ||||
|   std::vector<int> output_axes(output.size(), axes[0]); | ||||
|   return {std::move(output), std::move(output_axes)}; | ||||
| } | ||||
|  | ||||
| std::vector<array> Split::vjp( | ||||
|   | ||||
| @@ -596,6 +596,18 @@ class TestVmap(mlx_tests.MLXTestCase): | ||||
|         out = mx.vmap(fun, in_axes=(None, 1, 1))(a, idx, upd) | ||||
|         self.assertEqual(out.shape, (4, 5, 1)) | ||||
|  | ||||
|     def test_vmap_split_vmap(self): | ||||
|         def fun(x): | ||||
|             a, b = mx.split(x, 2, 1) | ||||
|             return mx.concatenate([b, a], 1) | ||||
|  | ||||
|         x = mx.ones((5, 6, 7)) | ||||
|         y = mx.ones((5, 4, 6, 7)) | ||||
|         fx = fun(x) | ||||
|         fy = mx.vmap(fun, in_axes=1)(y) | ||||
|         self.assertEqual(fx.shape, (5, 6, 7)) | ||||
|         self.assertEqual(fy.shape, (4, 5, 6, 7)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user