Correct types for vjp + tests (#418)

* correct types for vjp + tests

* fix build + comment
This commit is contained in:
Awni Hannun
2024-01-10 13:32:37 -08:00
committed by GitHub
parent b7f905787e
commit 3b4f066dac
4 changed files with 75 additions and 4 deletions

View File

@@ -274,6 +274,47 @@ class TestAutograd(mlx_tests.MLXTestCase):
mx.eval(state)
self.assertTrue(mx.allclose(state, mx.ones((2,))))
def test_scatter_vjp(self):
def fun(x, idx):
x[idx] = 2.0
return x.sum()
dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0]), mx.array([1]))
self.assertTrue(mx.array_equal(dfdx, mx.array([1.0, 0.0, 1.0])))
self.assertEqual(dfdx.dtype, mx.float32)
y = mx.array([0.0, 1.0, 2.0])
def fun(x, idx):
y[idx] = x
return y.sum()
dfdx = mx.grad(fun)(mx.array([2.0]), mx.array([1]))
self.assertTrue(mx.array_equal(dfdx, mx.array([1.0])))
self.assertEqual(dfdx.dtype, mx.float32)
def test_vjp_types(self):
def fun(x):
return x
for t in [mx.float16, mx.bfloat16, mx.float32]:
out = mx.grad(fun)(mx.array(1.0, t))
self.assertEqual(out.dtype, t)
def fun(x):
return x.sum()
for t in [mx.float16, mx.bfloat16, mx.float32]:
out = mx.grad(fun)(mx.array(1.0, t))
self.assertEqual(out.dtype, t)
def fun(x, y):
return (x + y).sum()
for t in [mx.float16, mx.bfloat16, mx.float32]:
out = mx.grad(fun)(mx.array(1.0, t), mx.array(1.0, t))
self.assertEqual(out.dtype, t)
if __name__ == "__main__":
unittest.main()