mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix all_gather vjp (#2654)
This commit is contained in:
@@ -29,7 +29,7 @@ std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(
|
||||
std::vector<array> AllReduce::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>&) {
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
return {all_sum(tangents[0], group(), stream())};
|
||||
@@ -46,7 +46,7 @@ std::vector<array> AllReduce::jvp(
|
||||
std::vector<array> AllReduce::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<int>&,
|
||||
const std::vector<array>& outputs) {
|
||||
return cotangents;
|
||||
}
|
||||
@@ -60,21 +60,30 @@ std::pair<std::vector<array>, std::vector<int>> AllGather::vmap(
|
||||
std::vector<array> AllGather::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<int>&) {
|
||||
return {all_gather(tangents[0], group(), stream())};
|
||||
}
|
||||
|
||||
std::vector<array> AllGather::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
const std::vector<int>&,
|
||||
const std::vector<array>&) {
|
||||
auto g = group();
|
||||
auto ndim = primals[0].ndim();
|
||||
Shape starts(primals[0].ndim(), 0);
|
||||
auto stops = primals[0].shape();
|
||||
if (ndim == 0) {
|
||||
starts.push_back(0);
|
||||
stops.push_back(1);
|
||||
}
|
||||
starts[0] = g.rank() * stops[0];
|
||||
stops[0] += starts[0];
|
||||
return {slice(cotangents[0], starts, stops)};
|
||||
auto out = slice(cotangents[0], starts, stops);
|
||||
if (ndim == 0) {
|
||||
out = squeeze(out, 0);
|
||||
}
|
||||
return {out};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Send::vmap(
|
||||
|
||||
Reference in New Issue
Block a user