mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 07:18:29 +08:00
fix all_gather vjp (#2654)
This commit is contained in:
@@ -29,7 +29,7 @@ std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(
|
|||||||
std::vector<array> AllReduce::jvp(
|
std::vector<array> AllReduce::jvp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& tangents,
|
const std::vector<array>& tangents,
|
||||||
const std::vector<int>& argnums) {
|
const std::vector<int>&) {
|
||||||
switch (reduce_type_) {
|
switch (reduce_type_) {
|
||||||
case Sum:
|
case Sum:
|
||||||
return {all_sum(tangents[0], group(), stream())};
|
return {all_sum(tangents[0], group(), stream())};
|
||||||
@@ -46,7 +46,7 @@ std::vector<array> AllReduce::jvp(
|
|||||||
std::vector<array> AllReduce::vjp(
|
std::vector<array> AllReduce::vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& cotangents,
|
const std::vector<array>& cotangents,
|
||||||
const std::vector<int>& argnums,
|
const std::vector<int>&,
|
||||||
const std::vector<array>& outputs) {
|
const std::vector<array>& outputs) {
|
||||||
return cotangents;
|
return cotangents;
|
||||||
}
|
}
|
||||||
@@ -60,21 +60,30 @@ std::pair<std::vector<array>, std::vector<int>> AllGather::vmap(
|
|||||||
std::vector<array> AllGather::jvp(
|
std::vector<array> AllGather::jvp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& tangents,
|
const std::vector<array>& tangents,
|
||||||
const std::vector<int>& argnums) {
|
const std::vector<int>&) {
|
||||||
return {all_gather(tangents[0], group(), stream())};
|
return {all_gather(tangents[0], group(), stream())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> AllGather::vjp(
|
std::vector<array> AllGather::vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& cotangents,
|
const std::vector<array>& cotangents,
|
||||||
const std::vector<int>& argnums,
|
const std::vector<int>&,
|
||||||
const std::vector<array>& outputs) {
|
const std::vector<array>&) {
|
||||||
auto g = group();
|
auto g = group();
|
||||||
|
auto ndim = primals[0].ndim();
|
||||||
Shape starts(primals[0].ndim(), 0);
|
Shape starts(primals[0].ndim(), 0);
|
||||||
auto stops = primals[0].shape();
|
auto stops = primals[0].shape();
|
||||||
|
if (ndim == 0) {
|
||||||
|
starts.push_back(0);
|
||||||
|
stops.push_back(1);
|
||||||
|
}
|
||||||
starts[0] = g.rank() * stops[0];
|
starts[0] = g.rank() * stops[0];
|
||||||
stops[0] += starts[0];
|
stops[0] += starts[0];
|
||||||
return {slice(cotangents[0], starts, stops)};
|
auto out = slice(cotangents[0], starts, stops);
|
||||||
|
if (ndim == 0) {
|
||||||
|
out = squeeze(out, 0);
|
||||||
|
}
|
||||||
|
return {out};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<int>> Send::vmap(
|
std::pair<std::vector<array>, std::vector<int>> Send::vmap(
|
||||||
|
@@ -129,6 +129,16 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
|
|||||||
self.assertTrue(mx.all(y == x[world.rank()]))
|
self.assertTrue(mx.all(y == x[world.rank()]))
|
||||||
self.assertTrue(mx.all(z == x[left]))
|
self.assertTrue(mx.all(z == x[left]))
|
||||||
|
|
||||||
|
def test_all_gather_vjp(self):
|
||||||
|
def fun(x):
|
||||||
|
return mx.distributed.all_gather(x)[0]
|
||||||
|
|
||||||
|
dfdx = mx.grad(fun)(mx.array(1.0))
|
||||||
|
if mx.distributed.init().rank() == 0:
|
||||||
|
self.assertEqual(dfdx.item(), 1.0)
|
||||||
|
else:
|
||||||
|
self.assertEqual(dfdx.item(), 0.0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mlx_tests.MLXTestRunner()
|
mlx_tests.MLXTestRunner()
|
||||||
|
Reference in New Issue
Block a user