mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 02:41:19 +08:00
Scatter vjp (#394)
* Add a first scatter vjp * Implement the scatter_add vjp * Add array.at to implement user friendly scatters
This commit is contained in:
parent
e9ca65c939
commit
961435a243
@ -2122,6 +2122,78 @@ bool Scatter::is_equivalent(const Primitive& other) const {
|
||||
return reduce_type_ == s_other.reduce_type_ && axes_ == s_other.axes_;
|
||||
}
|
||||
|
||||
std::vector<array> Scatter::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
switch (reduce_type_) {
|
||||
case Scatter::None:
|
||||
case Scatter::Sum:
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[scatter] VJP implemented only for scatter and scatter_add");
|
||||
}
|
||||
|
||||
const array& values = primals[0];
|
||||
const array& updates = primals.back();
|
||||
const std::vector<array> indices(primals.begin() + 1, primals.end() - 1);
|
||||
|
||||
std::vector<array> vjps;
|
||||
for (auto num : argnums) {
|
||||
// Gradient wrt to the input array
|
||||
if (num == 0) {
|
||||
switch (reduce_type_) {
|
||||
case Scatter::None:
|
||||
// Scatter 0s to the locations that were updated with the updates
|
||||
vjps.push_back(scatter(
|
||||
cotangents[0],
|
||||
indices,
|
||||
zeros_like(updates, stream()),
|
||||
axes_,
|
||||
stream()));
|
||||
break;
|
||||
case Scatter::Sum:
|
||||
// The input array values are kept so they all get gradients
|
||||
vjps.push_back(cotangents[0]);
|
||||
break;
|
||||
default:
|
||||
// Should never reach here
|
||||
throw std::invalid_argument("");
|
||||
}
|
||||
} else if (num == primals.size() - 1) {
|
||||
switch (reduce_type_) {
|
||||
case Scatter::None:
|
||||
case Scatter::Sum: {
|
||||
// Gather the values from the cotangent
|
||||
auto slice_sizes = cotangents[0].shape();
|
||||
for (auto ax : axes_) {
|
||||
slice_sizes[ax] = 1;
|
||||
}
|
||||
vjps.push_back(
|
||||
gather(cotangents[0], indices, axes_, slice_sizes, stream()));
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
// Should never reach here
|
||||
throw std::invalid_argument("");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[scatter] Cannot calculate VJP with respect to indices.");
|
||||
}
|
||||
}
|
||||
return vjps;
|
||||
}
|
||||
|
||||
std::vector<array> Scatter::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
throw std::runtime_error("[scatter] JVP not yet implemented");
|
||||
}
|
||||
|
||||
std::vector<array> Sigmoid::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
|
@ -1266,7 +1266,26 @@ class Scatter : public UnaryPrimitive {
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_PRINT(Scatter)
|
||||
DEFINE_GRADS();
|
||||
void print(std::ostream& os) override {
|
||||
os << "Scatter";
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
os << " Sum";
|
||||
break;
|
||||
case Prod:
|
||||
os << " Prod";
|
||||
break;
|
||||
case Min:
|
||||
os << " Min";
|
||||
break;
|
||||
case Max:
|
||||
os << " Max";
|
||||
break;
|
||||
case None:
|
||||
break;
|
||||
}
|
||||
}
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
|
@ -462,6 +462,37 @@ array create_array(array_init_type v, std::optional<Dtype> t) {
|
||||
}
|
||||
}
|
||||
|
||||
class ArrayAt {
|
||||
public:
|
||||
ArrayAt(array x) : x_(std::move(x)) {}
|
||||
ArrayAt& set_indices(py::object indices) {
|
||||
indices_ = indices;
|
||||
return *this;
|
||||
}
|
||||
array add(const ScalarOrArray& v) {
|
||||
return mlx_add_item(x_, indices_, v);
|
||||
}
|
||||
array subtract(const ScalarOrArray& v) {
|
||||
return mlx_subtract_item(x_, indices_, v);
|
||||
}
|
||||
array multiply(const ScalarOrArray& v) {
|
||||
return mlx_multiply_item(x_, indices_, v);
|
||||
}
|
||||
array divide(const ScalarOrArray& v) {
|
||||
return mlx_divide_item(x_, indices_, v);
|
||||
}
|
||||
array maximum(const ScalarOrArray& v) {
|
||||
return mlx_maximum_item(x_, indices_, v);
|
||||
}
|
||||
array minimum(const ScalarOrArray& v) {
|
||||
return mlx_minimum_item(x_, indices_, v);
|
||||
}
|
||||
|
||||
private:
|
||||
array x_;
|
||||
py::object indices_;
|
||||
};
|
||||
|
||||
void init_array(py::module_& m) {
|
||||
// Types
|
||||
py::class_<Dtype>(
|
||||
@ -501,6 +532,26 @@ void init_array(py::module_& m) {
|
||||
m.attr("bfloat16") = py::cast(bfloat16);
|
||||
m.attr("complex64") = py::cast(complex64);
|
||||
|
||||
py::class_<ArrayAt>(
|
||||
m,
|
||||
"_ArrayAt",
|
||||
R"pbdoc(
|
||||
A helper object to apply updates at specific indices.
|
||||
)pbdoc")
|
||||
.def(
|
||||
py::init([](const array& x) { return ArrayAt(x); }),
|
||||
"x"_a,
|
||||
R"pbdoc(
|
||||
__init__(self, x: array)
|
||||
)pbdoc")
|
||||
.def("__getitem__", &ArrayAt::set_indices, "indices"_a)
|
||||
.def("add", &ArrayAt::add, "value"_a)
|
||||
.def("subtract", &ArrayAt::subtract, "value"_a)
|
||||
.def("multiply", &ArrayAt::multiply, "value"_a)
|
||||
.def("divide", &ArrayAt::divide, "value"_a)
|
||||
.def("maximum", &ArrayAt::maximum, "value"_a)
|
||||
.def("minimum", &ArrayAt::minimum, "value"_a);
|
||||
|
||||
auto array_class = py::class_<array>(
|
||||
m,
|
||||
"array",
|
||||
@ -610,6 +661,38 @@ void init_array(py::module_& m) {
|
||||
)pbdoc")
|
||||
.def("__getitem__", mlx_get_item)
|
||||
.def("__setitem__", mlx_set_item)
|
||||
.def_property_readonly(
|
||||
"at",
|
||||
[](const array& a) { return ArrayAt(a); },
|
||||
R"pbdoc(
|
||||
Used to apply updates at the given indices.
|
||||
|
||||
.. note::
|
||||
|
||||
Python in place updates for all array frameworks map to
|
||||
assignment. For instance ``x[idx] += y`` maps to ``x[idx] =
|
||||
x[idx] + y``. As a result, assigning to the same index ignores
|
||||
all but one updates. Using ``x.at[idx].add(y)`` will correctly
|
||||
apply all the updates to all indices.
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
|
||||
* - array.at syntax
|
||||
- In-place syntax
|
||||
* - ``x = x.at[idx].add(y)``
|
||||
- ``x[idx] += y``
|
||||
* - ``x = x.at[idx].subtract(y)``
|
||||
- ``x[idx] -= y``
|
||||
* - ``x = x.at[idx].multiply(y)``
|
||||
- ``x[idx] *= y``
|
||||
* - ``x = x.at[idx].divide(y)``
|
||||
- ``x[idx] /= y``
|
||||
* - ``x = x.at[idx].maximum(y)``
|
||||
- ``x[idx] = mx.maximum(x[idx], y)``
|
||||
* - ``x = x.at[idx].minimum(y)``
|
||||
- ``x[idx] = mx.minimum(x[idx], y)``
|
||||
)pbdoc")
|
||||
.def(
|
||||
"__len__",
|
||||
[](const array& a) {
|
||||
|
@ -392,7 +392,7 @@ array mlx_get_item(const array& src, const py::object& obj) {
|
||||
throw std::invalid_argument("Cannot index mlx array using the given type.");
|
||||
}
|
||||
|
||||
array mlx_set_item_int(
|
||||
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_int(
|
||||
const array& src,
|
||||
const py::int_& idx,
|
||||
const array& update) {
|
||||
@ -410,14 +410,14 @@ array mlx_set_item_int(
|
||||
std::vector<int>(update.shape().begin() + s, update.shape().end());
|
||||
auto shape = src.shape();
|
||||
shape[0] = 1;
|
||||
return scatter(
|
||||
src,
|
||||
get_int_index(idx, src.shape(0)),
|
||||
|
||||
return {
|
||||
{get_int_index(idx, src.shape(0))},
|
||||
broadcast_to(reshape(update, up_shape), shape),
|
||||
0);
|
||||
{0}};
|
||||
}
|
||||
|
||||
array mlx_set_item_array(
|
||||
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_array(
|
||||
const array& src,
|
||||
const array& indices,
|
||||
const array& update) {
|
||||
@ -441,10 +441,10 @@ array mlx_set_item_array(
|
||||
up_shape.insert(up_shape.begin() + indices.ndim(), 1);
|
||||
up = reshape(up, up_shape);
|
||||
|
||||
return scatter(src, indices, up, 0);
|
||||
return {{indices}, up, {0}};
|
||||
}
|
||||
|
||||
array mlx_set_item_slice(
|
||||
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice(
|
||||
const array& src,
|
||||
const py::slice& in_slice,
|
||||
const array& update) {
|
||||
@ -462,7 +462,7 @@ array mlx_set_item_slice(
|
||||
;
|
||||
auto up_shape =
|
||||
std::vector<int>(update.shape().begin() + s, update.shape().end());
|
||||
return broadcast_to(reshape(update, up_shape), src.shape());
|
||||
return {{}, broadcast_to(reshape(update, up_shape), src.shape()), {}};
|
||||
}
|
||||
|
||||
int start = 0;
|
||||
@ -472,10 +472,11 @@ array mlx_set_item_slice(
|
||||
// Check and update slice params
|
||||
get_slice_params(start, end, stride, in_slice, end);
|
||||
|
||||
return mlx_set_item_array(src, arange(start, end, stride, uint32), update);
|
||||
return mlx_scatter_args_array(
|
||||
src, arange(start, end, stride, uint32), update);
|
||||
}
|
||||
|
||||
array mlx_set_item_nd(
|
||||
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
const array& src,
|
||||
const py::tuple& entries,
|
||||
const array& update) {
|
||||
@ -537,7 +538,7 @@ array mlx_set_item_nd(
|
||||
|
||||
// If no non-None indices return the broadcasted update
|
||||
if (non_none_indices == 0) {
|
||||
return broadcast_to(up, src.shape());
|
||||
return {{}, broadcast_to(up, src.shape()), {}};
|
||||
}
|
||||
|
||||
unsigned long max_dim = 0;
|
||||
@ -621,25 +622,108 @@ array mlx_set_item_nd(
|
||||
|
||||
std::vector<int> axes(arr_indices.size(), 0);
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
return scatter(src, arr_indices, up, axes);
|
||||
|
||||
return {arr_indices, up, axes};
|
||||
}
|
||||
|
||||
std::tuple<std::vector<array>, array, std::vector<int>>
|
||||
mlx_compute_scatter_args(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
auto vals = to_array(v, src.dtype());
|
||||
if (py::isinstance<py::slice>(obj)) {
|
||||
return mlx_scatter_args_slice(src, obj, vals);
|
||||
} else if (py::isinstance<array>(obj)) {
|
||||
return mlx_scatter_args_array(src, py::cast<array>(obj), vals);
|
||||
} else if (py::isinstance<py::int_>(obj)) {
|
||||
return mlx_scatter_args_int(src, obj, vals);
|
||||
} else if (py::isinstance<py::tuple>(obj)) {
|
||||
return mlx_scatter_args_nd(src, obj, vals);
|
||||
} else if (obj.is_none()) {
|
||||
return {{}, broadcast_to(vals, src.shape()), {}};
|
||||
}
|
||||
throw std::invalid_argument("Cannot index mlx array using the given type.");
|
||||
}
|
||||
|
||||
void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v) {
|
||||
auto vals = to_array(v, src.dtype());
|
||||
auto impl = [&src, &obj, &vals]() {
|
||||
if (py::isinstance<py::slice>(obj)) {
|
||||
return mlx_set_item_slice(src, obj, vals);
|
||||
} else if (py::isinstance<array>(obj)) {
|
||||
return mlx_set_item_array(src, py::cast<array>(obj), vals);
|
||||
} else if (py::isinstance<py::int_>(obj)) {
|
||||
return mlx_set_item_int(src, obj, vals);
|
||||
} else if (py::isinstance<py::tuple>(obj)) {
|
||||
return mlx_set_item_nd(src, obj, vals);
|
||||
} else if (obj.is_none()) {
|
||||
return broadcast_to(vals, src.shape());
|
||||
}
|
||||
throw std::invalid_argument("Cannot index mlx array using the given type.");
|
||||
};
|
||||
auto out = impl();
|
||||
src.overwrite_descriptor(out);
|
||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||
if (indices.size() > 0) {
|
||||
auto out = scatter(src, indices, updates, axes);
|
||||
src.overwrite_descriptor(out);
|
||||
} else {
|
||||
src.overwrite_descriptor(updates);
|
||||
}
|
||||
}
|
||||
|
||||
array mlx_add_item(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||
if (indices.size() > 0) {
|
||||
return scatter_add(src, indices, updates, axes);
|
||||
} else {
|
||||
return src + updates;
|
||||
}
|
||||
}
|
||||
|
||||
array mlx_subtract_item(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||
if (indices.size() > 0) {
|
||||
return scatter_add(src, indices, -updates, axes);
|
||||
} else {
|
||||
return src - updates;
|
||||
}
|
||||
}
|
||||
|
||||
array mlx_multiply_item(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||
if (indices.size() > 0) {
|
||||
return scatter_prod(src, indices, updates, axes);
|
||||
} else {
|
||||
return src * updates;
|
||||
}
|
||||
}
|
||||
|
||||
array mlx_divide_item(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||
if (indices.size() > 0) {
|
||||
return scatter_prod(src, indices, reciprocal(updates), axes);
|
||||
} else {
|
||||
return src / updates;
|
||||
}
|
||||
}
|
||||
|
||||
array mlx_maximum_item(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||
if (indices.size() > 0) {
|
||||
return scatter_max(src, indices, updates, axes);
|
||||
} else {
|
||||
return maximum(src, updates);
|
||||
}
|
||||
}
|
||||
|
||||
array mlx_minimum_item(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||
if (indices.size() > 0) {
|
||||
return scatter_min(src, indices, updates, axes);
|
||||
} else {
|
||||
return minimum(src, updates);
|
||||
}
|
||||
}
|
||||
|
@ -12,3 +12,27 @@ using namespace mlx::core;
|
||||
|
||||
array mlx_get_item(const array& src, const py::object& obj);
|
||||
void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v);
|
||||
array mlx_add_item(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const ScalarOrArray& v);
|
||||
array mlx_subtract_item(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const ScalarOrArray& v);
|
||||
array mlx_multiply_item(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const ScalarOrArray& v);
|
||||
array mlx_divide_item(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const ScalarOrArray& v);
|
||||
array mlx_maximum_item(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const ScalarOrArray& v);
|
||||
array mlx_minimum_item(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const ScalarOrArray& v);
|
||||
|
@ -3310,7 +3310,6 @@ void init_ops(py::module_& m) {
|
||||
Returns:
|
||||
result (array): The tensor dot product.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"inner",
|
||||
&inner,
|
||||
@ -3331,7 +3330,6 @@ void init_ops(py::module_& m) {
|
||||
Returns:
|
||||
result (array): The inner product.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"outer",
|
||||
&outer,
|
||||
|
@ -984,6 +984,53 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
a[2:-2, 2:-2] = 4
|
||||
self.assertEqual(a[2, 2].item(), 4)
|
||||
|
||||
def test_array_at(self):
|
||||
a = mx.array(1)
|
||||
a = a.at[None].add(1)
|
||||
self.assertEqual(a.item(), 2)
|
||||
|
||||
a = mx.array([0, 1, 2])
|
||||
a = a.at[1].add(2)
|
||||
self.assertEqual(a.tolist(), [0, 3, 2])
|
||||
|
||||
a = a.at[mx.array([0, 0, 0, 0])].add(1)
|
||||
self.assertEqual(a.tolist(), [4, 3, 2])
|
||||
|
||||
a = mx.zeros((10, 10))
|
||||
a = a.at[0].add(mx.arange(10))
|
||||
self.assertEqual(a[0].tolist(), list(range(10)))
|
||||
|
||||
a = mx.zeros((10, 10))
|
||||
index_x = mx.array([0, 2, 3, 7])
|
||||
index_y = mx.array([3, 3, 1, 2])
|
||||
u = mx.random.uniform(shape=(4,))
|
||||
a = a.at[index_x, index_y].add(u)
|
||||
self.assertEqual(a.sum().item(), u.sum().item())
|
||||
self.assertEqual(a[index_x, index_y].tolist(), u.tolist())
|
||||
|
||||
# Test all array.at ops
|
||||
a = mx.random.uniform(shape=(10, 5, 2))
|
||||
idx_x = mx.array([0, 4])
|
||||
update = mx.ones((2, 5))
|
||||
a[idx_x, :, 0] = 0
|
||||
a = a.at[idx_x, :, 0].add(update)
|
||||
self.assertEqualArray(a[idx_x, :, 0], update)
|
||||
a = a.at[idx_x, :, 0].subtract(update)
|
||||
self.assertEqualArray(a[idx_x, :, 0], mx.zeros_like(update))
|
||||
a = a.at[idx_x, :, 0].add(2 * update)
|
||||
self.assertEqualArray(a[idx_x, :, 0], 2 * update)
|
||||
a = a.at[idx_x, :, 0].multiply(2 * update)
|
||||
self.assertEqualArray(a[idx_x, :, 0], 4 * update)
|
||||
a = a.at[idx_x, :, 0].divide(3 * update)
|
||||
self.assertEqualArray(a[idx_x, :, 0], (4 / 3) * update)
|
||||
a[idx_x, :, 0] = 5
|
||||
update = mx.arange(10).reshape(2, 5)
|
||||
a = a.at[idx_x, :, 0].maximum(update)
|
||||
self.assertEqualArray(a[idx_x, :, 0], mx.maximum(a[idx_x, :, 0], update))
|
||||
a[idx_x, :, 0] = 5
|
||||
a = a.at[idx_x, :, 0].minimum(update)
|
||||
self.assertEqualArray(a[idx_x, :, 0], mx.minimum(a[idx_x, :, 0], update))
|
||||
|
||||
def test_slice_negative_step(self):
|
||||
a_np = np.arange(20)
|
||||
a_mx = mx.array(a_np)
|
||||
|
Loading…
Reference in New Issue
Block a user