scatter_max vjp + bindings + tests (#431)

Co-authored-by: DjamelMesbah <djamel.mesbah@adservio.fr>
This commit is contained in:
Tristan Bilot
2024-01-14 23:12:15 +01:00
committed by GitHub
parent 4bc446be08
commit 6022d4129e
2 changed files with 53 additions and 0 deletions

View File

@@ -2129,6 +2129,7 @@ std::vector<array> Scatter::vjp(
switch (reduce_type_) {
case Scatter::None:
case Scatter::Sum:
case Scatter::Max:
break;
default:
throw std::runtime_error(
@@ -2139,6 +2140,17 @@ std::vector<array> Scatter::vjp(
const array& updates = primals.back();
const std::vector<array> indices(primals.begin() + 1, primals.end() - 1);
// Store result of scatter if needed for reuse in vjp
auto get_result = [&]() {
switch (reduce_type_) {
case Scatter::Max:
return scatter_max(values, indices, updates, axes_, stream());
default:
return array({});
}
};
array result = get_result();
std::vector<array> vjps;
for (auto num : argnums) {
// Gradient wrt to the input array
@@ -2157,6 +2169,11 @@ std::vector<array> Scatter::vjp(
// The input array values are kept so they all get gradients
vjps.push_back(cotangents[0]);
break;
case Scatter::Max: {
auto mask = where(result == values, array({1}), array({0}));
vjps.push_back(multiply(cotangents[0], mask));
break;
}
default:
// Should never reach here
throw std::invalid_argument("");
@@ -2174,6 +2191,19 @@ std::vector<array> Scatter::vjp(
gather(cotangents[0], indices, axes_, slice_sizes, stream()));
break;
}
case Scatter::Max: {
auto slice_sizes = cotangents[0].shape();
for (auto ax : axes_) {
slice_sizes[ax] = 1;
}
auto gathered_cotan =
gather(cotangents[0], indices, axes_, slice_sizes, stream());
auto gathered_result =
gather(result, indices, axes_, slice_sizes, stream());
vjps.push_back(
multiply(gathered_cotan, gathered_result == updates, stream()));
break;
}
default: {
// Should never reach here
throw std::invalid_argument("");