fix all_gather vjp (#2654)

This commit is contained in:
Awni Hannun
2025-10-07 06:05:23 -07:00
committed by GitHub
parent 0073096dd1
commit 343e33b6d5
2 changed files with 25 additions and 6 deletions

View File

@@ -29,7 +29,7 @@ std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(
std::vector<array> AllReduce::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
const std::vector<int>&) {
switch (reduce_type_) {
case Sum:
return {all_sum(tangents[0], group(), stream())};
@@ -46,7 +46,7 @@ std::vector<array> AllReduce::jvp(
std::vector<array> AllReduce::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<int>&,
const std::vector<array>& outputs) {
return cotangents;
}
@@ -60,21 +60,30 @@ std::pair<std::vector<array>, std::vector<int>> AllGather::vmap(
std::vector<array> AllGather::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
const std::vector<int>&) {
return {all_gather(tangents[0], group(), stream())};
}
std::vector<array> AllGather::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
const std::vector<int>&,
const std::vector<array>&) {
auto g = group();
auto ndim = primals[0].ndim();
Shape starts(primals[0].ndim(), 0);
auto stops = primals[0].shape();
if (ndim == 0) {
starts.push_back(0);
stops.push_back(1);
}
starts[0] = g.rank() * stops[0];
stops[0] += starts[0];
return {slice(cotangents[0], starts, stops)};
auto out = slice(cotangents[0], starts, stops);
if (ndim == 0) {
out = squeeze(out, 0);
}
return {out};
}
std::pair<std::vector<array>, std::vector<int>> Send::vmap(