From f44c132f4a08826967321257afe1c579685493e9 Mon Sep 17 00:00:00 2001 From: Tristan Bilot <40337775+TristanBilot@users.noreply.github.com> Date: Tue, 16 Jan 2024 09:37:40 +0100 Subject: [PATCH] Add scatter_min VJP (#462) --- mlx/primitives.cpp | 11 ++++++++--- python/tests/test_autograd.py | 23 +++++++++++++++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 4659c1fbda..8eccd1f603 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -2130,10 +2130,11 @@ std::vector Scatter::vjp( case Scatter::None: case Scatter::Sum: case Scatter::Max: + case Scatter::Min: break; default: throw std::runtime_error( - "[scatter] VJP implemented only for scatter and scatter_add"); + "[scatter] VJP not implemented for scatter_prod"); } const array& values = primals[0]; @@ -2145,6 +2146,8 @@ std::vector Scatter::vjp( switch (reduce_type_) { case Scatter::Max: return scatter_max(values, indices, updates, axes_, stream()); + case Scatter::Min: + return scatter_min(values, indices, updates, axes_, stream()); default: return array({}); } @@ -2169,7 +2172,8 @@ std::vector Scatter::vjp( // The input array values are kept so they all get gradients vjps.push_back(cotangents[0]); break; - case Scatter::Max: { + case Scatter::Max: + case Scatter::Min: { auto mask = where(result == values, array({1}), array({0})); vjps.push_back(multiply(cotangents[0], mask)); break; @@ -2191,7 +2195,8 @@ std::vector Scatter::vjp( gather(cotangents[0], indices, axes_, slice_sizes, stream())); break; } - case Scatter::Max: { + case Scatter::Max: + case Scatter::Min: { auto slice_sizes = cotangents[0].shape(); for (auto ax : axes_) { slice_sizes[ax] = 1; diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 901c89dfdb..1279cfb40f 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -316,6 +316,29 @@ class TestAutograd(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(vjps[0], mx.array([[4.0], [5.0], [6.0]]))) self.assertTrue(mx.allclose(vjps[1], mx.array([[[5.0]]]))) + def test_scatter_min_vjp(self): + def fun(src, updates): + x = src.at[1].minimum(updates) + return x + + cotan = mx.array([4.0, 5.0, 6.0]) + _, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0]), mx.array([[3.0]])], [cotan]) + mx.eval(vjps) + + # Update larger than value + self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 5.0, 6.0]))) + self.assertTrue(mx.allclose(vjps[1], mx.array([0.0]))) + + cotan = mx.array([[4.0], [5.0], [6.0]]) + _, vjps = mx.vjp( + fun, [mx.array([[1.0], [2.0], [3.0]]), mx.array([[[2.0]]])], [cotan] + ) + mx.eval(vjps) + + # Update and value are equal + self.assertTrue(mx.allclose(vjps[0], mx.array([[4.0], [5.0], [6.0]]))) + self.assertTrue(mx.allclose(vjps[1], mx.array([[[5.0]]]))) + def test_vjp_types(self): def fun(x): return x