mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-27 03:11:16 +08:00
scatter_max vjp + bindings + tests (#431)
Co-authored-by: DjamelMesbah <djamel.mesbah@adservio.fr>
This commit is contained in:
parent
4bc446be08
commit
6022d4129e
@ -2129,6 +2129,7 @@ std::vector<array> Scatter::vjp(
|
|||||||
switch (reduce_type_) {
|
switch (reduce_type_) {
|
||||||
case Scatter::None:
|
case Scatter::None:
|
||||||
case Scatter::Sum:
|
case Scatter::Sum:
|
||||||
|
case Scatter::Max:
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
@ -2139,6 +2140,17 @@ std::vector<array> Scatter::vjp(
|
|||||||
const array& updates = primals.back();
|
const array& updates = primals.back();
|
||||||
const std::vector<array> indices(primals.begin() + 1, primals.end() - 1);
|
const std::vector<array> indices(primals.begin() + 1, primals.end() - 1);
|
||||||
|
|
||||||
|
// Store result of scatter if needed for reuse in vjp
|
||||||
|
auto get_result = [&]() {
|
||||||
|
switch (reduce_type_) {
|
||||||
|
case Scatter::Max:
|
||||||
|
return scatter_max(values, indices, updates, axes_, stream());
|
||||||
|
default:
|
||||||
|
return array({});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
array result = get_result();
|
||||||
|
|
||||||
std::vector<array> vjps;
|
std::vector<array> vjps;
|
||||||
for (auto num : argnums) {
|
for (auto num : argnums) {
|
||||||
// Gradient wrt to the input array
|
// Gradient wrt to the input array
|
||||||
@ -2157,6 +2169,11 @@ std::vector<array> Scatter::vjp(
|
|||||||
// The input array values are kept so they all get gradients
|
// The input array values are kept so they all get gradients
|
||||||
vjps.push_back(cotangents[0]);
|
vjps.push_back(cotangents[0]);
|
||||||
break;
|
break;
|
||||||
|
case Scatter::Max: {
|
||||||
|
auto mask = where(result == values, array({1}), array({0}));
|
||||||
|
vjps.push_back(multiply(cotangents[0], mask));
|
||||||
|
break;
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
// Should never reach here
|
// Should never reach here
|
||||||
throw std::invalid_argument("");
|
throw std::invalid_argument("");
|
||||||
@ -2174,6 +2191,19 @@ std::vector<array> Scatter::vjp(
|
|||||||
gather(cotangents[0], indices, axes_, slice_sizes, stream()));
|
gather(cotangents[0], indices, axes_, slice_sizes, stream()));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case Scatter::Max: {
|
||||||
|
auto slice_sizes = cotangents[0].shape();
|
||||||
|
for (auto ax : axes_) {
|
||||||
|
slice_sizes[ax] = 1;
|
||||||
|
}
|
||||||
|
auto gathered_cotan =
|
||||||
|
gather(cotangents[0], indices, axes_, slice_sizes, stream());
|
||||||
|
auto gathered_result =
|
||||||
|
gather(result, indices, axes_, slice_sizes, stream());
|
||||||
|
vjps.push_back(
|
||||||
|
multiply(gathered_cotan, gathered_result == updates, stream()));
|
||||||
|
break;
|
||||||
|
}
|
||||||
default: {
|
default: {
|
||||||
// Should never reach here
|
// Should never reach here
|
||||||
throw std::invalid_argument("");
|
throw std::invalid_argument("");
|
||||||
|
@ -293,6 +293,29 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
|||||||
self.assertTrue(mx.array_equal(dfdx, mx.array([1.0])))
|
self.assertTrue(mx.array_equal(dfdx, mx.array([1.0])))
|
||||||
self.assertEqual(dfdx.dtype, mx.float32)
|
self.assertEqual(dfdx.dtype, mx.float32)
|
||||||
|
|
||||||
|
def test_scatter_max_vjp(self):
|
||||||
|
def fun(src, updates):
|
||||||
|
x = src.at[1].maximum(updates)
|
||||||
|
return x
|
||||||
|
|
||||||
|
cotan = mx.array([4.0, 5.0, 6.0])
|
||||||
|
_, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0]), mx.array([[3.0]])], [cotan])
|
||||||
|
mx.eval(vjps)
|
||||||
|
|
||||||
|
# Update larger than value
|
||||||
|
self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 0.0, 6.0])))
|
||||||
|
self.assertTrue(mx.allclose(vjps[1], mx.array([5.0])))
|
||||||
|
|
||||||
|
cotan = mx.array([[4.0], [5.0], [6.0]])
|
||||||
|
_, vjps = mx.vjp(
|
||||||
|
fun, [mx.array([[1.0], [2.0], [3.0]]), mx.array([[[2.0]]])], [cotan]
|
||||||
|
)
|
||||||
|
mx.eval(vjps)
|
||||||
|
|
||||||
|
# Update and value are equal
|
||||||
|
self.assertTrue(mx.allclose(vjps[0], mx.array([[4.0], [5.0], [6.0]])))
|
||||||
|
self.assertTrue(mx.allclose(vjps[1], mx.array([[[5.0]]])))
|
||||||
|
|
||||||
def test_vjp_types(self):
|
def test_vjp_types(self):
|
||||||
def fun(x):
|
def fun(x):
|
||||||
return x
|
return x
|
||||||
|
Loading…
Reference in New Issue
Block a user