mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 07:18:29 +08:00
fix all_gather vjp (#2654)
This commit is contained in:
@@ -29,7 +29,7 @@ std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(
|
||||
std::vector<array> AllReduce::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>&) {
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
return {all_sum(tangents[0], group(), stream())};
|
||||
@@ -46,7 +46,7 @@ std::vector<array> AllReduce::jvp(
|
||||
std::vector<array> AllReduce::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<int>&,
|
||||
const std::vector<array>& outputs) {
|
||||
return cotangents;
|
||||
}
|
||||
@@ -60,21 +60,30 @@ std::pair<std::vector<array>, std::vector<int>> AllGather::vmap(
|
||||
std::vector<array> AllGather::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>&) {
|
||||
return {all_gather(tangents[0], group(), stream())};
|
||||
}
|
||||
|
||||
std::vector<array> AllGather::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
const std::vector<int>&,
|
||||
const std::vector<array>&) {
|
||||
auto g = group();
|
||||
auto ndim = primals[0].ndim();
|
||||
Shape starts(primals[0].ndim(), 0);
|
||||
auto stops = primals[0].shape();
|
||||
if (ndim == 0) {
|
||||
starts.push_back(0);
|
||||
stops.push_back(1);
|
||||
}
|
||||
starts[0] = g.rank() * stops[0];
|
||||
stops[0] += starts[0];
|
||||
return {slice(cotangents[0], starts, stops)};
|
||||
auto out = slice(cotangents[0], starts, stops);
|
||||
if (ndim == 0) {
|
||||
out = squeeze(out, 0);
|
||||
}
|
||||
return {out};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Send::vmap(
|
||||
|
@@ -129,6 +129,16 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
||||
self.assertTrue(mx.all(y == x[world.rank()]))
|
||||
self.assertTrue(mx.all(z == x[left]))
|
||||
|
||||
def test_all_gather_vjp(self):
|
||||
def fun(x):
|
||||
return mx.distributed.all_gather(x)[0]
|
||||
|
||||
dfdx = mx.grad(fun)(mx.array(1.0))
|
||||
if mx.distributed.init().rank() == 0:
|
||||
self.assertEqual(dfdx.item(), 1.0)
|
||||
else:
|
||||
self.assertEqual(dfdx.item(), 0.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
Reference in New Issue
Block a user