Scatter vjp (#394)

* Add a first scatter vjp
* Implement the scatter_add vjp
* Add array.at to implement user friendly scatters
This commit is contained in:
Angelos Katharopoulos
2024-01-09 13:36:51 -08:00
committed by GitHub
parent e9ca65c939
commit 961435a243
7 changed files with 360 additions and 33 deletions

View File

@@ -2122,6 +2122,78 @@ bool Scatter::is_equivalent(const Primitive& other) const {
return reduce_type_ == s_other.reduce_type_ && axes_ == s_other.axes_;
}
std::vector<array> Scatter::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums) {
switch (reduce_type_) {
case Scatter::None:
case Scatter::Sum:
break;
default:
throw std::runtime_error(
"[scatter] VJP implemented only for scatter and scatter_add");
}
const array& values = primals[0];
const array& updates = primals.back();
const std::vector<array> indices(primals.begin() + 1, primals.end() - 1);
std::vector<array> vjps;
for (auto num : argnums) {
// Gradient wrt to the input array
if (num == 0) {
switch (reduce_type_) {
case Scatter::None:
// Scatter 0s to the locations that were updated with the updates
vjps.push_back(scatter(
cotangents[0],
indices,
zeros_like(updates, stream()),
axes_,
stream()));
break;
case Scatter::Sum:
// The input array values are kept so they all get gradients
vjps.push_back(cotangents[0]);
break;
default:
// Should never reach here
throw std::invalid_argument("");
}
} else if (num == primals.size() - 1) {
switch (reduce_type_) {
case Scatter::None:
case Scatter::Sum: {
// Gather the values from the cotangent
auto slice_sizes = cotangents[0].shape();
for (auto ax : axes_) {
slice_sizes[ax] = 1;
}
vjps.push_back(
gather(cotangents[0], indices, axes_, slice_sizes, stream()));
break;
}
default: {
// Should never reach here
throw std::invalid_argument("");
}
}
} else {
throw std::invalid_argument(
"[scatter] Cannot calculate VJP with respect to indices.");
}
}
return vjps;
}
std::vector<array> Scatter::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
throw std::runtime_error("[scatter] JVP not yet implemented");
}
std::vector<array> Sigmoid::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,

View File

@@ -1266,7 +1266,26 @@ class Scatter : public UnaryPrimitive {
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_PRINT(Scatter)
DEFINE_GRADS();
void print(std::ostream& os) override {
os << "Scatter";
switch (reduce_type_) {
case Sum:
os << " Sum";
break;
case Prod:
os << " Prod";
break;
case Min:
os << " Min";
break;
case Max:
os << " Max";
break;
case None:
break;
}
}
bool is_equivalent(const Primitive& other) const override;
private: