Use int64 stride everywhere (#1671)

* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
This commit is contained in:
Awni Hannun
2024-12-09 11:09:02 -08:00
committed by GitHub
parent 35b412c099
commit 40c62c1321
102 changed files with 1262 additions and 1705 deletions

View File

@@ -24,13 +24,13 @@ void single_block_sort(
// Prepare shapes
int n_rows = in.size() / in.shape(axis);
std::vector<size_t> in_nc_str = in.strides();
auto in_nc_str = in.strides();
in_nc_str.erase(in_nc_str.begin() + axis);
std::vector<size_t> out_nc_str = out.strides();
auto out_nc_str = out.strides();
out_nc_str.erase(out_nc_str.begin() + axis);
std::vector<int> nc_shape = in.shape();
auto nc_shape = in.shape();
nc_shape.erase(nc_shape.begin() + axis);
int nc_dim = nc_shape.size();
@@ -106,10 +106,10 @@ void multi_block_sort(
// Prepare shapes
int n_rows = in.size() / in.shape(axis);
std::vector<size_t> nc_str = in.strides();
auto nc_str = in.strides();
nc_str.erase(nc_str.begin() + axis);
std::vector<int> nc_shape = in.shape();
auto nc_shape = in.shape();
nc_shape.erase(nc_shape.begin() + axis);
int nc_dim = nc_shape.size();