41 const std::vector<int>& shape,
42 const std::vector<std::vector<stride_t>> strides) {
45 std::vector<int> to_collapse;
46 if (shape.size() > 0) {
47 to_collapse.push_back(0);
48 for (
int i = 1; i < shape.size(); i++) {
49 bool contiguous =
true;
50 for (
const std::vector<stride_t>& st : strides) {
51 if (st[i] * shape[i] != st[i - 1]) {
59 to_collapse.push_back(-1);
61 to_collapse.push_back(i);
63 to_collapse.push_back(-1);
66 std::vector<int> out_shape;
67 std::vector<std::vector<stride_t>> out_strides(strides.size());
68 for (
int i = 0; i < to_collapse.size(); i++) {
69 int current_shape = shape[to_collapse[i]];
70 while (to_collapse[++i] != -1) {
71 current_shape *= shape[to_collapse[i]];
73 out_shape.push_back(current_shape);
74 for (
int j = 0; j < strides.size(); j++) {
75 const std::vector<stride_t>& st = strides[j];
76 out_strides[j].push_back(st[to_collapse[i - 1]]);
80 return std::make_tuple(out_shape, out_strides);
100 const std::vector<int>& shape,
101 const std::vector<stride_t>& strides) {
102 size_t data_size = 1;
105 bool is_row_contiguous =
true;
106 bool is_col_contiguous =
true;
108 for (
int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
109 is_row_contiguous &= strides[i] == f_stride || shape[i] == 1;
110 is_col_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
111 f_stride *= shape[i];
112 b_stride *= shape[ri];
113 if (strides[i] > 0) {
114 data_size *= shape[i];
118 return std::make_tuple(data_size, is_row_contiguous, is_col_contiguous);