Use int64 stride everywhere (#1671)

* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
This commit is contained in:
Awni Hannun
2024-12-09 11:09:02 -08:00
committed by GitHub
parent 35b412c099
commit 40c62c1321
102 changed files with 1262 additions and 1705 deletions

View File

@@ -38,13 +38,10 @@ enum ReductionOpType {
struct ReductionPlan {
ReductionOpType type;
std::vector<int> shape;
std::vector<size_t> strides;
Shape shape;
Strides strides;
ReductionPlan(
ReductionOpType type_,
std::vector<int> shape_,
std::vector<size_t> strides_)
ReductionPlan(ReductionOpType type_, Shape shape_, Strides strides_)
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
ReductionPlan(ReductionOpType type_) : type(type_) {}
};
@@ -55,10 +52,10 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
// Should this be in utils?
void nd_loop(
std::function<void(int)> callback,
const std::vector<int>& shape,
const std::vector<size_t>& strides);
const Shape& shape,
const Strides& strides);
std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes);
@@ -113,9 +110,6 @@ void reduction_op(
return;
}
std::vector<int> shape;
std::vector<size_t> strides;
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
int reduction_size = plan.shape[0];
const T* x_ptr = x.data<T>();
@@ -135,7 +129,7 @@ void reduction_op(
U* out_ptr = out.data<U>();
// Unrolling the following loop (and implementing it in order for
// ContiguousReduce) should hold extra performance boost.
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
if (plan.shape.size() == 0) {
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
@@ -181,7 +175,7 @@ void reduction_op(
plan.strides.pop_back();
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
if (plan.shape.size() == 0) {
for (int i = 0; i < out.size(); i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
@@ -211,7 +205,7 @@ void reduction_op(
if (plan.type == GeneralReduce) {
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
U val = init;