Add scatter_min VJP (#462)

This commit is contained in:
Tristan Bilot
2024-01-16 09:37:40 +01:00
committed by GitHub
parent 92a2fdd577
commit f44c132f4a
2 changed files with 31 additions and 3 deletions

View File

@@ -2130,10 +2130,11 @@ std::vector<array> Scatter::vjp(
case Scatter::None:
case Scatter::Sum:
case Scatter::Max:
case Scatter::Min:
break;
default:
throw std::runtime_error(
"[scatter] VJP implemented only for scatter and scatter_add");
"[scatter] VJP not implemented for scatter_prod");
}
const array& values = primals[0];
@@ -2145,6 +2146,8 @@ std::vector<array> Scatter::vjp(
switch (reduce_type_) {
case Scatter::Max:
return scatter_max(values, indices, updates, axes_, stream());
case Scatter::Min:
return scatter_min(values, indices, updates, axes_, stream());
default:
return array({});
}
@@ -2169,7 +2172,8 @@ std::vector<array> Scatter::vjp(
// The input array values are kept so they all get gradients
vjps.push_back(cotangents[0]);
break;
case Scatter::Max: {
case Scatter::Max:
case Scatter::Min: {
auto mask = where(result == values, array({1}), array({0}));
vjps.push_back(multiply(cotangents[0], mask));
break;
@@ -2191,7 +2195,8 @@ std::vector<array> Scatter::vjp(
gather(cotangents[0], indices, axes_, slice_sizes, stream()));
break;
}
case Scatter::Max: {
case Scatter::Max:
case Scatter::Min: {
auto slice_sizes = cotangents[0].shape();
for (auto ax : axes_) {
slice_sizes[ax] = 1;