diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 0077b11fa..4659c1fbd 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -2129,6 +2129,7 @@ std::vector Scatter::vjp( switch (reduce_type_) { case Scatter::None: case Scatter::Sum: + case Scatter::Max: break; default: throw std::runtime_error( @@ -2139,6 +2140,17 @@ std::vector Scatter::vjp( const array& updates = primals.back(); const std::vector indices(primals.begin() + 1, primals.end() - 1); + // Store result of scatter if needed for reuse in vjp + auto get_result = [&]() { + switch (reduce_type_) { + case Scatter::Max: + return scatter_max(values, indices, updates, axes_, stream()); + default: + return array({}); + } + }; + array result = get_result(); + std::vector vjps; for (auto num : argnums) { // Gradient wrt to the input array @@ -2157,6 +2169,11 @@ std::vector Scatter::vjp( // The input array values are kept so they all get gradients vjps.push_back(cotangents[0]); break; + case Scatter::Max: { + auto mask = where(result == values, array({1}), array({0})); + vjps.push_back(multiply(cotangents[0], mask)); + break; + } default: // Should never reach here throw std::invalid_argument(""); @@ -2174,6 +2191,19 @@ std::vector Scatter::vjp( gather(cotangents[0], indices, axes_, slice_sizes, stream())); break; } + case Scatter::Max: { + auto slice_sizes = cotangents[0].shape(); + for (auto ax : axes_) { + slice_sizes[ax] = 1; + } + auto gathered_cotan = + gather(cotangents[0], indices, axes_, slice_sizes, stream()); + auto gathered_result = + gather(result, indices, axes_, slice_sizes, stream()); + vjps.push_back( + multiply(gathered_cotan, gathered_result == updates, stream())); + break; + } default: { // Should never reach here throw std::invalid_argument(""); diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 946210e6e..901c89dfd 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -293,6 +293,29 @@ class TestAutograd(mlx_tests.MLXTestCase): self.assertTrue(mx.array_equal(dfdx, mx.array([1.0]))) self.assertEqual(dfdx.dtype, mx.float32) + def test_scatter_max_vjp(self): + def fun(src, updates): + x = src.at[1].maximum(updates) + return x + + cotan = mx.array([4.0, 5.0, 6.0]) + _, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0]), mx.array([[3.0]])], [cotan]) + mx.eval(vjps) + + # Update larger than value + self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 0.0, 6.0]))) + self.assertTrue(mx.allclose(vjps[1], mx.array([5.0]))) + + cotan = mx.array([[4.0], [5.0], [6.0]]) + _, vjps = mx.vjp( + fun, [mx.array([[1.0], [2.0], [3.0]]), mx.array([[[2.0]]])], [cotan] + ) + mx.eval(vjps) + + # Update and value are equal + self.assertTrue(mx.allclose(vjps[0], mx.array([[4.0], [5.0], [6.0]]))) + self.assertTrue(mx.allclose(vjps[1], mx.array([[[5.0]]]))) + def test_vjp_types(self): def fun(x): return x