Scatter vjp (#394)

* Add a first scatter vjp
* Implement the scatter_add vjp
* Add array.at to implement user friendly scatters
This commit is contained in:
Angelos Katharopoulos
2024-01-09 13:36:51 -08:00
committed by GitHub
parent e9ca65c939
commit 961435a243
7 changed files with 360 additions and 33 deletions

View File

@@ -392,7 +392,7 @@ array mlx_get_item(const array& src, const py::object& obj) {
throw std::invalid_argument("Cannot index mlx array using the given type.");
}
array mlx_set_item_int(
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_int(
const array& src,
const py::int_& idx,
const array& update) {
@@ -410,14 +410,14 @@ array mlx_set_item_int(
std::vector<int>(update.shape().begin() + s, update.shape().end());
auto shape = src.shape();
shape[0] = 1;
return scatter(
src,
get_int_index(idx, src.shape(0)),
return {
{get_int_index(idx, src.shape(0))},
broadcast_to(reshape(update, up_shape), shape),
0);
{0}};
}
array mlx_set_item_array(
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_array(
const array& src,
const array& indices,
const array& update) {
@@ -441,10 +441,10 @@ array mlx_set_item_array(
up_shape.insert(up_shape.begin() + indices.ndim(), 1);
up = reshape(up, up_shape);
return scatter(src, indices, up, 0);
return {{indices}, up, {0}};
}
array mlx_set_item_slice(
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice(
const array& src,
const py::slice& in_slice,
const array& update) {
@@ -462,7 +462,7 @@ array mlx_set_item_slice(
;
auto up_shape =
std::vector<int>(update.shape().begin() + s, update.shape().end());
return broadcast_to(reshape(update, up_shape), src.shape());
return {{}, broadcast_to(reshape(update, up_shape), src.shape()), {}};
}
int start = 0;
@@ -472,10 +472,11 @@ array mlx_set_item_slice(
// Check and update slice params
get_slice_params(start, end, stride, in_slice, end);
return mlx_set_item_array(src, arange(start, end, stride, uint32), update);
return mlx_scatter_args_array(
src, arange(start, end, stride, uint32), update);
}
array mlx_set_item_nd(
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
const array& src,
const py::tuple& entries,
const array& update) {
@@ -537,7 +538,7 @@ array mlx_set_item_nd(
// If no non-None indices return the broadcasted update
if (non_none_indices == 0) {
return broadcast_to(up, src.shape());
return {{}, broadcast_to(up, src.shape()), {}};
}
unsigned long max_dim = 0;
@@ -621,25 +622,108 @@ array mlx_set_item_nd(
std::vector<int> axes(arr_indices.size(), 0);
std::iota(axes.begin(), axes.end(), 0);
return scatter(src, arr_indices, up, axes);
return {arr_indices, up, axes};
}
std::tuple<std::vector<array>, array, std::vector<int>>
mlx_compute_scatter_args(
const array& src,
const py::object& obj,
const ScalarOrArray& v) {
auto vals = to_array(v, src.dtype());
if (py::isinstance<py::slice>(obj)) {
return mlx_scatter_args_slice(src, obj, vals);
} else if (py::isinstance<array>(obj)) {
return mlx_scatter_args_array(src, py::cast<array>(obj), vals);
} else if (py::isinstance<py::int_>(obj)) {
return mlx_scatter_args_int(src, obj, vals);
} else if (py::isinstance<py::tuple>(obj)) {
return mlx_scatter_args_nd(src, obj, vals);
} else if (obj.is_none()) {
return {{}, broadcast_to(vals, src.shape()), {}};
}
throw std::invalid_argument("Cannot index mlx array using the given type.");
}
void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v) {
auto vals = to_array(v, src.dtype());
auto impl = [&src, &obj, &vals]() {
if (py::isinstance<py::slice>(obj)) {
return mlx_set_item_slice(src, obj, vals);
} else if (py::isinstance<array>(obj)) {
return mlx_set_item_array(src, py::cast<array>(obj), vals);
} else if (py::isinstance<py::int_>(obj)) {
return mlx_set_item_int(src, obj, vals);
} else if (py::isinstance<py::tuple>(obj)) {
return mlx_set_item_nd(src, obj, vals);
} else if (obj.is_none()) {
return broadcast_to(vals, src.shape());
}
throw std::invalid_argument("Cannot index mlx array using the given type.");
};
auto out = impl();
src.overwrite_descriptor(out);
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
auto out = scatter(src, indices, updates, axes);
src.overwrite_descriptor(out);
} else {
src.overwrite_descriptor(updates);
}
}
array mlx_add_item(
const array& src,
const py::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
return scatter_add(src, indices, updates, axes);
} else {
return src + updates;
}
}
array mlx_subtract_item(
const array& src,
const py::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
return scatter_add(src, indices, -updates, axes);
} else {
return src - updates;
}
}
array mlx_multiply_item(
const array& src,
const py::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
return scatter_prod(src, indices, updates, axes);
} else {
return src * updates;
}
}
array mlx_divide_item(
const array& src,
const py::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
return scatter_prod(src, indices, reciprocal(updates), axes);
} else {
return src / updates;
}
}
array mlx_maximum_item(
const array& src,
const py::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
return scatter_max(src, indices, updates, axes);
} else {
return maximum(src, updates);
}
}
array mlx_minimum_item(
const array& src,
const py::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
return scatter_min(src, indices, updates, axes);
} else {
return minimum(src, updates);
}
}