mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
Add scatter_min VJP (#462)
This commit is contained in:
@@ -2130,10 +2130,11 @@ std::vector<array> Scatter::vjp(
|
||||
case Scatter::None:
|
||||
case Scatter::Sum:
|
||||
case Scatter::Max:
|
||||
case Scatter::Min:
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[scatter] VJP implemented only for scatter and scatter_add");
|
||||
"[scatter] VJP not implemented for scatter_prod");
|
||||
}
|
||||
|
||||
const array& values = primals[0];
|
||||
@@ -2145,6 +2146,8 @@ std::vector<array> Scatter::vjp(
|
||||
switch (reduce_type_) {
|
||||
case Scatter::Max:
|
||||
return scatter_max(values, indices, updates, axes_, stream());
|
||||
case Scatter::Min:
|
||||
return scatter_min(values, indices, updates, axes_, stream());
|
||||
default:
|
||||
return array({});
|
||||
}
|
||||
@@ -2169,7 +2172,8 @@ std::vector<array> Scatter::vjp(
|
||||
// The input array values are kept so they all get gradients
|
||||
vjps.push_back(cotangents[0]);
|
||||
break;
|
||||
case Scatter::Max: {
|
||||
case Scatter::Max:
|
||||
case Scatter::Min: {
|
||||
auto mask = where(result == values, array({1}), array({0}));
|
||||
vjps.push_back(multiply(cotangents[0], mask));
|
||||
break;
|
||||
@@ -2191,7 +2195,8 @@ std::vector<array> Scatter::vjp(
|
||||
gather(cotangents[0], indices, axes_, slice_sizes, stream()));
|
||||
break;
|
||||
}
|
||||
case Scatter::Max: {
|
||||
case Scatter::Max:
|
||||
case Scatter::Min: {
|
||||
auto slice_sizes = cotangents[0].shape();
|
||||
for (auto ax : axes_) {
|
||||
slice_sizes[ax] = 1;
|
||||
|
Reference in New Issue
Block a user