mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Use int64 stride everywhere (#1671)
* use int64 stride everywhere * fix ext * fix ext * more shape + cleanup * one more * few more
This commit is contained in:
@@ -38,13 +38,10 @@ enum ReductionOpType {
|
||||
|
||||
struct ReductionPlan {
|
||||
ReductionOpType type;
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
Shape shape;
|
||||
Strides strides;
|
||||
|
||||
ReductionPlan(
|
||||
ReductionOpType type_,
|
||||
std::vector<int> shape_,
|
||||
std::vector<size_t> strides_)
|
||||
ReductionPlan(ReductionOpType type_, Shape shape_, Strides strides_)
|
||||
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
|
||||
ReductionPlan(ReductionOpType type_) : type(type_) {}
|
||||
};
|
||||
@@ -55,10 +52,10 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
|
||||
// Should this be in utils?
|
||||
void nd_loop(
|
||||
std::function<void(int)> callback,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides);
|
||||
const Shape& shape,
|
||||
const Strides& strides);
|
||||
|
||||
std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
|
||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
const array& x,
|
||||
const std::vector<int>& axes);
|
||||
|
||||
@@ -113,9 +110,6 @@ void reduction_op(
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
|
||||
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
|
||||
int reduction_size = plan.shape[0];
|
||||
const T* x_ptr = x.data<T>();
|
||||
@@ -135,7 +129,7 @@ void reduction_op(
|
||||
U* out_ptr = out.data<U>();
|
||||
// Unrolling the following loop (and implementing it in order for
|
||||
// ContiguousReduce) should hold extra performance boost.
|
||||
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
@@ -181,7 +175,7 @@ void reduction_op(
|
||||
plan.strides.pop_back();
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
@@ -211,7 +205,7 @@ void reduction_op(
|
||||
if (plan.type == GeneralReduce) {
|
||||
const T* x_ptr = x.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
U val = init;
|
||||
|
||||
Reference in New Issue
Block a user