diff --git a/mlx/distributed/primitives.cpp b/mlx/distributed/primitives.cpp index 576424cdd..5e8d5327a 100644 --- a/mlx/distributed/primitives.cpp +++ b/mlx/distributed/primitives.cpp @@ -29,7 +29,7 @@ std::pair, std::vector> AllReduce::vmap( std::vector AllReduce::jvp( const std::vector& primals, const std::vector& tangents, - const std::vector& argnums) { + const std::vector&) { switch (reduce_type_) { case Sum: return {all_sum(tangents[0], group(), stream())}; @@ -46,7 +46,7 @@ std::vector AllReduce::jvp( std::vector AllReduce::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums, + const std::vector&, const std::vector& outputs) { return cotangents; } @@ -60,21 +60,30 @@ std::pair, std::vector> AllGather::vmap( std::vector AllGather::jvp( const std::vector& primals, const std::vector& tangents, - const std::vector& argnums) { + const std::vector&) { return {all_gather(tangents[0], group(), stream())}; } std::vector AllGather::vjp( const std::vector& primals, const std::vector& cotangents, - const std::vector& argnums, - const std::vector& outputs) { + const std::vector&, + const std::vector&) { auto g = group(); + auto ndim = primals[0].ndim(); Shape starts(primals[0].ndim(), 0); auto stops = primals[0].shape(); + if (ndim == 0) { + starts.push_back(0); + stops.push_back(1); + } starts[0] = g.rank() * stops[0]; stops[0] += starts[0]; - return {slice(cotangents[0], starts, stops)}; + auto out = slice(cotangents[0], starts, stops); + if (ndim == 0) { + out = squeeze(out, 0); + } + return {out}; } std::pair, std::vector> Send::vmap( diff --git a/python/tests/ring_test_distributed.py b/python/tests/ring_test_distributed.py index b8d08b0f6..6721b0831 100644 --- a/python/tests/ring_test_distributed.py +++ b/python/tests/ring_test_distributed.py @@ -129,6 +129,16 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): self.assertTrue(mx.all(y == x[world.rank()])) self.assertTrue(mx.all(z == x[left])) + def test_all_gather_vjp(self): + def fun(x): + return mx.distributed.all_gather(x)[0] + + dfdx = mx.grad(fun)(mx.array(1.0)) + if mx.distributed.init().rank() == 0: + self.assertEqual(dfdx.item(), 1.0) + else: + self.assertEqual(dfdx.item(), 0.0) + if __name__ == "__main__": mlx_tests.MLXTestRunner()