More shape type (#1705)

* more shape type

* fix
This commit is contained in:
Awni Hannun
2024-12-19 08:08:20 -08:00
committed by GitHub
parent f17536af9c
commit e03f0372b1
38 changed files with 260 additions and 258 deletions

View File

@@ -82,7 +82,7 @@ array send(
}
array recv(
std::vector<int> shape,
Shape shape,
Dtype dtype,
int src,
std::optional<Group> group_ /* = std::nullopt */,

View File

@@ -26,7 +26,7 @@ array send(
StreamOrDevice s = {});
array recv(
std::vector<int> shape,
Shape shape,
Dtype dtype,
int src,
std::optional<Group> group = std::nullopt,

View File

@@ -91,7 +91,7 @@ std::vector<array> AllGather::vjp(
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
auto g = group();
std::vector<int> starts(primals[0].ndim(), 0);
Shape starts(primals[0].ndim(), 0);
auto stops = primals[0].shape();
starts[0] = g.rank() * stops[0];
stops[0] += starts[0];