Add scatter_min VJP (#462)

This commit is contained in:
Tristan Bilot
2024-01-16 09:37:40 +01:00
committed by GitHub
parent 92a2fdd577
commit f44c132f4a
2 changed files with 31 additions and 3 deletions

View File

@@ -316,6 +316,29 @@ class TestAutograd(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(vjps[0], mx.array([[4.0], [5.0], [6.0]])))
self.assertTrue(mx.allclose(vjps[1], mx.array([[[5.0]]])))
def test_scatter_min_vjp(self):
def fun(src, updates):
x = src.at[1].minimum(updates)
return x
cotan = mx.array([4.0, 5.0, 6.0])
_, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0]), mx.array([[3.0]])], [cotan])
mx.eval(vjps)
# Update larger than value
self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 5.0, 6.0])))
self.assertTrue(mx.allclose(vjps[1], mx.array([0.0])))
cotan = mx.array([[4.0], [5.0], [6.0]])
_, vjps = mx.vjp(
fun, [mx.array([[1.0], [2.0], [3.0]]), mx.array([[[2.0]]])], [cotan]
)
mx.eval(vjps)
# Update and value are equal
self.assertTrue(mx.allclose(vjps[0], mx.array([[4.0], [5.0], [6.0]])))
self.assertTrue(mx.allclose(vjps[1], mx.array([[[5.0]]])))
def test_vjp_types(self):
def fun(x):
return x