fix all_gather vjp (#2654)

This commit is contained in:
Awni Hannun
2025-10-07 06:05:23 -07:00
committed by GitHub
parent 0073096dd1
commit 343e33b6d5
2 changed files with 25 additions and 6 deletions

View File

@@ -29,7 +29,7 @@ std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(
std::vector<array> AllReduce::jvp( std::vector<array> AllReduce::jvp(
const std::vector<array>& primals, const std::vector<array>& primals,
const std::vector<array>& tangents, const std::vector<array>& tangents,
const std::vector<int>& argnums) { const std::vector<int>&) {
switch (reduce_type_) { switch (reduce_type_) {
case Sum: case Sum:
return {all_sum(tangents[0], group(), stream())}; return {all_sum(tangents[0], group(), stream())};
@@ -46,7 +46,7 @@ std::vector<array> AllReduce::jvp(
std::vector<array> AllReduce::vjp( std::vector<array> AllReduce::vjp(
const std::vector<array>& primals, const std::vector<array>& primals,
const std::vector<array>& cotangents, const std::vector<array>& cotangents,
const std::vector<int>& argnums, const std::vector<int>&,
const std::vector<array>& outputs) { const std::vector<array>& outputs) {
return cotangents; return cotangents;
} }
@@ -60,21 +60,30 @@ std::pair<std::vector<array>, std::vector<int>> AllGather::vmap(
std::vector<array> AllGather::jvp( std::vector<array> AllGather::jvp(
const std::vector<array>& primals, const std::vector<array>& primals,
const std::vector<array>& tangents, const std::vector<array>& tangents,
const std::vector<int>& argnums) { const std::vector<int>&) {
return {all_gather(tangents[0], group(), stream())}; return {all_gather(tangents[0], group(), stream())};
} }
std::vector<array> AllGather::vjp( std::vector<array> AllGather::vjp(
const std::vector<array>& primals, const std::vector<array>& primals,
const std::vector<array>& cotangents, const std::vector<array>& cotangents,
const std::vector<int>& argnums, const std::vector<int>&,
const std::vector<array>& outputs) { const std::vector<array>&) {
auto g = group(); auto g = group();
auto ndim = primals[0].ndim();
Shape starts(primals[0].ndim(), 0); Shape starts(primals[0].ndim(), 0);
auto stops = primals[0].shape(); auto stops = primals[0].shape();
if (ndim == 0) {
starts.push_back(0);
stops.push_back(1);
}
starts[0] = g.rank() * stops[0]; starts[0] = g.rank() * stops[0];
stops[0] += starts[0]; stops[0] += starts[0];
return {slice(cotangents[0], starts, stops)}; auto out = slice(cotangents[0], starts, stops);
if (ndim == 0) {
out = squeeze(out, 0);
}
return {out};
} }
std::pair<std::vector<array>, std::vector<int>> Send::vmap( std::pair<std::vector<array>, std::vector<int>> Send::vmap(

View File

@@ -129,6 +129,16 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
self.assertTrue(mx.all(y == x[world.rank()])) self.assertTrue(mx.all(y == x[world.rank()]))
self.assertTrue(mx.all(z == x[left])) self.assertTrue(mx.all(z == x[left]))
def test_all_gather_vjp(self):
def fun(x):
return mx.distributed.all_gather(x)[0]
dfdx = mx.grad(fun)(mx.array(1.0))
if mx.distributed.init().rank() == 0:
self.assertEqual(dfdx.item(), 1.0)
else:
self.assertEqual(dfdx.item(), 0.0)
if __name__ == "__main__": if __name__ == "__main__":
mlx_tests.MLXTestRunner() mlx_tests.MLXTestRunner()