MLX
Loading...
Searching...
No Matches
utils.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <vector>
6
7#include "mlx/array.h"
8
9namespace mlx::core {
10
11template <typename stride_t>
12inline stride_t elem_to_loc(
13 int elem,
14 const std::vector<int>& shape,
15 const std::vector<stride_t>& strides) {
16 stride_t loc = 0;
17 for (int i = shape.size() - 1; i >= 0; --i) {
18 auto q_and_r = ldiv(elem, shape[i]);
19 loc += q_and_r.rem * strides[i];
20 elem = q_and_r.quot;
21 }
22 return loc;
23}
24
25inline size_t elem_to_loc(int elem, const array& a) {
26 if (a.flags().row_contiguous) {
27 return elem;
28 }
29 return elem_to_loc(elem, a.shape(), a.strides());
30}
31
32template <typename stride_t>
33std::vector<stride_t> make_contiguous_strides(const std::vector<int>& shape) {
34 std::vector<stride_t> strides(shape.size(), 1);
35 for (int i = shape.size() - 1; i > 0; i--) {
36 strides[i - 1] = strides[i] * shape[i];
37 }
38 return strides;
39}
40
41// Collapse dims that are contiguous to possibly route to a better kernel
42// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
43// should return {{2, 4}, {{1, 2}}}.
44//
45// When multiple arrays are passed they should all have the same shape. The
46// collapsed axes are also the same so one shape is returned.
47template <typename stride_t>
48inline std::tuple<std::vector<int>, std::vector<std::vector<stride_t>>>
50 const std::vector<int>& shape,
51 const std::vector<std::vector<stride_t>> strides) {
52 // Make a vector that has axes separated with -1. Collapse all axes between
53 // -1.
54 std::vector<int> to_collapse;
55 if (shape.size() > 0) {
56 to_collapse.push_back(0);
57 for (int i = 1; i < shape.size(); i++) {
58 bool contiguous = true;
59 for (const std::vector<stride_t>& st : strides) {
60 if (st[i] * shape[i] != st[i - 1]) {
61 contiguous = false;
62 }
63 if (!contiguous) {
64 break;
65 }
66 }
67 if (!contiguous) {
68 to_collapse.push_back(-1);
69 }
70 to_collapse.push_back(i);
71 }
72 to_collapse.push_back(-1);
73 }
74
75 std::vector<int> out_shape;
76 std::vector<std::vector<stride_t>> out_strides(strides.size());
77 for (int i = 0; i < to_collapse.size(); i++) {
78 int current_shape = shape[to_collapse[i]];
79 while (to_collapse[++i] != -1) {
80 current_shape *= shape[to_collapse[i]];
81 }
82 out_shape.push_back(current_shape);
83 for (int j = 0; j < strides.size(); j++) {
84 const std::vector<stride_t>& st = strides[j];
85 out_strides[j].push_back(st[to_collapse[i - 1]]);
86 }
87 }
88
89 return std::make_tuple(out_shape, out_strides);
90}
91
92inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
93collapse_contiguous_dims(const std::vector<array>& xs) {
94 std::vector<std::vector<size_t>> strides;
95 for (auto& x : xs) {
96 strides.emplace_back(x.strides());
97 }
98 return collapse_contiguous_dims(xs[0].shape(), strides);
99}
100
101template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
102inline auto collapse_contiguous_dims(Arrays&&... xs) {
104 std::vector<array>{std::forward<Arrays>(xs)...});
105}
106
107template <typename stride_t>
109 const std::vector<int>& shape,
110 const std::vector<stride_t>& strides) {
111 size_t data_size = 1;
112 size_t f_stride = 1;
113 size_t b_stride = 1;
114 bool is_row_contiguous = true;
115 bool is_col_contiguous = true;
116
117 for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
118 is_row_contiguous &= strides[i] == f_stride || shape[i] == 1;
119 is_col_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
120 f_stride *= shape[i];
121 b_stride *= shape[ri];
122 if (strides[i] > 0) {
123 data_size *= shape[i];
124 }
125 }
126
127 return std::make_tuple(data_size, is_row_contiguous, is_col_contiguous);
128}
129
130} // namespace mlx::core
Definition array.h:20
const Flags & flags() const
Get the Flags bit-field.
Definition array.h:290
const std::vector< size_t > & strides() const
The strides of the array.
Definition array.h:113
const std::vector< int > & shape() const
The shape of the array as a vector of integers.
Definition array.h:99
Definition allocator.h:7
stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:12
auto check_contiguity(const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:108
std::tuple< std::vector< int >, std::vector< std::vector< stride_t > > > collapse_contiguous_dims(const std::vector< int > &shape, const std::vector< std::vector< stride_t > > strides)
Definition utils.h:49
std::vector< stride_t > make_contiguous_strides(const std::vector< int > &shape)
Definition utils.h:33
typename std::enable_if_t< is_arrays_v< T... > > enable_for_arrays_t
Definition array.h:566
bool row_contiguous
Definition array.h:226