mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
scatter_max vjp + bindings + tests (#431)
Co-authored-by: DjamelMesbah <djamel.mesbah@adservio.fr>
This commit is contained in:
@@ -293,6 +293,29 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.array_equal(dfdx, mx.array([1.0])))
|
||||
self.assertEqual(dfdx.dtype, mx.float32)
|
||||
|
||||
def test_scatter_max_vjp(self):
|
||||
def fun(src, updates):
|
||||
x = src.at[1].maximum(updates)
|
||||
return x
|
||||
|
||||
cotan = mx.array([4.0, 5.0, 6.0])
|
||||
_, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0]), mx.array([[3.0]])], [cotan])
|
||||
mx.eval(vjps)
|
||||
|
||||
# Update larger than value
|
||||
self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 0.0, 6.0])))
|
||||
self.assertTrue(mx.allclose(vjps[1], mx.array([5.0])))
|
||||
|
||||
cotan = mx.array([[4.0], [5.0], [6.0]])
|
||||
_, vjps = mx.vjp(
|
||||
fun, [mx.array([[1.0], [2.0], [3.0]]), mx.array([[[2.0]]])], [cotan]
|
||||
)
|
||||
mx.eval(vjps)
|
||||
|
||||
# Update and value are equal
|
||||
self.assertTrue(mx.allclose(vjps[0], mx.array([[4.0], [5.0], [6.0]])))
|
||||
self.assertTrue(mx.allclose(vjps[1], mx.array([[[5.0]]])))
|
||||
|
||||
def test_vjp_types(self):
|
||||
def fun(x):
|
||||
return x
|
||||
|
Reference in New Issue
Block a user