Scatter vjp (#394)

* Add a first scatter vjp
* Implement the scatter_add vjp
* Add array.at to implement user friendly scatters
This commit is contained in:
Angelos Katharopoulos 2024-01-09 13:36:51 -08:00 committed by GitHub
parent e9ca65c939
commit 961435a243
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 360 additions and 33 deletions

View File

@ -2122,6 +2122,78 @@ bool Scatter::is_equivalent(const Primitive& other) const {
return reduce_type_ == s_other.reduce_type_ && axes_ == s_other.axes_;
}
std::vector<array> Scatter::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
switch (reduce_type_) {
case Scatter::None:
case Scatter::Sum:
break;
default:
throw std::runtime_error(
"[scatter] VJP implemented only for scatter and scatter_add");
}
const array& values = primals[0];
const array& updates = primals.back();
const std::vector<array> indices(primals.begin() + 1, primals.end() - 1);
std::vector<array> vjps;
for (auto num : argnums) {
// Gradient wrt to the input array
if (num == 0) {
switch (reduce_type_) {
case Scatter::None:
// Scatter 0s to the locations that were updated with the updates
vjps.push_back(scatter(
cotangents[0],
indices,
zeros_like(updates, stream()),
axes_,
stream()));
break;
case Scatter::Sum:
// The input array values are kept so they all get gradients
vjps.push_back(cotangents[0]);
break;
default:
// Should never reach here
throw std::invalid_argument("");
}
} else if (num == primals.size() - 1) {
switch (reduce_type_) {
case Scatter::None:
case Scatter::Sum: {
// Gather the values from the cotangent
auto slice_sizes = cotangents[0].shape();
for (auto ax : axes_) {
slice_sizes[ax] = 1;
}
vjps.push_back(
gather(cotangents[0], indices, axes_, slice_sizes, stream()));
break;
}
default: {
// Should never reach here
throw std::invalid_argument("");
}
}
} else {
throw std::invalid_argument(
"[scatter] Cannot calculate VJP with respect to indices.");
}
}
return vjps;
}
std::vector<array> Scatter::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
throw std::runtime_error("[scatter] JVP not yet implemented");
}
std::vector<array> Sigmoid::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,

View File

@ -1266,7 +1266,26 @@ class Scatter : public UnaryPrimitive {
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_PRINT(Scatter)
DEFINE_GRADS();
void print(std::ostream& os) override {
os << "Scatter";
switch (reduce_type_) {
case Sum:
os << " Sum";
break;
case Prod:
os << " Prod";
break;
case Min:
os << " Min";
break;
case Max:
os << " Max";
break;
case None:
break;
}
}
bool is_equivalent(const Primitive& other) const override;
private:

View File

@ -462,6 +462,37 @@ array create_array(array_init_type v, std::optional<Dtype> t) {
}
}
class ArrayAt {
public:
ArrayAt(array x) : x_(std::move(x)) {}
ArrayAt& set_indices(py::object indices) {
indices_ = indices;
return *this;
}
array add(const ScalarOrArray& v) {
return mlx_add_item(x_, indices_, v);
}
array subtract(const ScalarOrArray& v) {
return mlx_subtract_item(x_, indices_, v);
}
array multiply(const ScalarOrArray& v) {
return mlx_multiply_item(x_, indices_, v);
}
array divide(const ScalarOrArray& v) {
return mlx_divide_item(x_, indices_, v);
}
array maximum(const ScalarOrArray& v) {
return mlx_maximum_item(x_, indices_, v);
}
array minimum(const ScalarOrArray& v) {
return mlx_minimum_item(x_, indices_, v);
}
private:
array x_;
py::object indices_;
};
void init_array(py::module_& m) {
// Types
py::class_<Dtype>(
@ -501,6 +532,26 @@ void init_array(py::module_& m) {
m.attr("bfloat16") = py::cast(bfloat16);
m.attr("complex64") = py::cast(complex64);
py::class_<ArrayAt>(
m,
"_ArrayAt",
R"pbdoc(
A helper object to apply updates at specific indices.
)pbdoc")
.def(
py::init([](const array& x) { return ArrayAt(x); }),
"x"_a,
R"pbdoc(
__init__(self, x: array)
)pbdoc")
.def("__getitem__", &ArrayAt::set_indices, "indices"_a)
.def("add", &ArrayAt::add, "value"_a)
.def("subtract", &ArrayAt::subtract, "value"_a)
.def("multiply", &ArrayAt::multiply, "value"_a)
.def("divide", &ArrayAt::divide, "value"_a)
.def("maximum", &ArrayAt::maximum, "value"_a)
.def("minimum", &ArrayAt::minimum, "value"_a);
auto array_class = py::class_<array>(
m,
"array",
@ -610,6 +661,38 @@ void init_array(py::module_& m) {
)pbdoc")
.def("__getitem__", mlx_get_item)
.def("__setitem__", mlx_set_item)
.def_property_readonly(
"at",
[](const array& a) { return ArrayAt(a); },
R"pbdoc(
Used to apply updates at the given indices.
.. note::
Python in place updates for all array frameworks map to
assignment. For instance ``x[idx] += y`` maps to ``x[idx] =
x[idx] + y``. As a result, assigning to the same index ignores
all but one updates. Using ``x.at[idx].add(y)`` will correctly
apply all the updates to all indices.
.. list-table::
:header-rows: 1
* - array.at syntax
- In-place syntax
* - ``x = x.at[idx].add(y)``
- ``x[idx] += y``
* - ``x = x.at[idx].subtract(y)``
- ``x[idx] -= y``
* - ``x = x.at[idx].multiply(y)``
- ``x[idx] *= y``
* - ``x = x.at[idx].divide(y)``
- ``x[idx] /= y``
* - ``x = x.at[idx].maximum(y)``
- ``x[idx] = mx.maximum(x[idx], y)``
* - ``x = x.at[idx].minimum(y)``
- ``x[idx] = mx.minimum(x[idx], y)``
)pbdoc")
.def(
"__len__",
[](const array& a) {

View File

@ -392,7 +392,7 @@ array mlx_get_item(const array& src, const py::object& obj) {
throw std::invalid_argument("Cannot index mlx array using the given type.");
}
array mlx_set_item_int(
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_int(
const array& src,
const py::int_& idx,
const array& update) {
@ -410,14 +410,14 @@ array mlx_set_item_int(
std::vector<int>(update.shape().begin() + s, update.shape().end());
auto shape = src.shape();
shape[0] = 1;
return scatter(
src,
get_int_index(idx, src.shape(0)),
return {
{get_int_index(idx, src.shape(0))},
broadcast_to(reshape(update, up_shape), shape),
0);
{0}};
}
array mlx_set_item_array(
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_array(
const array& src,
const array& indices,
const array& update) {
@ -441,10 +441,10 @@ array mlx_set_item_array(
up_shape.insert(up_shape.begin() + indices.ndim(), 1);
up = reshape(up, up_shape);
return scatter(src, indices, up, 0);
return {{indices}, up, {0}};
}
array mlx_set_item_slice(
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice(
const array& src,
const py::slice& in_slice,
const array& update) {
@ -462,7 +462,7 @@ array mlx_set_item_slice(
;
auto up_shape =
std::vector<int>(update.shape().begin() + s, update.shape().end());
return broadcast_to(reshape(update, up_shape), src.shape());
return {{}, broadcast_to(reshape(update, up_shape), src.shape()), {}};
}
int start = 0;
@ -472,10 +472,11 @@ array mlx_set_item_slice(
// Check and update slice params
get_slice_params(start, end, stride, in_slice, end);
return mlx_set_item_array(src, arange(start, end, stride, uint32), update);
return mlx_scatter_args_array(
src, arange(start, end, stride, uint32), update);
}
array mlx_set_item_nd(
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
const array& src,
const py::tuple& entries,
const array& update) {
@ -537,7 +538,7 @@ array mlx_set_item_nd(
// If no non-None indices return the broadcasted update
if (non_none_indices == 0) {
return broadcast_to(up, src.shape());
return {{}, broadcast_to(up, src.shape()), {}};
}
unsigned long max_dim = 0;
@ -621,25 +622,108 @@ array mlx_set_item_nd(
std::vector<int> axes(arr_indices.size(), 0);
std::iota(axes.begin(), axes.end(), 0);
return scatter(src, arr_indices, up, axes);
return {arr_indices, up, axes};
}
std::tuple<std::vector<array>, array, std::vector<int>>
mlx_compute_scatter_args(
const array& src,
const py::object& obj,
const ScalarOrArray& v) {
auto vals = to_array(v, src.dtype());
if (py::isinstance<py::slice>(obj)) {
return mlx_scatter_args_slice(src, obj, vals);
} else if (py::isinstance<array>(obj)) {
return mlx_scatter_args_array(src, py::cast<array>(obj), vals);
} else if (py::isinstance<py::int_>(obj)) {
return mlx_scatter_args_int(src, obj, vals);
} else if (py::isinstance<py::tuple>(obj)) {
return mlx_scatter_args_nd(src, obj, vals);
} else if (obj.is_none()) {
return {{}, broadcast_to(vals, src.shape()), {}};
}
throw std::invalid_argument("Cannot index mlx array using the given type.");
}
void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v) {
auto vals = to_array(v, src.dtype());
auto impl = [&src, &obj, &vals]() {
if (py::isinstance<py::slice>(obj)) {
return mlx_set_item_slice(src, obj, vals);
} else if (py::isinstance<array>(obj)) {
return mlx_set_item_array(src, py::cast<array>(obj), vals);
} else if (py::isinstance<py::int_>(obj)) {
return mlx_set_item_int(src, obj, vals);
} else if (py::isinstance<py::tuple>(obj)) {
return mlx_set_item_nd(src, obj, vals);
} else if (obj.is_none()) {
return broadcast_to(vals, src.shape());
}
throw std::invalid_argument("Cannot index mlx array using the given type.");
};
auto out = impl();
src.overwrite_descriptor(out);
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
auto out = scatter(src, indices, updates, axes);
src.overwrite_descriptor(out);
} else {
src.overwrite_descriptor(updates);
}
}
array mlx_add_item(
const array& src,
const py::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
return scatter_add(src, indices, updates, axes);
} else {
return src + updates;
}
}
array mlx_subtract_item(
const array& src,
const py::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
return scatter_add(src, indices, -updates, axes);
} else {
return src - updates;
}
}
array mlx_multiply_item(
const array& src,
const py::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
return scatter_prod(src, indices, updates, axes);
} else {
return src * updates;
}
}
array mlx_divide_item(
const array& src,
const py::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
return scatter_prod(src, indices, reciprocal(updates), axes);
} else {
return src / updates;
}
}
array mlx_maximum_item(
const array& src,
const py::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
return scatter_max(src, indices, updates, axes);
} else {
return maximum(src, updates);
}
}
array mlx_minimum_item(
const array& src,
const py::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
return scatter_min(src, indices, updates, axes);
} else {
return minimum(src, updates);
}
}

View File

@ -12,3 +12,27 @@ using namespace mlx::core;
array mlx_get_item(const array& src, const py::object& obj);
void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v);
array mlx_add_item(
const array& src,
const py::object& obj,
const ScalarOrArray& v);
array mlx_subtract_item(
const array& src,
const py::object& obj,
const ScalarOrArray& v);
array mlx_multiply_item(
const array& src,
const py::object& obj,
const ScalarOrArray& v);
array mlx_divide_item(
const array& src,
const py::object& obj,
const ScalarOrArray& v);
array mlx_maximum_item(
const array& src,
const py::object& obj,
const ScalarOrArray& v);
array mlx_minimum_item(
const array& src,
const py::object& obj,
const ScalarOrArray& v);

View File

@ -3310,7 +3310,6 @@ void init_ops(py::module_& m) {
Returns:
result (array): The tensor dot product.
)pbdoc");
m.def(
"inner",
&inner,
@ -3331,7 +3330,6 @@ void init_ops(py::module_& m) {
Returns:
result (array): The inner product.
)pbdoc");
m.def(
"outer",
&outer,

View File

@ -984,6 +984,53 @@ class TestArray(mlx_tests.MLXTestCase):
a[2:-2, 2:-2] = 4
self.assertEqual(a[2, 2].item(), 4)
def test_array_at(self):
a = mx.array(1)
a = a.at[None].add(1)
self.assertEqual(a.item(), 2)
a = mx.array([0, 1, 2])
a = a.at[1].add(2)
self.assertEqual(a.tolist(), [0, 3, 2])
a = a.at[mx.array([0, 0, 0, 0])].add(1)
self.assertEqual(a.tolist(), [4, 3, 2])
a = mx.zeros((10, 10))
a = a.at[0].add(mx.arange(10))
self.assertEqual(a[0].tolist(), list(range(10)))
a = mx.zeros((10, 10))
index_x = mx.array([0, 2, 3, 7])
index_y = mx.array([3, 3, 1, 2])
u = mx.random.uniform(shape=(4,))
a = a.at[index_x, index_y].add(u)
self.assertEqual(a.sum().item(), u.sum().item())
self.assertEqual(a[index_x, index_y].tolist(), u.tolist())
# Test all array.at ops
a = mx.random.uniform(shape=(10, 5, 2))
idx_x = mx.array([0, 4])
update = mx.ones((2, 5))
a[idx_x, :, 0] = 0
a = a.at[idx_x, :, 0].add(update)
self.assertEqualArray(a[idx_x, :, 0], update)
a = a.at[idx_x, :, 0].subtract(update)
self.assertEqualArray(a[idx_x, :, 0], mx.zeros_like(update))
a = a.at[idx_x, :, 0].add(2 * update)
self.assertEqualArray(a[idx_x, :, 0], 2 * update)
a = a.at[idx_x, :, 0].multiply(2 * update)
self.assertEqualArray(a[idx_x, :, 0], 4 * update)
a = a.at[idx_x, :, 0].divide(3 * update)
self.assertEqualArray(a[idx_x, :, 0], (4 / 3) * update)
a[idx_x, :, 0] = 5
update = mx.arange(10).reshape(2, 5)
a = a.at[idx_x, :, 0].maximum(update)
self.assertEqualArray(a[idx_x, :, 0], mx.maximum(a[idx_x, :, 0], update))
a[idx_x, :, 0] = 5
a = a.at[idx_x, :, 0].minimum(update)
self.assertEqualArray(a[idx_x, :, 0], mx.minimum(a[idx_x, :, 0], update))
def test_slice_negative_step(self):
a_np = np.arange(20)
a_mx = mx.array(a_np)