fix all_gather vjp (#2654)

This commit is contained in:
Awni Hannun
2025-10-07 06:05:23 -07:00
committed by GitHub
parent 0073096dd1
commit 343e33b6d5
2 changed files with 25 additions and 6 deletions

View File

@@ -129,6 +129,16 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
self.assertTrue(mx.all(y == x[world.rank()]))
self.assertTrue(mx.all(z == x[left]))
def test_all_gather_vjp(self):
def fun(x):
return mx.distributed.all_gather(x)[0]
dfdx = mx.grad(fun)(mx.array(1.0))
if mx.distributed.init().rank() == 0:
self.assertEqual(dfdx.item(), 1.0)
else:
self.assertEqual(dfdx.item(), 0.0)
if __name__ == "__main__":
mlx_tests.MLXTestRunner()