mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Correct types for vjp + tests (#418)
* correct types for vjp + tests * fix build + comment
This commit is contained in:
@@ -274,6 +274,47 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
mx.eval(state)
|
||||
self.assertTrue(mx.allclose(state, mx.ones((2,))))
|
||||
|
||||
def test_scatter_vjp(self):
|
||||
def fun(x, idx):
|
||||
x[idx] = 2.0
|
||||
return x.sum()
|
||||
|
||||
dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0]), mx.array([1]))
|
||||
self.assertTrue(mx.array_equal(dfdx, mx.array([1.0, 0.0, 1.0])))
|
||||
self.assertEqual(dfdx.dtype, mx.float32)
|
||||
|
||||
y = mx.array([0.0, 1.0, 2.0])
|
||||
|
||||
def fun(x, idx):
|
||||
y[idx] = x
|
||||
return y.sum()
|
||||
|
||||
dfdx = mx.grad(fun)(mx.array([2.0]), mx.array([1]))
|
||||
self.assertTrue(mx.array_equal(dfdx, mx.array([1.0])))
|
||||
self.assertEqual(dfdx.dtype, mx.float32)
|
||||
|
||||
def test_vjp_types(self):
|
||||
def fun(x):
|
||||
return x
|
||||
|
||||
for t in [mx.float16, mx.bfloat16, mx.float32]:
|
||||
out = mx.grad(fun)(mx.array(1.0, t))
|
||||
self.assertEqual(out.dtype, t)
|
||||
|
||||
def fun(x):
|
||||
return x.sum()
|
||||
|
||||
for t in [mx.float16, mx.bfloat16, mx.float32]:
|
||||
out = mx.grad(fun)(mx.array(1.0, t))
|
||||
self.assertEqual(out.dtype, t)
|
||||
|
||||
def fun(x, y):
|
||||
return (x + y).sum()
|
||||
|
||||
for t in [mx.float16, mx.bfloat16, mx.float32]:
|
||||
out = mx.grad(fun)(mx.array(1.0, t), mx.array(1.0, t))
|
||||
self.assertEqual(out.dtype, t)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user