mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	fix vmap for flatten (#1955)
This commit is contained in:
		@@ -1766,13 +1766,19 @@ std::pair<std::vector<array>, std::vector<int>> Flatten::vmap(
 | 
			
		||||
  auto ax = axes[0];
 | 
			
		||||
  auto start_axis = start_axis_;
 | 
			
		||||
  auto end_axis = end_axis_;
 | 
			
		||||
  auto in = inputs[0];
 | 
			
		||||
  if (ax < start_axis) {
 | 
			
		||||
    start_axis++;
 | 
			
		||||
    end_axis++;
 | 
			
		||||
  } else if (ax <= end_axis_) {
 | 
			
		||||
    start_axis++;
 | 
			
		||||
    end_axis++;
 | 
			
		||||
    in = moveaxis(in, ax, 0, stream());
 | 
			
		||||
    ax = 0;
 | 
			
		||||
  } else {
 | 
			
		||||
    ax -= (end_axis - start_axis);
 | 
			
		||||
  }
 | 
			
		||||
  return {{flatten(inputs[0], start_axis, end_axis, stream())}, {ax}};
 | 
			
		||||
  return {{flatten(in, start_axis, end_axis, stream())}, {ax}};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool Flatten::is_equivalent(const Primitive& other) const {
 | 
			
		||||
 
 | 
			
		||||
@@ -659,6 +659,16 @@ class TestVmap(mlx_tests.MLXTestCase):
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(mem_pre, mem_post)
 | 
			
		||||
 | 
			
		||||
    def test_vmap_flatten(self):
 | 
			
		||||
        def fun(x):
 | 
			
		||||
            return mx.flatten(x, 0, 1)
 | 
			
		||||
 | 
			
		||||
        x = mx.zeros((2, 3, 4))
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(mx.vmap(fun)(x).shape, (2, 12))
 | 
			
		||||
        self.assertEqual(mx.vmap(fun, in_axes=(1,))(x).shape, (3, 8))
 | 
			
		||||
        self.assertEqual(mx.vmap(fun, in_axes=(2,))(x).shape, (4, 6))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user