WIP (common)

This commit is contained in:
Ronan Collobert
2025-10-30 16:18:59 -07:00
parent 63d91557e0
commit 76ef1e98f3
2 changed files with 4 additions and 4 deletions

View File

@@ -28,7 +28,7 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) { ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// The data is all there and we are reducing over everything // The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() && if (x.size() == x.data_size() && std::ssize(axes) == x.ndim() &&
x.flags().contiguous) { x.flags().contiguous) {
return ContiguousAllReduce; return ContiguousAllReduce;
} }
@@ -38,7 +38,7 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// Merge consecutive axes // Merge consecutive axes
Shape shape = {x.shape(axes[0])}; Shape shape = {x.shape(axes[0])};
Strides strides = {x.strides()[axes[0]]}; Strides strides = {x.strides()[axes[0]]};
for (int i = 1; i < axes.size(); i++) { for (int i = 1; i < std::ssize(axes); i++) {
if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) { if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) {
shape.back() *= x.shape(axes[i]); shape.back() *= x.shape(axes[i]);
strides.back() = x.strides()[axes[i]]; strides.back() = x.strides()[axes[i]];

View File

@@ -28,7 +28,7 @@ std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
if (shape[0] != 1) { if (shape[0] != 1) {
to_collapse.push_back(0); to_collapse.push_back(0);
} }
size_t size = shape[0]; int64_t size = shape[0];
for (int i = 1; i < shape.size(); i++) { for (int i = 1; i < shape.size(); i++) {
bool contiguous = true; bool contiguous = true;
size *= shape[i]; size *= shape[i];
@@ -64,7 +64,7 @@ std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
current_shape *= shape[to_collapse[k]]; current_shape *= shape[to_collapse[k]];
} }
out_shape.push_back(current_shape); out_shape.push_back(current_shape);
for (int j = 0; j < strides.size(); j++) { for (int j = 0; j < std::ssize(strides); j++) {
const auto& st = strides[j]; const auto& st = strides[j];
out_strides[j].push_back(st[to_collapse[k - 1]]); out_strides[j].push_back(st[to_collapse[k - 1]]);
} }