mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-12 04:06:39 +08:00
Add scatter_min VJP (#462)
This commit is contained in:
parent
92a2fdd577
commit
f44c132f4a
@ -2130,10 +2130,11 @@ std::vector<array> Scatter::vjp(
|
|||||||
case Scatter::None:
|
case Scatter::None:
|
||||||
case Scatter::Sum:
|
case Scatter::Sum:
|
||||||
case Scatter::Max:
|
case Scatter::Max:
|
||||||
|
case Scatter::Min:
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[scatter] VJP implemented only for scatter and scatter_add");
|
"[scatter] VJP not implemented for scatter_prod");
|
||||||
}
|
}
|
||||||
|
|
||||||
const array& values = primals[0];
|
const array& values = primals[0];
|
||||||
@ -2145,6 +2146,8 @@ std::vector<array> Scatter::vjp(
|
|||||||
switch (reduce_type_) {
|
switch (reduce_type_) {
|
||||||
case Scatter::Max:
|
case Scatter::Max:
|
||||||
return scatter_max(values, indices, updates, axes_, stream());
|
return scatter_max(values, indices, updates, axes_, stream());
|
||||||
|
case Scatter::Min:
|
||||||
|
return scatter_min(values, indices, updates, axes_, stream());
|
||||||
default:
|
default:
|
||||||
return array({});
|
return array({});
|
||||||
}
|
}
|
||||||
@ -2169,7 +2172,8 @@ std::vector<array> Scatter::vjp(
|
|||||||
// The input array values are kept so they all get gradients
|
// The input array values are kept so they all get gradients
|
||||||
vjps.push_back(cotangents[0]);
|
vjps.push_back(cotangents[0]);
|
||||||
break;
|
break;
|
||||||
case Scatter::Max: {
|
case Scatter::Max:
|
||||||
|
case Scatter::Min: {
|
||||||
auto mask = where(result == values, array({1}), array({0}));
|
auto mask = where(result == values, array({1}), array({0}));
|
||||||
vjps.push_back(multiply(cotangents[0], mask));
|
vjps.push_back(multiply(cotangents[0], mask));
|
||||||
break;
|
break;
|
||||||
@ -2191,7 +2195,8 @@ std::vector<array> Scatter::vjp(
|
|||||||
gather(cotangents[0], indices, axes_, slice_sizes, stream()));
|
gather(cotangents[0], indices, axes_, slice_sizes, stream()));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case Scatter::Max: {
|
case Scatter::Max:
|
||||||
|
case Scatter::Min: {
|
||||||
auto slice_sizes = cotangents[0].shape();
|
auto slice_sizes = cotangents[0].shape();
|
||||||
for (auto ax : axes_) {
|
for (auto ax : axes_) {
|
||||||
slice_sizes[ax] = 1;
|
slice_sizes[ax] = 1;
|
||||||
|
@ -316,6 +316,29 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(mx.allclose(vjps[0], mx.array([[4.0], [5.0], [6.0]])))
|
self.assertTrue(mx.allclose(vjps[0], mx.array([[4.0], [5.0], [6.0]])))
|
||||||
self.assertTrue(mx.allclose(vjps[1], mx.array([[[5.0]]])))
|
self.assertTrue(mx.allclose(vjps[1], mx.array([[[5.0]]])))
|
||||||
|
|
||||||
|
def test_scatter_min_vjp(self):
|
||||||
|
def fun(src, updates):
|
||||||
|
x = src.at[1].minimum(updates)
|
||||||
|
return x
|
||||||
|
|
||||||
|
cotan = mx.array([4.0, 5.0, 6.0])
|
||||||
|
_, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0]), mx.array([[3.0]])], [cotan])
|
||||||
|
mx.eval(vjps)
|
||||||
|
|
||||||
|
# Update larger than value
|
||||||
|
self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 5.0, 6.0])))
|
||||||
|
self.assertTrue(mx.allclose(vjps[1], mx.array([0.0])))
|
||||||
|
|
||||||
|
cotan = mx.array([[4.0], [5.0], [6.0]])
|
||||||
|
_, vjps = mx.vjp(
|
||||||
|
fun, [mx.array([[1.0], [2.0], [3.0]]), mx.array([[[2.0]]])], [cotan]
|
||||||
|
)
|
||||||
|
mx.eval(vjps)
|
||||||
|
|
||||||
|
# Update and value are equal
|
||||||
|
self.assertTrue(mx.allclose(vjps[0], mx.array([[4.0], [5.0], [6.0]])))
|
||||||
|
self.assertTrue(mx.allclose(vjps[1], mx.array([[[5.0]]])))
|
||||||
|
|
||||||
def test_vjp_types(self):
|
def test_vjp_types(self):
|
||||||
def fun(x):
|
def fun(x):
|
||||||
return x
|
return x
|
||||||
|
Loading…
Reference in New Issue
Block a user