mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-25 03:31:17 +08:00
Scatter vjp (#394)
* Add a first scatter vjp * Implement the scatter_add vjp * Add array.at to implement user friendly scatters
This commit is contained in:
parent
e9ca65c939
commit
961435a243
@ -2122,6 +2122,78 @@ bool Scatter::is_equivalent(const Primitive& other) const {
|
|||||||
return reduce_type_ == s_other.reduce_type_ && axes_ == s_other.axes_;
|
return reduce_type_ == s_other.reduce_type_ && axes_ == s_other.axes_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<array> Scatter::vjp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& cotangents,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
switch (reduce_type_) {
|
||||||
|
case Scatter::None:
|
||||||
|
case Scatter::Sum:
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[scatter] VJP implemented only for scatter and scatter_add");
|
||||||
|
}
|
||||||
|
|
||||||
|
const array& values = primals[0];
|
||||||
|
const array& updates = primals.back();
|
||||||
|
const std::vector<array> indices(primals.begin() + 1, primals.end() - 1);
|
||||||
|
|
||||||
|
std::vector<array> vjps;
|
||||||
|
for (auto num : argnums) {
|
||||||
|
// Gradient wrt to the input array
|
||||||
|
if (num == 0) {
|
||||||
|
switch (reduce_type_) {
|
||||||
|
case Scatter::None:
|
||||||
|
// Scatter 0s to the locations that were updated with the updates
|
||||||
|
vjps.push_back(scatter(
|
||||||
|
cotangents[0],
|
||||||
|
indices,
|
||||||
|
zeros_like(updates, stream()),
|
||||||
|
axes_,
|
||||||
|
stream()));
|
||||||
|
break;
|
||||||
|
case Scatter::Sum:
|
||||||
|
// The input array values are kept so they all get gradients
|
||||||
|
vjps.push_back(cotangents[0]);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
// Should never reach here
|
||||||
|
throw std::invalid_argument("");
|
||||||
|
}
|
||||||
|
} else if (num == primals.size() - 1) {
|
||||||
|
switch (reduce_type_) {
|
||||||
|
case Scatter::None:
|
||||||
|
case Scatter::Sum: {
|
||||||
|
// Gather the values from the cotangent
|
||||||
|
auto slice_sizes = cotangents[0].shape();
|
||||||
|
for (auto ax : axes_) {
|
||||||
|
slice_sizes[ax] = 1;
|
||||||
|
}
|
||||||
|
vjps.push_back(
|
||||||
|
gather(cotangents[0], indices, axes_, slice_sizes, stream()));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
// Should never reach here
|
||||||
|
throw std::invalid_argument("");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[scatter] Cannot calculate VJP with respect to indices.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return vjps;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> Scatter::jvp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& tangents,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
throw std::runtime_error("[scatter] JVP not yet implemented");
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<array> Sigmoid::vjp(
|
std::vector<array> Sigmoid::vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& cotangents,
|
const std::vector<array>& cotangents,
|
||||||
|
@ -1266,7 +1266,26 @@ class Scatter : public UnaryPrimitive {
|
|||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
|
||||||
DEFINE_PRINT(Scatter)
|
DEFINE_GRADS();
|
||||||
|
void print(std::ostream& os) override {
|
||||||
|
os << "Scatter";
|
||||||
|
switch (reduce_type_) {
|
||||||
|
case Sum:
|
||||||
|
os << " Sum";
|
||||||
|
break;
|
||||||
|
case Prod:
|
||||||
|
os << " Prod";
|
||||||
|
break;
|
||||||
|
case Min:
|
||||||
|
os << " Min";
|
||||||
|
break;
|
||||||
|
case Max:
|
||||||
|
os << " Max";
|
||||||
|
break;
|
||||||
|
case None:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -462,6 +462,37 @@ array create_array(array_init_type v, std::optional<Dtype> t) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class ArrayAt {
|
||||||
|
public:
|
||||||
|
ArrayAt(array x) : x_(std::move(x)) {}
|
||||||
|
ArrayAt& set_indices(py::object indices) {
|
||||||
|
indices_ = indices;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
array add(const ScalarOrArray& v) {
|
||||||
|
return mlx_add_item(x_, indices_, v);
|
||||||
|
}
|
||||||
|
array subtract(const ScalarOrArray& v) {
|
||||||
|
return mlx_subtract_item(x_, indices_, v);
|
||||||
|
}
|
||||||
|
array multiply(const ScalarOrArray& v) {
|
||||||
|
return mlx_multiply_item(x_, indices_, v);
|
||||||
|
}
|
||||||
|
array divide(const ScalarOrArray& v) {
|
||||||
|
return mlx_divide_item(x_, indices_, v);
|
||||||
|
}
|
||||||
|
array maximum(const ScalarOrArray& v) {
|
||||||
|
return mlx_maximum_item(x_, indices_, v);
|
||||||
|
}
|
||||||
|
array minimum(const ScalarOrArray& v) {
|
||||||
|
return mlx_minimum_item(x_, indices_, v);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
array x_;
|
||||||
|
py::object indices_;
|
||||||
|
};
|
||||||
|
|
||||||
void init_array(py::module_& m) {
|
void init_array(py::module_& m) {
|
||||||
// Types
|
// Types
|
||||||
py::class_<Dtype>(
|
py::class_<Dtype>(
|
||||||
@ -501,6 +532,26 @@ void init_array(py::module_& m) {
|
|||||||
m.attr("bfloat16") = py::cast(bfloat16);
|
m.attr("bfloat16") = py::cast(bfloat16);
|
||||||
m.attr("complex64") = py::cast(complex64);
|
m.attr("complex64") = py::cast(complex64);
|
||||||
|
|
||||||
|
py::class_<ArrayAt>(
|
||||||
|
m,
|
||||||
|
"_ArrayAt",
|
||||||
|
R"pbdoc(
|
||||||
|
A helper object to apply updates at specific indices.
|
||||||
|
)pbdoc")
|
||||||
|
.def(
|
||||||
|
py::init([](const array& x) { return ArrayAt(x); }),
|
||||||
|
"x"_a,
|
||||||
|
R"pbdoc(
|
||||||
|
__init__(self, x: array)
|
||||||
|
)pbdoc")
|
||||||
|
.def("__getitem__", &ArrayAt::set_indices, "indices"_a)
|
||||||
|
.def("add", &ArrayAt::add, "value"_a)
|
||||||
|
.def("subtract", &ArrayAt::subtract, "value"_a)
|
||||||
|
.def("multiply", &ArrayAt::multiply, "value"_a)
|
||||||
|
.def("divide", &ArrayAt::divide, "value"_a)
|
||||||
|
.def("maximum", &ArrayAt::maximum, "value"_a)
|
||||||
|
.def("minimum", &ArrayAt::minimum, "value"_a);
|
||||||
|
|
||||||
auto array_class = py::class_<array>(
|
auto array_class = py::class_<array>(
|
||||||
m,
|
m,
|
||||||
"array",
|
"array",
|
||||||
@ -610,6 +661,38 @@ void init_array(py::module_& m) {
|
|||||||
)pbdoc")
|
)pbdoc")
|
||||||
.def("__getitem__", mlx_get_item)
|
.def("__getitem__", mlx_get_item)
|
||||||
.def("__setitem__", mlx_set_item)
|
.def("__setitem__", mlx_set_item)
|
||||||
|
.def_property_readonly(
|
||||||
|
"at",
|
||||||
|
[](const array& a) { return ArrayAt(a); },
|
||||||
|
R"pbdoc(
|
||||||
|
Used to apply updates at the given indices.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Python in place updates for all array frameworks map to
|
||||||
|
assignment. For instance ``x[idx] += y`` maps to ``x[idx] =
|
||||||
|
x[idx] + y``. As a result, assigning to the same index ignores
|
||||||
|
all but one updates. Using ``x.at[idx].add(y)`` will correctly
|
||||||
|
apply all the updates to all indices.
|
||||||
|
|
||||||
|
.. list-table::
|
||||||
|
:header-rows: 1
|
||||||
|
|
||||||
|
* - array.at syntax
|
||||||
|
- In-place syntax
|
||||||
|
* - ``x = x.at[idx].add(y)``
|
||||||
|
- ``x[idx] += y``
|
||||||
|
* - ``x = x.at[idx].subtract(y)``
|
||||||
|
- ``x[idx] -= y``
|
||||||
|
* - ``x = x.at[idx].multiply(y)``
|
||||||
|
- ``x[idx] *= y``
|
||||||
|
* - ``x = x.at[idx].divide(y)``
|
||||||
|
- ``x[idx] /= y``
|
||||||
|
* - ``x = x.at[idx].maximum(y)``
|
||||||
|
- ``x[idx] = mx.maximum(x[idx], y)``
|
||||||
|
* - ``x = x.at[idx].minimum(y)``
|
||||||
|
- ``x[idx] = mx.minimum(x[idx], y)``
|
||||||
|
)pbdoc")
|
||||||
.def(
|
.def(
|
||||||
"__len__",
|
"__len__",
|
||||||
[](const array& a) {
|
[](const array& a) {
|
||||||
|
@ -392,7 +392,7 @@ array mlx_get_item(const array& src, const py::object& obj) {
|
|||||||
throw std::invalid_argument("Cannot index mlx array using the given type.");
|
throw std::invalid_argument("Cannot index mlx array using the given type.");
|
||||||
}
|
}
|
||||||
|
|
||||||
array mlx_set_item_int(
|
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_int(
|
||||||
const array& src,
|
const array& src,
|
||||||
const py::int_& idx,
|
const py::int_& idx,
|
||||||
const array& update) {
|
const array& update) {
|
||||||
@ -410,14 +410,14 @@ array mlx_set_item_int(
|
|||||||
std::vector<int>(update.shape().begin() + s, update.shape().end());
|
std::vector<int>(update.shape().begin() + s, update.shape().end());
|
||||||
auto shape = src.shape();
|
auto shape = src.shape();
|
||||||
shape[0] = 1;
|
shape[0] = 1;
|
||||||
return scatter(
|
|
||||||
src,
|
return {
|
||||||
get_int_index(idx, src.shape(0)),
|
{get_int_index(idx, src.shape(0))},
|
||||||
broadcast_to(reshape(update, up_shape), shape),
|
broadcast_to(reshape(update, up_shape), shape),
|
||||||
0);
|
{0}};
|
||||||
}
|
}
|
||||||
|
|
||||||
array mlx_set_item_array(
|
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_array(
|
||||||
const array& src,
|
const array& src,
|
||||||
const array& indices,
|
const array& indices,
|
||||||
const array& update) {
|
const array& update) {
|
||||||
@ -441,10 +441,10 @@ array mlx_set_item_array(
|
|||||||
up_shape.insert(up_shape.begin() + indices.ndim(), 1);
|
up_shape.insert(up_shape.begin() + indices.ndim(), 1);
|
||||||
up = reshape(up, up_shape);
|
up = reshape(up, up_shape);
|
||||||
|
|
||||||
return scatter(src, indices, up, 0);
|
return {{indices}, up, {0}};
|
||||||
}
|
}
|
||||||
|
|
||||||
array mlx_set_item_slice(
|
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice(
|
||||||
const array& src,
|
const array& src,
|
||||||
const py::slice& in_slice,
|
const py::slice& in_slice,
|
||||||
const array& update) {
|
const array& update) {
|
||||||
@ -462,7 +462,7 @@ array mlx_set_item_slice(
|
|||||||
;
|
;
|
||||||
auto up_shape =
|
auto up_shape =
|
||||||
std::vector<int>(update.shape().begin() + s, update.shape().end());
|
std::vector<int>(update.shape().begin() + s, update.shape().end());
|
||||||
return broadcast_to(reshape(update, up_shape), src.shape());
|
return {{}, broadcast_to(reshape(update, up_shape), src.shape()), {}};
|
||||||
}
|
}
|
||||||
|
|
||||||
int start = 0;
|
int start = 0;
|
||||||
@ -472,10 +472,11 @@ array mlx_set_item_slice(
|
|||||||
// Check and update slice params
|
// Check and update slice params
|
||||||
get_slice_params(start, end, stride, in_slice, end);
|
get_slice_params(start, end, stride, in_slice, end);
|
||||||
|
|
||||||
return mlx_set_item_array(src, arange(start, end, stride, uint32), update);
|
return mlx_scatter_args_array(
|
||||||
|
src, arange(start, end, stride, uint32), update);
|
||||||
}
|
}
|
||||||
|
|
||||||
array mlx_set_item_nd(
|
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||||
const array& src,
|
const array& src,
|
||||||
const py::tuple& entries,
|
const py::tuple& entries,
|
||||||
const array& update) {
|
const array& update) {
|
||||||
@ -537,7 +538,7 @@ array mlx_set_item_nd(
|
|||||||
|
|
||||||
// If no non-None indices return the broadcasted update
|
// If no non-None indices return the broadcasted update
|
||||||
if (non_none_indices == 0) {
|
if (non_none_indices == 0) {
|
||||||
return broadcast_to(up, src.shape());
|
return {{}, broadcast_to(up, src.shape()), {}};
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned long max_dim = 0;
|
unsigned long max_dim = 0;
|
||||||
@ -621,25 +622,108 @@ array mlx_set_item_nd(
|
|||||||
|
|
||||||
std::vector<int> axes(arr_indices.size(), 0);
|
std::vector<int> axes(arr_indices.size(), 0);
|
||||||
std::iota(axes.begin(), axes.end(), 0);
|
std::iota(axes.begin(), axes.end(), 0);
|
||||||
return scatter(src, arr_indices, up, axes);
|
|
||||||
|
return {arr_indices, up, axes};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<std::vector<array>, array, std::vector<int>>
|
||||||
|
mlx_compute_scatter_args(
|
||||||
|
const array& src,
|
||||||
|
const py::object& obj,
|
||||||
|
const ScalarOrArray& v) {
|
||||||
|
auto vals = to_array(v, src.dtype());
|
||||||
|
if (py::isinstance<py::slice>(obj)) {
|
||||||
|
return mlx_scatter_args_slice(src, obj, vals);
|
||||||
|
} else if (py::isinstance<array>(obj)) {
|
||||||
|
return mlx_scatter_args_array(src, py::cast<array>(obj), vals);
|
||||||
|
} else if (py::isinstance<py::int_>(obj)) {
|
||||||
|
return mlx_scatter_args_int(src, obj, vals);
|
||||||
|
} else if (py::isinstance<py::tuple>(obj)) {
|
||||||
|
return mlx_scatter_args_nd(src, obj, vals);
|
||||||
|
} else if (obj.is_none()) {
|
||||||
|
return {{}, broadcast_to(vals, src.shape()), {}};
|
||||||
|
}
|
||||||
|
throw std::invalid_argument("Cannot index mlx array using the given type.");
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v) {
|
void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v) {
|
||||||
auto vals = to_array(v, src.dtype());
|
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||||
auto impl = [&src, &obj, &vals]() {
|
if (indices.size() > 0) {
|
||||||
if (py::isinstance<py::slice>(obj)) {
|
auto out = scatter(src, indices, updates, axes);
|
||||||
return mlx_set_item_slice(src, obj, vals);
|
src.overwrite_descriptor(out);
|
||||||
} else if (py::isinstance<array>(obj)) {
|
} else {
|
||||||
return mlx_set_item_array(src, py::cast<array>(obj), vals);
|
src.overwrite_descriptor(updates);
|
||||||
} else if (py::isinstance<py::int_>(obj)) {
|
}
|
||||||
return mlx_set_item_int(src, obj, vals);
|
}
|
||||||
} else if (py::isinstance<py::tuple>(obj)) {
|
|
||||||
return mlx_set_item_nd(src, obj, vals);
|
array mlx_add_item(
|
||||||
} else if (obj.is_none()) {
|
const array& src,
|
||||||
return broadcast_to(vals, src.shape());
|
const py::object& obj,
|
||||||
}
|
const ScalarOrArray& v) {
|
||||||
throw std::invalid_argument("Cannot index mlx array using the given type.");
|
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||||
};
|
if (indices.size() > 0) {
|
||||||
auto out = impl();
|
return scatter_add(src, indices, updates, axes);
|
||||||
src.overwrite_descriptor(out);
|
} else {
|
||||||
|
return src + updates;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
array mlx_subtract_item(
|
||||||
|
const array& src,
|
||||||
|
const py::object& obj,
|
||||||
|
const ScalarOrArray& v) {
|
||||||
|
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||||
|
if (indices.size() > 0) {
|
||||||
|
return scatter_add(src, indices, -updates, axes);
|
||||||
|
} else {
|
||||||
|
return src - updates;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
array mlx_multiply_item(
|
||||||
|
const array& src,
|
||||||
|
const py::object& obj,
|
||||||
|
const ScalarOrArray& v) {
|
||||||
|
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||||
|
if (indices.size() > 0) {
|
||||||
|
return scatter_prod(src, indices, updates, axes);
|
||||||
|
} else {
|
||||||
|
return src * updates;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
array mlx_divide_item(
|
||||||
|
const array& src,
|
||||||
|
const py::object& obj,
|
||||||
|
const ScalarOrArray& v) {
|
||||||
|
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||||
|
if (indices.size() > 0) {
|
||||||
|
return scatter_prod(src, indices, reciprocal(updates), axes);
|
||||||
|
} else {
|
||||||
|
return src / updates;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
array mlx_maximum_item(
|
||||||
|
const array& src,
|
||||||
|
const py::object& obj,
|
||||||
|
const ScalarOrArray& v) {
|
||||||
|
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||||
|
if (indices.size() > 0) {
|
||||||
|
return scatter_max(src, indices, updates, axes);
|
||||||
|
} else {
|
||||||
|
return maximum(src, updates);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
array mlx_minimum_item(
|
||||||
|
const array& src,
|
||||||
|
const py::object& obj,
|
||||||
|
const ScalarOrArray& v) {
|
||||||
|
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||||
|
if (indices.size() > 0) {
|
||||||
|
return scatter_min(src, indices, updates, axes);
|
||||||
|
} else {
|
||||||
|
return minimum(src, updates);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -12,3 +12,27 @@ using namespace mlx::core;
|
|||||||
|
|
||||||
array mlx_get_item(const array& src, const py::object& obj);
|
array mlx_get_item(const array& src, const py::object& obj);
|
||||||
void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v);
|
void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v);
|
||||||
|
array mlx_add_item(
|
||||||
|
const array& src,
|
||||||
|
const py::object& obj,
|
||||||
|
const ScalarOrArray& v);
|
||||||
|
array mlx_subtract_item(
|
||||||
|
const array& src,
|
||||||
|
const py::object& obj,
|
||||||
|
const ScalarOrArray& v);
|
||||||
|
array mlx_multiply_item(
|
||||||
|
const array& src,
|
||||||
|
const py::object& obj,
|
||||||
|
const ScalarOrArray& v);
|
||||||
|
array mlx_divide_item(
|
||||||
|
const array& src,
|
||||||
|
const py::object& obj,
|
||||||
|
const ScalarOrArray& v);
|
||||||
|
array mlx_maximum_item(
|
||||||
|
const array& src,
|
||||||
|
const py::object& obj,
|
||||||
|
const ScalarOrArray& v);
|
||||||
|
array mlx_minimum_item(
|
||||||
|
const array& src,
|
||||||
|
const py::object& obj,
|
||||||
|
const ScalarOrArray& v);
|
||||||
|
@ -3310,7 +3310,6 @@ void init_ops(py::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
result (array): The tensor dot product.
|
result (array): The tensor dot product.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"inner",
|
"inner",
|
||||||
&inner,
|
&inner,
|
||||||
@ -3331,7 +3330,6 @@ void init_ops(py::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
result (array): The inner product.
|
result (array): The inner product.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"outer",
|
"outer",
|
||||||
&outer,
|
&outer,
|
||||||
|
@ -984,6 +984,53 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
a[2:-2, 2:-2] = 4
|
a[2:-2, 2:-2] = 4
|
||||||
self.assertEqual(a[2, 2].item(), 4)
|
self.assertEqual(a[2, 2].item(), 4)
|
||||||
|
|
||||||
|
def test_array_at(self):
|
||||||
|
a = mx.array(1)
|
||||||
|
a = a.at[None].add(1)
|
||||||
|
self.assertEqual(a.item(), 2)
|
||||||
|
|
||||||
|
a = mx.array([0, 1, 2])
|
||||||
|
a = a.at[1].add(2)
|
||||||
|
self.assertEqual(a.tolist(), [0, 3, 2])
|
||||||
|
|
||||||
|
a = a.at[mx.array([0, 0, 0, 0])].add(1)
|
||||||
|
self.assertEqual(a.tolist(), [4, 3, 2])
|
||||||
|
|
||||||
|
a = mx.zeros((10, 10))
|
||||||
|
a = a.at[0].add(mx.arange(10))
|
||||||
|
self.assertEqual(a[0].tolist(), list(range(10)))
|
||||||
|
|
||||||
|
a = mx.zeros((10, 10))
|
||||||
|
index_x = mx.array([0, 2, 3, 7])
|
||||||
|
index_y = mx.array([3, 3, 1, 2])
|
||||||
|
u = mx.random.uniform(shape=(4,))
|
||||||
|
a = a.at[index_x, index_y].add(u)
|
||||||
|
self.assertEqual(a.sum().item(), u.sum().item())
|
||||||
|
self.assertEqual(a[index_x, index_y].tolist(), u.tolist())
|
||||||
|
|
||||||
|
# Test all array.at ops
|
||||||
|
a = mx.random.uniform(shape=(10, 5, 2))
|
||||||
|
idx_x = mx.array([0, 4])
|
||||||
|
update = mx.ones((2, 5))
|
||||||
|
a[idx_x, :, 0] = 0
|
||||||
|
a = a.at[idx_x, :, 0].add(update)
|
||||||
|
self.assertEqualArray(a[idx_x, :, 0], update)
|
||||||
|
a = a.at[idx_x, :, 0].subtract(update)
|
||||||
|
self.assertEqualArray(a[idx_x, :, 0], mx.zeros_like(update))
|
||||||
|
a = a.at[idx_x, :, 0].add(2 * update)
|
||||||
|
self.assertEqualArray(a[idx_x, :, 0], 2 * update)
|
||||||
|
a = a.at[idx_x, :, 0].multiply(2 * update)
|
||||||
|
self.assertEqualArray(a[idx_x, :, 0], 4 * update)
|
||||||
|
a = a.at[idx_x, :, 0].divide(3 * update)
|
||||||
|
self.assertEqualArray(a[idx_x, :, 0], (4 / 3) * update)
|
||||||
|
a[idx_x, :, 0] = 5
|
||||||
|
update = mx.arange(10).reshape(2, 5)
|
||||||
|
a = a.at[idx_x, :, 0].maximum(update)
|
||||||
|
self.assertEqualArray(a[idx_x, :, 0], mx.maximum(a[idx_x, :, 0], update))
|
||||||
|
a[idx_x, :, 0] = 5
|
||||||
|
a = a.at[idx_x, :, 0].minimum(update)
|
||||||
|
self.assertEqualArray(a[idx_x, :, 0], mx.minimum(a[idx_x, :, 0], update))
|
||||||
|
|
||||||
def test_slice_negative_step(self):
|
def test_slice_negative_step(self):
|
||||||
a_np = np.arange(20)
|
a_np = np.arange(20)
|
||||||
a_mx = mx.array(a_np)
|
a_mx = mx.array(a_np)
|
||||||
|
Loading…
Reference in New Issue
Block a user