mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-16 14:18:12 +08:00
Primitive's VJP takes outputs as input (#475)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -71,3 +71,5 @@ class MLXTestCase(unittest.TestCase):
|
||||
elif not isinstance(expected, mx.array):
|
||||
expected = mx.array(expected)
|
||||
self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol))
|
||||
else:
|
||||
self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol))
|
||||
|
@@ -1005,7 +1005,8 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
index_y = mx.array([3, 3, 1, 2])
|
||||
u = mx.random.uniform(shape=(4,))
|
||||
a = a.at[index_x, index_y].add(u)
|
||||
self.assertEqual(a.sum().item(), u.sum().item())
|
||||
self.assertTrue(mx.allclose(a.sum(), u.sum()))
|
||||
self.assertEqualArray(a.sum(), u.sum(), atol=1e-6, rtol=1e-5)
|
||||
self.assertEqual(a[index_x, index_y].tolist(), u.tolist())
|
||||
|
||||
# Test all array.at ops
|
||||
|
Reference in New Issue
Block a user