Primitive's VJP takes outputs as input (#475)

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-01-16 19:03:53 -08:00
committed by GitHub
parent d8fabaa12b
commit a2bf7693dd
5 changed files with 205 additions and 136 deletions

View File

@@ -71,3 +71,5 @@ class MLXTestCase(unittest.TestCase):
elif not isinstance(expected, mx.array):
expected = mx.array(expected)
self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol))
else:
self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol))