Use int64 stride everywhere (#1671)

* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
This commit is contained in:
Awni Hannun
2024-12-09 11:09:02 -08:00
committed by GitHub
parent 35b412c099
commit 40c62c1321
102 changed files with 1262 additions and 1705 deletions

View File

@@ -2953,16 +2953,16 @@ void init_ops(nb::module_& m) {
m.def(
"as_strided",
[](const array& a,
std::optional<std::vector<int>> shape,
std::optional<std::vector<size_t>> strides,
std::optional<Shape> shape,
std::optional<Strides> strides,
size_t offset,
StreamOrDevice s) {
std::vector<int> a_shape = (shape) ? *shape : a.shape();
std::vector<size_t> a_strides;
auto a_shape = (shape) ? *shape : a.shape();
Strides a_strides;
if (strides) {
a_strides = *strides;
} else {
a_strides = std::vector<size_t>(a_shape.size(), 1);
a_strides = Strides(a_shape.size(), 1);
for (int i = a_shape.size() - 1; i > 0; i--) {
a_strides[i - 1] = a_shape[i] * a_strides[i];
}