From 961435a243754fa7bc071aae6c84fa49a20fd410 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 9 Jan 2024 13:36:51 -0800 Subject: [PATCH] Scatter vjp (#394) * Add a first scatter vjp * Implement the scatter_add vjp * Add array.at to implement user friendly scatters --- mlx/primitives.cpp | 72 +++++++++++++++++++ mlx/primitives.h | 21 +++++- python/src/array.cpp | 83 +++++++++++++++++++++ python/src/indexing.cpp | 144 +++++++++++++++++++++++++++++-------- python/src/indexing.h | 24 +++++++ python/src/ops.cpp | 2 - python/tests/test_array.py | 47 ++++++++++++ 7 files changed, 360 insertions(+), 33 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index bc2a2a036..0077b11fa 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -2122,6 +2122,78 @@ bool Scatter::is_equivalent(const Primitive& other) const { return reduce_type_ == s_other.reduce_type_ && axes_ == s_other.axes_; } +std::vector Scatter::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums) { + switch (reduce_type_) { + case Scatter::None: + case Scatter::Sum: + break; + default: + throw std::runtime_error( + "[scatter] VJP implemented only for scatter and scatter_add"); + } + + const array& values = primals[0]; + const array& updates = primals.back(); + const std::vector indices(primals.begin() + 1, primals.end() - 1); + + std::vector vjps; + for (auto num : argnums) { + // Gradient wrt to the input array + if (num == 0) { + switch (reduce_type_) { + case Scatter::None: + // Scatter 0s to the locations that were updated with the updates + vjps.push_back(scatter( + cotangents[0], + indices, + zeros_like(updates, stream()), + axes_, + stream())); + break; + case Scatter::Sum: + // The input array values are kept so they all get gradients + vjps.push_back(cotangents[0]); + break; + default: + // Should never reach here + throw std::invalid_argument(""); + } + } else if (num == primals.size() - 1) { + switch (reduce_type_) { + case Scatter::None: + case Scatter::Sum: { + // Gather the values from the cotangent + auto slice_sizes = cotangents[0].shape(); + for (auto ax : axes_) { + slice_sizes[ax] = 1; + } + vjps.push_back( + gather(cotangents[0], indices, axes_, slice_sizes, stream())); + break; + } + default: { + // Should never reach here + throw std::invalid_argument(""); + } + } + } else { + throw std::invalid_argument( + "[scatter] Cannot calculate VJP with respect to indices."); + } + } + return vjps; +} + +std::vector Scatter::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + throw std::runtime_error("[scatter] JVP not yet implemented"); +} + std::vector Sigmoid::vjp( const std::vector& primals, const std::vector& cotangents, diff --git a/mlx/primitives.h b/mlx/primitives.h index 786da59d1..85ffbf25e 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1266,7 +1266,26 @@ class Scatter : public UnaryPrimitive { void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; - DEFINE_PRINT(Scatter) + DEFINE_GRADS(); + void print(std::ostream& os) override { + os << "Scatter"; + switch (reduce_type_) { + case Sum: + os << " Sum"; + break; + case Prod: + os << " Prod"; + break; + case Min: + os << " Min"; + break; + case Max: + os << " Max"; + break; + case None: + break; + } + } bool is_equivalent(const Primitive& other) const override; private: diff --git a/python/src/array.cpp b/python/src/array.cpp index 61b8a53f2..bf2b09a3c 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -462,6 +462,37 @@ array create_array(array_init_type v, std::optional t) { } } +class ArrayAt { + public: + ArrayAt(array x) : x_(std::move(x)) {} + ArrayAt& set_indices(py::object indices) { + indices_ = indices; + return *this; + } + array add(const ScalarOrArray& v) { + return mlx_add_item(x_, indices_, v); + } + array subtract(const ScalarOrArray& v) { + return mlx_subtract_item(x_, indices_, v); + } + array multiply(const ScalarOrArray& v) { + return mlx_multiply_item(x_, indices_, v); + } + array divide(const ScalarOrArray& v) { + return mlx_divide_item(x_, indices_, v); + } + array maximum(const ScalarOrArray& v) { + return mlx_maximum_item(x_, indices_, v); + } + array minimum(const ScalarOrArray& v) { + return mlx_minimum_item(x_, indices_, v); + } + + private: + array x_; + py::object indices_; +}; + void init_array(py::module_& m) { // Types py::class_( @@ -501,6 +532,26 @@ void init_array(py::module_& m) { m.attr("bfloat16") = py::cast(bfloat16); m.attr("complex64") = py::cast(complex64); + py::class_( + m, + "_ArrayAt", + R"pbdoc( + A helper object to apply updates at specific indices. + )pbdoc") + .def( + py::init([](const array& x) { return ArrayAt(x); }), + "x"_a, + R"pbdoc( + __init__(self, x: array) + )pbdoc") + .def("__getitem__", &ArrayAt::set_indices, "indices"_a) + .def("add", &ArrayAt::add, "value"_a) + .def("subtract", &ArrayAt::subtract, "value"_a) + .def("multiply", &ArrayAt::multiply, "value"_a) + .def("divide", &ArrayAt::divide, "value"_a) + .def("maximum", &ArrayAt::maximum, "value"_a) + .def("minimum", &ArrayAt::minimum, "value"_a); + auto array_class = py::class_( m, "array", @@ -610,6 +661,38 @@ void init_array(py::module_& m) { )pbdoc") .def("__getitem__", mlx_get_item) .def("__setitem__", mlx_set_item) + .def_property_readonly( + "at", + [](const array& a) { return ArrayAt(a); }, + R"pbdoc( + Used to apply updates at the given indices. + + .. note:: + + Python in place updates for all array frameworks map to + assignment. For instance ``x[idx] += y`` maps to ``x[idx] = + x[idx] + y``. As a result, assigning to the same index ignores + all but one updates. Using ``x.at[idx].add(y)`` will correctly + apply all the updates to all indices. + + .. list-table:: + :header-rows: 1 + + * - array.at syntax + - In-place syntax + * - ``x = x.at[idx].add(y)`` + - ``x[idx] += y`` + * - ``x = x.at[idx].subtract(y)`` + - ``x[idx] -= y`` + * - ``x = x.at[idx].multiply(y)`` + - ``x[idx] *= y`` + * - ``x = x.at[idx].divide(y)`` + - ``x[idx] /= y`` + * - ``x = x.at[idx].maximum(y)`` + - ``x[idx] = mx.maximum(x[idx], y)`` + * - ``x = x.at[idx].minimum(y)`` + - ``x[idx] = mx.minimum(x[idx], y)`` + )pbdoc") .def( "__len__", [](const array& a) { diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index 1454c5180..74fb6b695 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -392,7 +392,7 @@ array mlx_get_item(const array& src, const py::object& obj) { throw std::invalid_argument("Cannot index mlx array using the given type."); } -array mlx_set_item_int( +std::tuple, array, std::vector> mlx_scatter_args_int( const array& src, const py::int_& idx, const array& update) { @@ -410,14 +410,14 @@ array mlx_set_item_int( std::vector(update.shape().begin() + s, update.shape().end()); auto shape = src.shape(); shape[0] = 1; - return scatter( - src, - get_int_index(idx, src.shape(0)), + + return { + {get_int_index(idx, src.shape(0))}, broadcast_to(reshape(update, up_shape), shape), - 0); + {0}}; } -array mlx_set_item_array( +std::tuple, array, std::vector> mlx_scatter_args_array( const array& src, const array& indices, const array& update) { @@ -441,10 +441,10 @@ array mlx_set_item_array( up_shape.insert(up_shape.begin() + indices.ndim(), 1); up = reshape(up, up_shape); - return scatter(src, indices, up, 0); + return {{indices}, up, {0}}; } -array mlx_set_item_slice( +std::tuple, array, std::vector> mlx_scatter_args_slice( const array& src, const py::slice& in_slice, const array& update) { @@ -462,7 +462,7 @@ array mlx_set_item_slice( ; auto up_shape = std::vector(update.shape().begin() + s, update.shape().end()); - return broadcast_to(reshape(update, up_shape), src.shape()); + return {{}, broadcast_to(reshape(update, up_shape), src.shape()), {}}; } int start = 0; @@ -472,10 +472,11 @@ array mlx_set_item_slice( // Check and update slice params get_slice_params(start, end, stride, in_slice, end); - return mlx_set_item_array(src, arange(start, end, stride, uint32), update); + return mlx_scatter_args_array( + src, arange(start, end, stride, uint32), update); } -array mlx_set_item_nd( +std::tuple, array, std::vector> mlx_scatter_args_nd( const array& src, const py::tuple& entries, const array& update) { @@ -537,7 +538,7 @@ array mlx_set_item_nd( // If no non-None indices return the broadcasted update if (non_none_indices == 0) { - return broadcast_to(up, src.shape()); + return {{}, broadcast_to(up, src.shape()), {}}; } unsigned long max_dim = 0; @@ -621,25 +622,108 @@ array mlx_set_item_nd( std::vector axes(arr_indices.size(), 0); std::iota(axes.begin(), axes.end(), 0); - return scatter(src, arr_indices, up, axes); + + return {arr_indices, up, axes}; +} + +std::tuple, array, std::vector> +mlx_compute_scatter_args( + const array& src, + const py::object& obj, + const ScalarOrArray& v) { + auto vals = to_array(v, src.dtype()); + if (py::isinstance(obj)) { + return mlx_scatter_args_slice(src, obj, vals); + } else if (py::isinstance(obj)) { + return mlx_scatter_args_array(src, py::cast(obj), vals); + } else if (py::isinstance(obj)) { + return mlx_scatter_args_int(src, obj, vals); + } else if (py::isinstance(obj)) { + return mlx_scatter_args_nd(src, obj, vals); + } else if (obj.is_none()) { + return {{}, broadcast_to(vals, src.shape()), {}}; + } + throw std::invalid_argument("Cannot index mlx array using the given type."); } void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v) { - auto vals = to_array(v, src.dtype()); - auto impl = [&src, &obj, &vals]() { - if (py::isinstance(obj)) { - return mlx_set_item_slice(src, obj, vals); - } else if (py::isinstance(obj)) { - return mlx_set_item_array(src, py::cast(obj), vals); - } else if (py::isinstance(obj)) { - return mlx_set_item_int(src, obj, vals); - } else if (py::isinstance(obj)) { - return mlx_set_item_nd(src, obj, vals); - } else if (obj.is_none()) { - return broadcast_to(vals, src.shape()); - } - throw std::invalid_argument("Cannot index mlx array using the given type."); - }; - auto out = impl(); - src.overwrite_descriptor(out); + auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); + if (indices.size() > 0) { + auto out = scatter(src, indices, updates, axes); + src.overwrite_descriptor(out); + } else { + src.overwrite_descriptor(updates); + } +} + +array mlx_add_item( + const array& src, + const py::object& obj, + const ScalarOrArray& v) { + auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); + if (indices.size() > 0) { + return scatter_add(src, indices, updates, axes); + } else { + return src + updates; + } +} + +array mlx_subtract_item( + const array& src, + const py::object& obj, + const ScalarOrArray& v) { + auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); + if (indices.size() > 0) { + return scatter_add(src, indices, -updates, axes); + } else { + return src - updates; + } +} + +array mlx_multiply_item( + const array& src, + const py::object& obj, + const ScalarOrArray& v) { + auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); + if (indices.size() > 0) { + return scatter_prod(src, indices, updates, axes); + } else { + return src * updates; + } +} + +array mlx_divide_item( + const array& src, + const py::object& obj, + const ScalarOrArray& v) { + auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); + if (indices.size() > 0) { + return scatter_prod(src, indices, reciprocal(updates), axes); + } else { + return src / updates; + } +} + +array mlx_maximum_item( + const array& src, + const py::object& obj, + const ScalarOrArray& v) { + auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); + if (indices.size() > 0) { + return scatter_max(src, indices, updates, axes); + } else { + return maximum(src, updates); + } +} + +array mlx_minimum_item( + const array& src, + const py::object& obj, + const ScalarOrArray& v) { + auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v); + if (indices.size() > 0) { + return scatter_min(src, indices, updates, axes); + } else { + return minimum(src, updates); + } } diff --git a/python/src/indexing.h b/python/src/indexing.h index 0eaaac93d..0ddea859e 100644 --- a/python/src/indexing.h +++ b/python/src/indexing.h @@ -12,3 +12,27 @@ using namespace mlx::core; array mlx_get_item(const array& src, const py::object& obj); void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v); +array mlx_add_item( + const array& src, + const py::object& obj, + const ScalarOrArray& v); +array mlx_subtract_item( + const array& src, + const py::object& obj, + const ScalarOrArray& v); +array mlx_multiply_item( + const array& src, + const py::object& obj, + const ScalarOrArray& v); +array mlx_divide_item( + const array& src, + const py::object& obj, + const ScalarOrArray& v); +array mlx_maximum_item( + const array& src, + const py::object& obj, + const ScalarOrArray& v); +array mlx_minimum_item( + const array& src, + const py::object& obj, + const ScalarOrArray& v); diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 540c6e99a..90be116c4 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3310,7 +3310,6 @@ void init_ops(py::module_& m) { Returns: result (array): The tensor dot product. )pbdoc"); - m.def( "inner", &inner, @@ -3331,7 +3330,6 @@ void init_ops(py::module_& m) { Returns: result (array): The inner product. )pbdoc"); - m.def( "outer", &outer, diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 44775a11a..a227c8eb1 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -984,6 +984,53 @@ class TestArray(mlx_tests.MLXTestCase): a[2:-2, 2:-2] = 4 self.assertEqual(a[2, 2].item(), 4) + def test_array_at(self): + a = mx.array(1) + a = a.at[None].add(1) + self.assertEqual(a.item(), 2) + + a = mx.array([0, 1, 2]) + a = a.at[1].add(2) + self.assertEqual(a.tolist(), [0, 3, 2]) + + a = a.at[mx.array([0, 0, 0, 0])].add(1) + self.assertEqual(a.tolist(), [4, 3, 2]) + + a = mx.zeros((10, 10)) + a = a.at[0].add(mx.arange(10)) + self.assertEqual(a[0].tolist(), list(range(10))) + + a = mx.zeros((10, 10)) + index_x = mx.array([0, 2, 3, 7]) + index_y = mx.array([3, 3, 1, 2]) + u = mx.random.uniform(shape=(4,)) + a = a.at[index_x, index_y].add(u) + self.assertEqual(a.sum().item(), u.sum().item()) + self.assertEqual(a[index_x, index_y].tolist(), u.tolist()) + + # Test all array.at ops + a = mx.random.uniform(shape=(10, 5, 2)) + idx_x = mx.array([0, 4]) + update = mx.ones((2, 5)) + a[idx_x, :, 0] = 0 + a = a.at[idx_x, :, 0].add(update) + self.assertEqualArray(a[idx_x, :, 0], update) + a = a.at[idx_x, :, 0].subtract(update) + self.assertEqualArray(a[idx_x, :, 0], mx.zeros_like(update)) + a = a.at[idx_x, :, 0].add(2 * update) + self.assertEqualArray(a[idx_x, :, 0], 2 * update) + a = a.at[idx_x, :, 0].multiply(2 * update) + self.assertEqualArray(a[idx_x, :, 0], 4 * update) + a = a.at[idx_x, :, 0].divide(3 * update) + self.assertEqualArray(a[idx_x, :, 0], (4 / 3) * update) + a[idx_x, :, 0] = 5 + update = mx.arange(10).reshape(2, 5) + a = a.at[idx_x, :, 0].maximum(update) + self.assertEqualArray(a[idx_x, :, 0], mx.maximum(a[idx_x, :, 0], update)) + a[idx_x, :, 0] = 5 + a = a.at[idx_x, :, 0].minimum(update) + self.assertEqualArray(a[idx_x, :, 0], mx.minimum(a[idx_x, :, 0], update)) + def test_slice_negative_step(self): a_np = np.arange(20) a_mx = mx.array(a_np)