mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
@@ -82,7 +82,7 @@ array send(
|
||||
}
|
||||
|
||||
array recv(
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
int src,
|
||||
std::optional<Group> group_ /* = std::nullopt */,
|
||||
|
||||
@@ -26,7 +26,7 @@ array send(
|
||||
StreamOrDevice s = {});
|
||||
|
||||
array recv(
|
||||
std::vector<int> shape,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
int src,
|
||||
std::optional<Group> group = std::nullopt,
|
||||
|
||||
@@ -91,7 +91,7 @@ std::vector<array> AllGather::vjp(
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
auto g = group();
|
||||
std::vector<int> starts(primals[0].ndim(), 0);
|
||||
Shape starts(primals[0].ndim(), 0);
|
||||
auto stops = primals[0].shape();
|
||||
starts[0] = g.rank() * stops[0];
|
||||
stops[0] += starts[0];
|
||||
|
||||
Reference in New Issue
Block a user