mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
WIP (common)
This commit is contained in:
@@ -28,7 +28,7 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(
|
|||||||
|
|
||||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||||
// The data is all there and we are reducing over everything
|
// The data is all there and we are reducing over everything
|
||||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
if (x.size() == x.data_size() && std::ssize(axes) == x.ndim() &&
|
||||||
x.flags().contiguous) {
|
x.flags().contiguous) {
|
||||||
return ContiguousAllReduce;
|
return ContiguousAllReduce;
|
||||||
}
|
}
|
||||||
@@ -38,7 +38,7 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
|||||||
// Merge consecutive axes
|
// Merge consecutive axes
|
||||||
Shape shape = {x.shape(axes[0])};
|
Shape shape = {x.shape(axes[0])};
|
||||||
Strides strides = {x.strides()[axes[0]]};
|
Strides strides = {x.strides()[axes[0]]};
|
||||||
for (int i = 1; i < axes.size(); i++) {
|
for (int i = 1; i < std::ssize(axes); i++) {
|
||||||
if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) {
|
if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) {
|
||||||
shape.back() *= x.shape(axes[i]);
|
shape.back() *= x.shape(axes[i]);
|
||||||
strides.back() = x.strides()[axes[i]];
|
strides.back() = x.strides()[axes[i]];
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
|||||||
if (shape[0] != 1) {
|
if (shape[0] != 1) {
|
||||||
to_collapse.push_back(0);
|
to_collapse.push_back(0);
|
||||||
}
|
}
|
||||||
size_t size = shape[0];
|
int64_t size = shape[0];
|
||||||
for (int i = 1; i < shape.size(); i++) {
|
for (int i = 1; i < shape.size(); i++) {
|
||||||
bool contiguous = true;
|
bool contiguous = true;
|
||||||
size *= shape[i];
|
size *= shape[i];
|
||||||
@@ -64,7 +64,7 @@ std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
|||||||
current_shape *= shape[to_collapse[k]];
|
current_shape *= shape[to_collapse[k]];
|
||||||
}
|
}
|
||||||
out_shape.push_back(current_shape);
|
out_shape.push_back(current_shape);
|
||||||
for (int j = 0; j < strides.size(); j++) {
|
for (int j = 0; j < std::ssize(strides); j++) {
|
||||||
const auto& st = strides[j];
|
const auto& st = strides[j];
|
||||||
out_strides[j].push_back(st[to_collapse[k - 1]]);
|
out_strides[j].push_back(st[to_collapse[k - 1]]);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user