mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Fix Split::vmap (#1845)
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							1c0c118f7c
						
					
				
				
					commit
					9eb7d7362f
				
			@@ -596,6 +596,18 @@ class TestVmap(mlx_tests.MLXTestCase):
 | 
			
		||||
        out = mx.vmap(fun, in_axes=(None, 1, 1))(a, idx, upd)
 | 
			
		||||
        self.assertEqual(out.shape, (4, 5, 1))
 | 
			
		||||
 | 
			
		||||
    def test_vmap_split_vmap(self):
 | 
			
		||||
        def fun(x):
 | 
			
		||||
            a, b = mx.split(x, 2, 1)
 | 
			
		||||
            return mx.concatenate([b, a], 1)
 | 
			
		||||
 | 
			
		||||
        x = mx.ones((5, 6, 7))
 | 
			
		||||
        y = mx.ones((5, 4, 6, 7))
 | 
			
		||||
        fx = fun(x)
 | 
			
		||||
        fy = mx.vmap(fun, in_axes=1)(y)
 | 
			
		||||
        self.assertEqual(fx.shape, (5, 6, 7))
 | 
			
		||||
        self.assertEqual(fy.shape, (4, 5, 6, 7))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user