mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Scatter vjp (#394)
* Add a first scatter vjp * Implement the scatter_add vjp * Add array.at to implement user friendly scatters
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							e9ca65c939
						
					
				
				
					commit
					961435a243
				
			| @@ -462,6 +462,37 @@ array create_array(array_init_type v, std::optional<Dtype> t) { | ||||
|   } | ||||
| } | ||||
|  | ||||
| class ArrayAt { | ||||
|  public: | ||||
|   ArrayAt(array x) : x_(std::move(x)) {} | ||||
|   ArrayAt& set_indices(py::object indices) { | ||||
|     indices_ = indices; | ||||
|     return *this; | ||||
|   } | ||||
|   array add(const ScalarOrArray& v) { | ||||
|     return mlx_add_item(x_, indices_, v); | ||||
|   } | ||||
|   array subtract(const ScalarOrArray& v) { | ||||
|     return mlx_subtract_item(x_, indices_, v); | ||||
|   } | ||||
|   array multiply(const ScalarOrArray& v) { | ||||
|     return mlx_multiply_item(x_, indices_, v); | ||||
|   } | ||||
|   array divide(const ScalarOrArray& v) { | ||||
|     return mlx_divide_item(x_, indices_, v); | ||||
|   } | ||||
|   array maximum(const ScalarOrArray& v) { | ||||
|     return mlx_maximum_item(x_, indices_, v); | ||||
|   } | ||||
|   array minimum(const ScalarOrArray& v) { | ||||
|     return mlx_minimum_item(x_, indices_, v); | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   array x_; | ||||
|   py::object indices_; | ||||
| }; | ||||
|  | ||||
| void init_array(py::module_& m) { | ||||
|   // Types | ||||
|   py::class_<Dtype>( | ||||
| @@ -501,6 +532,26 @@ void init_array(py::module_& m) { | ||||
|   m.attr("bfloat16") = py::cast(bfloat16); | ||||
|   m.attr("complex64") = py::cast(complex64); | ||||
|  | ||||
|   py::class_<ArrayAt>( | ||||
|       m, | ||||
|       "_ArrayAt", | ||||
|       R"pbdoc( | ||||
|       A helper object to apply updates at specific indices. | ||||
|       )pbdoc") | ||||
|       .def( | ||||
|           py::init([](const array& x) { return ArrayAt(x); }), | ||||
|           "x"_a, | ||||
|           R"pbdoc( | ||||
|           __init__(self, x: array) | ||||
|         )pbdoc") | ||||
|       .def("__getitem__", &ArrayAt::set_indices, "indices"_a) | ||||
|       .def("add", &ArrayAt::add, "value"_a) | ||||
|       .def("subtract", &ArrayAt::subtract, "value"_a) | ||||
|       .def("multiply", &ArrayAt::multiply, "value"_a) | ||||
|       .def("divide", &ArrayAt::divide, "value"_a) | ||||
|       .def("maximum", &ArrayAt::maximum, "value"_a) | ||||
|       .def("minimum", &ArrayAt::minimum, "value"_a); | ||||
|  | ||||
|   auto array_class = py::class_<array>( | ||||
|       m, | ||||
|       "array", | ||||
| @@ -610,6 +661,38 @@ void init_array(py::module_& m) { | ||||
|           )pbdoc") | ||||
|       .def("__getitem__", mlx_get_item) | ||||
|       .def("__setitem__", mlx_set_item) | ||||
|       .def_property_readonly( | ||||
|           "at", | ||||
|           [](const array& a) { return ArrayAt(a); }, | ||||
|           R"pbdoc( | ||||
|             Used to apply updates at the given indices. | ||||
|  | ||||
|             .. note:: | ||||
|  | ||||
|                Python in place updates for all array frameworks map to | ||||
|                assignment. For instance ``x[idx] += y`` maps to ``x[idx] = | ||||
|                x[idx] + y``. As a result, assigning to the same index ignores | ||||
|                all but one updates. Using ``x.at[idx].add(y)`` will correctly | ||||
|                apply all the updates to all indices. | ||||
|  | ||||
|             .. list-table:: | ||||
|                :header-rows: 1 | ||||
|  | ||||
|                * - array.at syntax | ||||
|                  - In-place syntax | ||||
|                * - ``x = x.at[idx].add(y)``  | ||||
|                  - ``x[idx] += y`` | ||||
|                * - ``x = x.at[idx].subtract(y)``  | ||||
|                  - ``x[idx] -= y`` | ||||
|                * - ``x = x.at[idx].multiply(y)``  | ||||
|                  - ``x[idx] *= y`` | ||||
|                * - ``x = x.at[idx].divide(y)``  | ||||
|                  - ``x[idx] /= y`` | ||||
|                * - ``x = x.at[idx].maximum(y)``  | ||||
|                  - ``x[idx] = mx.maximum(x[idx], y)`` | ||||
|                * - ``x = x.at[idx].minimum(y)``  | ||||
|                 - ``x[idx] = mx.minimum(x[idx], y)`` | ||||
|           )pbdoc") | ||||
|       .def( | ||||
|           "__len__", | ||||
|           [](const array& a) { | ||||
|   | ||||
| @@ -392,7 +392,7 @@ array mlx_get_item(const array& src, const py::object& obj) { | ||||
|   throw std::invalid_argument("Cannot index mlx array using the given type."); | ||||
| } | ||||
|  | ||||
| array mlx_set_item_int( | ||||
| std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_int( | ||||
|     const array& src, | ||||
|     const py::int_& idx, | ||||
|     const array& update) { | ||||
| @@ -410,14 +410,14 @@ array mlx_set_item_int( | ||||
|       std::vector<int>(update.shape().begin() + s, update.shape().end()); | ||||
|   auto shape = src.shape(); | ||||
|   shape[0] = 1; | ||||
|   return scatter( | ||||
|       src, | ||||
|       get_int_index(idx, src.shape(0)), | ||||
|  | ||||
|   return { | ||||
|       {get_int_index(idx, src.shape(0))}, | ||||
|       broadcast_to(reshape(update, up_shape), shape), | ||||
|       0); | ||||
|       {0}}; | ||||
| } | ||||
|  | ||||
| array mlx_set_item_array( | ||||
| std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_array( | ||||
|     const array& src, | ||||
|     const array& indices, | ||||
|     const array& update) { | ||||
| @@ -441,10 +441,10 @@ array mlx_set_item_array( | ||||
|   up_shape.insert(up_shape.begin() + indices.ndim(), 1); | ||||
|   up = reshape(up, up_shape); | ||||
|  | ||||
|   return scatter(src, indices, up, 0); | ||||
|   return {{indices}, up, {0}}; | ||||
| } | ||||
|  | ||||
| array mlx_set_item_slice( | ||||
| std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice( | ||||
|     const array& src, | ||||
|     const py::slice& in_slice, | ||||
|     const array& update) { | ||||
| @@ -462,7 +462,7 @@ array mlx_set_item_slice( | ||||
|       ; | ||||
|     auto up_shape = | ||||
|         std::vector<int>(update.shape().begin() + s, update.shape().end()); | ||||
|     return broadcast_to(reshape(update, up_shape), src.shape()); | ||||
|     return {{}, broadcast_to(reshape(update, up_shape), src.shape()), {}}; | ||||
|   } | ||||
|  | ||||
|   int start = 0; | ||||
| @@ -472,10 +472,11 @@ array mlx_set_item_slice( | ||||
|   // Check and update slice params | ||||
|   get_slice_params(start, end, stride, in_slice, end); | ||||
|  | ||||
|   return mlx_set_item_array(src, arange(start, end, stride, uint32), update); | ||||
|   return mlx_scatter_args_array( | ||||
|       src, arange(start, end, stride, uint32), update); | ||||
| } | ||||
|  | ||||
| array mlx_set_item_nd( | ||||
| std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd( | ||||
|     const array& src, | ||||
|     const py::tuple& entries, | ||||
|     const array& update) { | ||||
| @@ -537,7 +538,7 @@ array mlx_set_item_nd( | ||||
|  | ||||
|   // If no non-None indices return the broadcasted update | ||||
|   if (non_none_indices == 0) { | ||||
|     return broadcast_to(up, src.shape()); | ||||
|     return {{}, broadcast_to(up, src.shape()), {}}; | ||||
|   } | ||||
|  | ||||
|   unsigned long max_dim = 0; | ||||
| @@ -621,25 +622,108 @@ array mlx_set_item_nd( | ||||
|  | ||||
|   std::vector<int> axes(arr_indices.size(), 0); | ||||
|   std::iota(axes.begin(), axes.end(), 0); | ||||
|   return scatter(src, arr_indices, up, axes); | ||||
|  | ||||
|   return {arr_indices, up, axes}; | ||||
| } | ||||
|  | ||||
| std::tuple<std::vector<array>, array, std::vector<int>> | ||||
| mlx_compute_scatter_args( | ||||
|     const array& src, | ||||
|     const py::object& obj, | ||||
|     const ScalarOrArray& v) { | ||||
|   auto vals = to_array(v, src.dtype()); | ||||
|   if (py::isinstance<py::slice>(obj)) { | ||||
|     return mlx_scatter_args_slice(src, obj, vals); | ||||
|   } else if (py::isinstance<array>(obj)) { | ||||
|     return mlx_scatter_args_array(src, py::cast<array>(obj), vals); | ||||
|   } else if (py::isinstance<py::int_>(obj)) { | ||||
|     return mlx_scatter_args_int(src, obj, vals); | ||||
|   } else if (py::isinstance<py::tuple>(obj)) { | ||||
|     return mlx_scatter_args_nd(src, obj, vals); | ||||
|   } else if (obj.is_none()) { | ||||
|     return {{}, broadcast_to(vals, src.shape()), {}}; | ||||
|   } | ||||
|   throw std::invalid_argument("Cannot index mlx array using the given type."); | ||||
| } | ||||
|  | ||||
| void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v) { | ||||
|   auto vals = to_array(v, src.dtype()); | ||||
|   auto impl = [&src, &obj, &vals]() { | ||||
|     if (py::isinstance<py::slice>(obj)) { | ||||
|       return mlx_set_item_slice(src, obj, vals); | ||||
|     } else if (py::isinstance<array>(obj)) { | ||||
|       return mlx_set_item_array(src, py::cast<array>(obj), vals); | ||||
|     } else if (py::isinstance<py::int_>(obj)) { | ||||
|       return mlx_set_item_int(src, obj, vals); | ||||
|     } else if (py::isinstance<py::tuple>(obj)) { | ||||
|       return mlx_set_item_nd(src, obj, vals); | ||||
|     } else if (obj.is_none()) { | ||||
|       return broadcast_to(vals, src.shape()); | ||||
|     } | ||||
|     throw std::invalid_argument("Cannot index mlx array using the given type."); | ||||
|   }; | ||||
|   auto out = impl(); | ||||
|   src.overwrite_descriptor(out); | ||||
|   auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); | ||||
|   if (indices.size() > 0) { | ||||
|     auto out = scatter(src, indices, updates, axes); | ||||
|     src.overwrite_descriptor(out); | ||||
|   } else { | ||||
|     src.overwrite_descriptor(updates); | ||||
|   } | ||||
| } | ||||
|  | ||||
| array mlx_add_item( | ||||
|     const array& src, | ||||
|     const py::object& obj, | ||||
|     const ScalarOrArray& v) { | ||||
|   auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); | ||||
|   if (indices.size() > 0) { | ||||
|     return scatter_add(src, indices, updates, axes); | ||||
|   } else { | ||||
|     return src + updates; | ||||
|   } | ||||
| } | ||||
|  | ||||
| array mlx_subtract_item( | ||||
|     const array& src, | ||||
|     const py::object& obj, | ||||
|     const ScalarOrArray& v) { | ||||
|   auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); | ||||
|   if (indices.size() > 0) { | ||||
|     return scatter_add(src, indices, -updates, axes); | ||||
|   } else { | ||||
|     return src - updates; | ||||
|   } | ||||
| } | ||||
|  | ||||
| array mlx_multiply_item( | ||||
|     const array& src, | ||||
|     const py::object& obj, | ||||
|     const ScalarOrArray& v) { | ||||
|   auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); | ||||
|   if (indices.size() > 0) { | ||||
|     return scatter_prod(src, indices, updates, axes); | ||||
|   } else { | ||||
|     return src * updates; | ||||
|   } | ||||
| } | ||||
|  | ||||
| array mlx_divide_item( | ||||
|     const array& src, | ||||
|     const py::object& obj, | ||||
|     const ScalarOrArray& v) { | ||||
|   auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); | ||||
|   if (indices.size() > 0) { | ||||
|     return scatter_prod(src, indices, reciprocal(updates), axes); | ||||
|   } else { | ||||
|     return src / updates; | ||||
|   } | ||||
| } | ||||
|  | ||||
| array mlx_maximum_item( | ||||
|     const array& src, | ||||
|     const py::object& obj, | ||||
|     const ScalarOrArray& v) { | ||||
|   auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); | ||||
|   if (indices.size() > 0) { | ||||
|     return scatter_max(src, indices, updates, axes); | ||||
|   } else { | ||||
|     return maximum(src, updates); | ||||
|   } | ||||
| } | ||||
|  | ||||
| array mlx_minimum_item( | ||||
|     const array& src, | ||||
|     const py::object& obj, | ||||
|     const ScalarOrArray& v) { | ||||
|   auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); | ||||
|   if (indices.size() > 0) { | ||||
|     return scatter_min(src, indices, updates, axes); | ||||
|   } else { | ||||
|     return minimum(src, updates); | ||||
|   } | ||||
| } | ||||
|   | ||||
| @@ -12,3 +12,27 @@ using namespace mlx::core; | ||||
|  | ||||
| array mlx_get_item(const array& src, const py::object& obj); | ||||
| void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v); | ||||
| array mlx_add_item( | ||||
|     const array& src, | ||||
|     const py::object& obj, | ||||
|     const ScalarOrArray& v); | ||||
| array mlx_subtract_item( | ||||
|     const array& src, | ||||
|     const py::object& obj, | ||||
|     const ScalarOrArray& v); | ||||
| array mlx_multiply_item( | ||||
|     const array& src, | ||||
|     const py::object& obj, | ||||
|     const ScalarOrArray& v); | ||||
| array mlx_divide_item( | ||||
|     const array& src, | ||||
|     const py::object& obj, | ||||
|     const ScalarOrArray& v); | ||||
| array mlx_maximum_item( | ||||
|     const array& src, | ||||
|     const py::object& obj, | ||||
|     const ScalarOrArray& v); | ||||
| array mlx_minimum_item( | ||||
|     const array& src, | ||||
|     const py::object& obj, | ||||
|     const ScalarOrArray& v); | ||||
|   | ||||
| @@ -3310,7 +3310,6 @@ void init_ops(py::module_& m) { | ||||
|         Returns: | ||||
|           result (array): The tensor dot product. | ||||
|       )pbdoc"); | ||||
|  | ||||
|   m.def( | ||||
|       "inner", | ||||
|       &inner, | ||||
| @@ -3331,7 +3330,6 @@ void init_ops(py::module_& m) { | ||||
|       Returns: | ||||
|         result (array): The inner product. | ||||
|     )pbdoc"); | ||||
|  | ||||
|   m.def( | ||||
|       "outer", | ||||
|       &outer, | ||||
|   | ||||
| @@ -984,6 +984,53 @@ class TestArray(mlx_tests.MLXTestCase): | ||||
|         a[2:-2, 2:-2] = 4 | ||||
|         self.assertEqual(a[2, 2].item(), 4) | ||||
|  | ||||
|     def test_array_at(self): | ||||
|         a = mx.array(1) | ||||
|         a = a.at[None].add(1) | ||||
|         self.assertEqual(a.item(), 2) | ||||
|  | ||||
|         a = mx.array([0, 1, 2]) | ||||
|         a = a.at[1].add(2) | ||||
|         self.assertEqual(a.tolist(), [0, 3, 2]) | ||||
|  | ||||
|         a = a.at[mx.array([0, 0, 0, 0])].add(1) | ||||
|         self.assertEqual(a.tolist(), [4, 3, 2]) | ||||
|  | ||||
|         a = mx.zeros((10, 10)) | ||||
|         a = a.at[0].add(mx.arange(10)) | ||||
|         self.assertEqual(a[0].tolist(), list(range(10))) | ||||
|  | ||||
|         a = mx.zeros((10, 10)) | ||||
|         index_x = mx.array([0, 2, 3, 7]) | ||||
|         index_y = mx.array([3, 3, 1, 2]) | ||||
|         u = mx.random.uniform(shape=(4,)) | ||||
|         a = a.at[index_x, index_y].add(u) | ||||
|         self.assertEqual(a.sum().item(), u.sum().item()) | ||||
|         self.assertEqual(a[index_x, index_y].tolist(), u.tolist()) | ||||
|  | ||||
|         # Test all array.at ops | ||||
|         a = mx.random.uniform(shape=(10, 5, 2)) | ||||
|         idx_x = mx.array([0, 4]) | ||||
|         update = mx.ones((2, 5)) | ||||
|         a[idx_x, :, 0] = 0 | ||||
|         a = a.at[idx_x, :, 0].add(update) | ||||
|         self.assertEqualArray(a[idx_x, :, 0], update) | ||||
|         a = a.at[idx_x, :, 0].subtract(update) | ||||
|         self.assertEqualArray(a[idx_x, :, 0], mx.zeros_like(update)) | ||||
|         a = a.at[idx_x, :, 0].add(2 * update) | ||||
|         self.assertEqualArray(a[idx_x, :, 0], 2 * update) | ||||
|         a = a.at[idx_x, :, 0].multiply(2 * update) | ||||
|         self.assertEqualArray(a[idx_x, :, 0], 4 * update) | ||||
|         a = a.at[idx_x, :, 0].divide(3 * update) | ||||
|         self.assertEqualArray(a[idx_x, :, 0], (4 / 3) * update) | ||||
|         a[idx_x, :, 0] = 5 | ||||
|         update = mx.arange(10).reshape(2, 5) | ||||
|         a = a.at[idx_x, :, 0].maximum(update) | ||||
|         self.assertEqualArray(a[idx_x, :, 0], mx.maximum(a[idx_x, :, 0], update)) | ||||
|         a[idx_x, :, 0] = 5 | ||||
|         a = a.at[idx_x, :, 0].minimum(update) | ||||
|         self.assertEqualArray(a[idx_x, :, 0], mx.minimum(a[idx_x, :, 0], update)) | ||||
|  | ||||
|     def test_slice_negative_step(self): | ||||
|         a_np = np.arange(20) | ||||
|         a_mx = mx.array(a_np) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user