mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-17 23:08:11 +08:00
fix all_gather vjp (#2654)
This commit is contained in:
@@ -129,6 +129,16 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
self.assertTrue(mx.all(y == x[world.rank()]))
|
||||
self.assertTrue(mx.all(z == x[left]))
|
||||
|
||||
def test_all_gather_vjp(self):
|
||||
def fun(x):
|
||||
return mx.distributed.all_gather(x)[0]
|
||||
|
||||
dfdx = mx.grad(fun)(mx.array(1.0))
|
||||
if mx.distributed.init().rank() == 0:
|
||||
self.assertEqual(dfdx.item(), 1.0)
|
||||
else:
|
||||
self.assertEqual(dfdx.item(), 0.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
Reference in New Issue
Block a user