mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	| @@ -1,5 +1,4 @@ | |||||||
| // Copyright © 2023-2024 Apple Inc. | // Copyright © 2023-2024 Apple Inc. | ||||||
|  |  | ||||||
| #include <algorithm> | #include <algorithm> | ||||||
| #include <cassert> | #include <cassert> | ||||||
| #include <cmath> | #include <cmath> | ||||||
| @@ -1683,48 +1682,58 @@ std::pair<std::vector<array>, std::vector<int>> Gather::vmap( | |||||||
|   auto gather_axes = axes_; |   auto gather_axes = axes_; | ||||||
|   auto slice_sizes = slice_sizes_; |   auto slice_sizes = slice_sizes_; | ||||||
|   auto src_vmapped = axes[0] >= 0; |   auto src_vmapped = axes[0] >= 0; | ||||||
|   auto indices_vmapped = |   auto ind_vmap_ax_ptr = | ||||||
|       std::any_of(axes.begin() + 1, axes.end(), [](int a) { return a >= 0; }); |       std::find_if(axes.begin() + 1, axes.end(), [](int a) { return a >= 0; }); | ||||||
|   auto out_ax = |   int out_ax = -1; | ||||||
|       *std::find_if(axes.begin(), axes.end(), [](int a) { return a >= 0; }); |   bool indices_vmapped = (ind_vmap_ax_ptr != axes.end()); | ||||||
|  |   if (indices_vmapped) { | ||||||
|  |     out_ax = *ind_vmap_ax_ptr; | ||||||
|  |   } else if (src_vmapped) { | ||||||
|  |     out_ax = axes[0]; | ||||||
|  |   } | ||||||
|  |  | ||||||
|   // Reorder all the index arrays so the vmap axis is in the same spot. |   // Reorder all the index arrays so the vmap axis is in the same spot. | ||||||
|   for (int i = 1; i < axes.size(); ++i) { |   if (indices_vmapped) { | ||||||
|     if (out_ax != axes[i] && axes[i] >= 0) { |     for (int i = 1; i < axes.size(); ++i) { | ||||||
|       indices[i - 1] = moveaxis(indices[i - 1], axes[i], out_ax, stream()); |       if (out_ax != axes[i] && axes[i] >= 0) { | ||||||
|  |         indices[i - 1] = moveaxis(indices[i - 1], axes[i], out_ax, stream()); | ||||||
|  |       } else if (axes[i] < 0) { | ||||||
|  |         indices[i - 1] = expand_dims(indices[i - 1], out_ax, stream()); | ||||||
|  |       } | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  |   int idx_dims = indices.empty() ? 0 : indices[0].ndim(); | ||||||
|  |  | ||||||
|   if (src_vmapped) { |   if (src_vmapped) { | ||||||
|     int max_dims = 0; |     for (auto& ax : gather_axes) { | ||||||
|     for (auto& idx : indices) { |       if (ax >= axes[0]) { | ||||||
|       max_dims = std::max(static_cast<int>(idx.ndim()), max_dims); |         ax++; | ||||||
|     } |       } | ||||||
|     auto new_ax_loc = |  | ||||||
|         std::find_if(gather_axes.begin(), gather_axes.end(), [&out_ax](int a) { |  | ||||||
|           return a >= out_ax; |  | ||||||
|         }); |  | ||||||
|     for (; new_ax_loc < gather_axes.end(); new_ax_loc++) { |  | ||||||
|       (*new_ax_loc)++; |  | ||||||
|     } |     } | ||||||
|     if (indices_vmapped) { |     if (indices_vmapped) { | ||||||
|       // Make a new index array for the vmapped dimension |       // Make a new index array for the vmapped dimension | ||||||
|  |       auto vmap_inds = arange(0, src.shape(axes[0]), stream()); | ||||||
|       // Reshape it so it broadcasts with other index arrays |       // Reshape it so it broadcasts with other index arrays | ||||||
|  |       { | ||||||
|  |         auto shape = std::vector<int>(idx_dims, 1); | ||||||
|  |         shape[out_ax] = vmap_inds.size(); | ||||||
|  |         vmap_inds = reshape(vmap_inds, std::move(shape), stream()); | ||||||
|  |       } | ||||||
|       // Update gather axes and slice sizes accordingly |       // Update gather axes and slice sizes accordingly | ||||||
|       auto shape = std::vector<int>(max_dims - out_ax, 1); |       slice_sizes.insert(slice_sizes.begin() + axes[0], 1); | ||||||
|       auto vmap_inds = arange(0, src.shape(out_ax), stream()); |       gather_axes.push_back(axes[0]); | ||||||
|       shape[0] = vmap_inds.shape(0); |       indices.push_back(vmap_inds); | ||||||
|       vmap_inds = reshape(vmap_inds, shape, stream()); |  | ||||||
|       slice_sizes.insert(slice_sizes.begin() + out_ax, 1); |  | ||||||
|       auto new_ax_idx = new_ax_loc - gather_axes.begin(); |  | ||||||
|       gather_axes.insert(new_ax_loc, out_ax); |  | ||||||
|       indices.insert(indices.begin() + new_ax_idx, vmap_inds); |  | ||||||
|     } else { |     } else { | ||||||
|       slice_sizes.insert(slice_sizes.begin() + axes[0], src.shape(axes[0])); |       slice_sizes.insert(slice_sizes.begin() + out_ax, src.shape(out_ax)); | ||||||
|       out_ax = max_dims + axes[0]; |       out_ax += idx_dims; | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|   return {{gather(src, indices, gather_axes, slice_sizes, stream())}, {out_ax}}; |   auto out = gather(src, indices, gather_axes, slice_sizes, stream()); | ||||||
|  |   if (src_vmapped && indices_vmapped) { | ||||||
|  |     out = squeeze(out, idx_dims + axes[0], stream()); | ||||||
|  |   } | ||||||
|  |   return {{out}, {out_ax}}; | ||||||
| } | } | ||||||
|  |  | ||||||
| std::vector<array> Gather::vjp( | std::vector<array> Gather::vjp( | ||||||
|   | |||||||
| @@ -370,6 +370,51 @@ class TestVmap(mlx_tests.MLXTestCase): | |||||||
|                 mx.allclose(a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=0, atol=1e-5) |                 mx.allclose(a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=0, atol=1e-5) | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|  |     def test_vmap_gather(self): | ||||||
|  |         def gather(a, idx): | ||||||
|  |             return a[idx] | ||||||
|  |  | ||||||
|  |         a = mx.array([[1, 2], [3, 4]]) | ||||||
|  |         idx = mx.array(0) | ||||||
|  |         out = mx.vmap(gather, (0, None))(a, idx) | ||||||
|  |         self.assertTrue(mx.array_equal(out, mx.array([1, 3]))) | ||||||
|  |  | ||||||
|  |         out = mx.vmap(gather, (1, None))(a, idx) | ||||||
|  |         self.assertTrue(mx.array_equal(out, mx.array([1, 2]))) | ||||||
|  |  | ||||||
|  |         idx = mx.array([0, 1]) | ||||||
|  |         out = mx.vmap(gather, (0, 0))(a, idx) | ||||||
|  |         self.assertTrue(mx.array_equal(out, mx.array([1, 4]))) | ||||||
|  |  | ||||||
|  |         a = mx.ones((2, 3, 4)) | ||||||
|  |         idx = mx.zeros(4, mx.int32) | ||||||
|  |         out = mx.vmap(gather, (2, 0))(a, idx) | ||||||
|  |         self.assertEqual(out.shape, (4, 3)) | ||||||
|  |  | ||||||
|  |         f = mx.vmap(gather, (0, None)) | ||||||
|  |         f = mx.vmap(gather, (0, 0)) | ||||||
|  |         out = f(mx.ones((2, 3, 4)), mx.zeros(2, dtype=mx.int32)) | ||||||
|  |         self.assertEqual(out.shape, (2, 4)) | ||||||
|  |  | ||||||
|  |         def gather(a, idxa, idxb): | ||||||
|  |             return a[idxa, idxb] | ||||||
|  |  | ||||||
|  |         a = mx.ones((2, 3, 4)) | ||||||
|  |         idxa = mx.zeros((2, 3), mx.int32) | ||||||
|  |         idxb = mx.zeros(3, mx.int32) | ||||||
|  |         out = mx.vmap(gather, (0, 0, None))(a, idxa, idxb) | ||||||
|  |         self.assertEqual(out.shape, (2, 3)) | ||||||
|  |  | ||||||
|  |         idxa = mx.zeros((3, 1, 2), mx.int32) | ||||||
|  |         idxb = mx.zeros((2, 3, 1, 2), mx.int32) | ||||||
|  |         out = mx.vmap(gather, (0, None, 0))(a, idxa, idxb) | ||||||
|  |         self.assertEqual(out.shape, (2, 3, 1, 2)) | ||||||
|  |  | ||||||
|  |         idxa = mx.zeros((3, 1, 2), mx.int32) | ||||||
|  |         idxb = mx.zeros((3, 1, 2, 2), mx.int32) | ||||||
|  |         out = mx.vmap(gather, (0, None, 3))(a, idxa, idxb) | ||||||
|  |         self.assertEqual(out.shape, (2, 3, 1, 2)) | ||||||
|  |  | ||||||
|     def test_vmap_scatter(self): |     def test_vmap_scatter(self): | ||||||
|         def scatter(a): |         def scatter(a): | ||||||
|             a[mx.array(0)] = mx.array(0.0) |             a[mx.array(0)] = mx.array(0.0) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun