mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Use int64 stride everywhere (#1671)
* use int64 stride everywhere * fix ext * fix ext * more shape + cleanup * one more * few more
This commit is contained in:
@@ -24,13 +24,13 @@ void single_block_sort(
|
||||
// Prepare shapes
|
||||
int n_rows = in.size() / in.shape(axis);
|
||||
|
||||
std::vector<size_t> in_nc_str = in.strides();
|
||||
auto in_nc_str = in.strides();
|
||||
in_nc_str.erase(in_nc_str.begin() + axis);
|
||||
|
||||
std::vector<size_t> out_nc_str = out.strides();
|
||||
auto out_nc_str = out.strides();
|
||||
out_nc_str.erase(out_nc_str.begin() + axis);
|
||||
|
||||
std::vector<int> nc_shape = in.shape();
|
||||
auto nc_shape = in.shape();
|
||||
nc_shape.erase(nc_shape.begin() + axis);
|
||||
|
||||
int nc_dim = nc_shape.size();
|
||||
@@ -106,10 +106,10 @@ void multi_block_sort(
|
||||
// Prepare shapes
|
||||
int n_rows = in.size() / in.shape(axis);
|
||||
|
||||
std::vector<size_t> nc_str = in.strides();
|
||||
auto nc_str = in.strides();
|
||||
nc_str.erase(nc_str.begin() + axis);
|
||||
|
||||
std::vector<int> nc_shape = in.shape();
|
||||
auto nc_shape = in.shape();
|
||||
nc_shape.erase(nc_shape.begin() + axis);
|
||||
|
||||
int nc_dim = nc_shape.size();
|
||||
|
||||
Reference in New Issue
Block a user