mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 10:48:09 +08:00
Scatter vjp (#394)
* Add a first scatter vjp * Implement the scatter_add vjp * Add array.at to implement user friendly scatters
This commit is contained in:

committed by
GitHub

parent
e9ca65c939
commit
961435a243
@@ -2122,6 +2122,78 @@ bool Scatter::is_equivalent(const Primitive& other) const {
|
||||
return reduce_type_ == s_other.reduce_type_ && axes_ == s_other.axes_;
|
||||
}
|
||||
|
||||
std::vector<array> Scatter::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums) {
|
||||
switch (reduce_type_) {
|
||||
case Scatter::None:
|
||||
case Scatter::Sum:
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[scatter] VJP implemented only for scatter and scatter_add");
|
||||
}
|
||||
|
||||
const array& values = primals[0];
|
||||
const array& updates = primals.back();
|
||||
const std::vector<array> indices(primals.begin() + 1, primals.end() - 1);
|
||||
|
||||
std::vector<array> vjps;
|
||||
for (auto num : argnums) {
|
||||
// Gradient wrt to the input array
|
||||
if (num == 0) {
|
||||
switch (reduce_type_) {
|
||||
case Scatter::None:
|
||||
// Scatter 0s to the locations that were updated with the updates
|
||||
vjps.push_back(scatter(
|
||||
cotangents[0],
|
||||
indices,
|
||||
zeros_like(updates, stream()),
|
||||
axes_,
|
||||
stream()));
|
||||
break;
|
||||
case Scatter::Sum:
|
||||
// The input array values are kept so they all get gradients
|
||||
vjps.push_back(cotangents[0]);
|
||||
break;
|
||||
default:
|
||||
// Should never reach here
|
||||
throw std::invalid_argument("");
|
||||
}
|
||||
} else if (num == primals.size() - 1) {
|
||||
switch (reduce_type_) {
|
||||
case Scatter::None:
|
||||
case Scatter::Sum: {
|
||||
// Gather the values from the cotangent
|
||||
auto slice_sizes = cotangents[0].shape();
|
||||
for (auto ax : axes_) {
|
||||
slice_sizes[ax] = 1;
|
||||
}
|
||||
vjps.push_back(
|
||||
gather(cotangents[0], indices, axes_, slice_sizes, stream()));
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
// Should never reach here
|
||||
throw std::invalid_argument("");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[scatter] Cannot calculate VJP with respect to indices.");
|
||||
}
|
||||
}
|
||||
return vjps;
|
||||
}
|
||||
|
||||
std::vector<array> Scatter::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
throw std::runtime_error("[scatter] JVP not yet implemented");
|
||||
}
|
||||
|
||||
std::vector<array> Sigmoid::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
|
@@ -1266,7 +1266,26 @@ class Scatter : public UnaryPrimitive {
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_PRINT(Scatter)
|
||||
DEFINE_GRADS();
|
||||
void print(std::ostream& os) override {
|
||||
os << "Scatter";
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
os << " Sum";
|
||||
break;
|
||||
case Prod:
|
||||
os << " Prod";
|
||||
break;
|
||||
case Min:
|
||||
os << " Min";
|
||||
break;
|
||||
case Max:
|
||||
os << " Max";
|
||||
break;
|
||||
case None:
|
||||
break;
|
||||
}
|
||||
}
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
|
Reference in New Issue
Block a user