MLX
Loading...
Searching...
No Matches
utils.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <vector>
6
7#include "mlx/array.h"
8
9namespace mlx::core {
10
11template <typename StrideT>
12inline StrideT elem_to_loc(
13 int elem,
14 const std::vector<int>& shape,
15 const std::vector<StrideT>& strides) {
16 StrideT loc = 0;
17 for (int i = shape.size() - 1; i >= 0; --i) {
18 auto q_and_r = ldiv(elem, shape[i]);
19 loc += q_and_r.rem * strides[i];
20 elem = q_and_r.quot;
21 }
22 return loc;
23}
24
25inline size_t elem_to_loc(int elem, const array& a) {
26 if (a.flags().row_contiguous) {
27 return elem;
28 }
29 return elem_to_loc(elem, a.shape(), a.strides());
30}
31
32template <typename StrideT>
33std::vector<StrideT> make_contiguous_strides(const std::vector<int>& shape) {
34 std::vector<StrideT> strides(shape.size(), 1);
35 for (int i = shape.size() - 1; i > 0; i--) {
36 strides[i - 1] = strides[i] * shape[i];
37 }
38 return strides;
39}
40
41// Collapse dims that are contiguous to possibly route to a better kernel
42// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
43// should return {{2, 4}, {{1, 2}}}.
44//
45// When multiple arrays are passed they should all have the same shape. The
46// collapsed axes are also the same so one shape is returned.
47std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>>
49 const std::vector<int>& shape,
50 const std::vector<std::vector<int64_t>>& strides,
51 int64_t size_cap = std::numeric_limits<int32_t>::max());
52std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
54 const std::vector<int>& shape,
55 const std::vector<std::vector<size_t>>& strides,
56 size_t size_cap = std::numeric_limits<int32_t>::max());
57
58inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
60 const std::vector<array>& xs,
61 size_t size_cap = std::numeric_limits<int32_t>::max()) {
62 std::vector<std::vector<size_t>> strides;
63 for (auto& x : xs) {
64 strides.emplace_back(x.strides());
65 }
66 return collapse_contiguous_dims(xs[0].shape(), strides, size_cap);
67}
68
69template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
70inline auto collapse_contiguous_dims(Arrays&&... xs) {
72 std::vector<array>{std::forward<Arrays>(xs)...});
73}
74
75// The single array version of the above.
76std::pair<std::vector<int>, std::vector<int64_t>> collapse_contiguous_dims(
77 const std::vector<int>& shape,
78 const std::vector<int64_t>& strides,
79 int64_t size_cap = std::numeric_limits<int32_t>::max());
80std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
81 const std::vector<int>& shape,
82 const std::vector<size_t>& strides,
83 size_t size_cap = std::numeric_limits<int32_t>::max());
84std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
85 const array& a,
86 size_t size_cap = std::numeric_limits<int32_t>::max());
87
88template <typename StrideT>
90 inline void step() {
91 int dims = shape_.size();
92 if (dims == 0) {
93 return;
94 }
95 int i = dims - 1;
96 while (pos_[i] == (shape_[i] - 1) && i > 0) {
97 pos_[i] = 0;
98 loc -= (shape_[i] - 1) * strides_[i];
99 i--;
100 }
101 pos_[i]++;
102 loc += strides_[i];
103 }
104
105 void seek(StrideT n) {
106 loc = 0;
107 for (int i = shape_.size() - 1; i >= 0; --i) {
108 auto q_and_r = ldiv(n, shape_[i]);
109 loc += q_and_r.rem * strides_[i];
110 pos_[i] = q_and_r.rem;
111 n = q_and_r.quot;
112 }
113 }
114
115 void reset() {
116 loc = 0;
117 std::fill(pos_.begin(), pos_.end(), 0);
118 }
119
121
122 explicit ContiguousIterator(const array& a)
123 : shape_(a.shape()), strides_(a.strides()) {
124 if (!shape_.empty()) {
125 std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
126 pos_ = std::vector<int>(shape_.size(), 0);
127 }
128 }
129
131 const std::vector<int>& shape,
132 const std::vector<StrideT>& strides,
133 int dims)
134 : shape_(shape.begin(), shape.begin() + dims),
135 strides_(strides.begin(), strides.begin() + dims) {
136 if (!shape_.empty()) {
137 std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
138 pos_ = std::vector<int>(shape_.size(), 0);
139 }
140 }
141
142 StrideT loc{0};
143
144 private:
145 std::vector<int> shape_;
146 std::vector<StrideT> strides_;
147 std::vector<int> pos_;
148};
149
150template <typename StrideT>
152 const std::vector<int>& shape,
153 const std::vector<StrideT>& strides) {
154 size_t no_broadcast_data_size = 1;
155 size_t f_stride = 1;
156 size_t b_stride = 1;
157 bool is_row_contiguous = true;
158 bool is_col_contiguous = true;
159
160 for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
161 is_col_contiguous &= strides[i] == f_stride || shape[i] == 1;
162 is_row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
163 f_stride *= shape[i];
164 b_stride *= shape[ri];
165 if (strides[i] > 0) {
166 no_broadcast_data_size *= shape[i];
167 }
168 }
169
170 return std::make_tuple(
171 no_broadcast_data_size, is_row_contiguous, is_col_contiguous);
172}
173
174inline bool is_donatable(const array& in, const array& out) {
175 constexpr size_t donation_extra = 16384;
176
177 return in.is_donatable() && in.itemsize() == out.itemsize() &&
178 in.buffer_size() <= out.nbytes() + donation_extra;
179}
180
181} // namespace mlx::core
Definition array.h:20
const Flags & flags() const
Get the Flags bit-field.
Definition array.h:302
const std::vector< size_t > & strides() const
The strides of the array.
Definition array.h:113
size_t nbytes() const
The number of bytes in the array.
Definition array.h:89
bool is_donatable() const
True indicates the arrays buffer is safe to reuse.
Definition array.h:267
const std::vector< int > & shape() const
The shape of the array as a vector of integers.
Definition array.h:99
size_t buffer_size() const
Definition array.h:327
size_t itemsize() const
The size of the array's datatype in bytes.
Definition array.h:79
Definition allocator.h:7
std::vector< StrideT > make_contiguous_strides(const std::vector< int > &shape)
Definition utils.h:33
std::tuple< std::vector< int >, std::vector< std::vector< int64_t > > > collapse_contiguous_dims(const std::vector< int > &shape, const std::vector< std::vector< int64_t > > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())
auto check_contiguity(const std::vector< int > &shape, const std::vector< StrideT > &strides)
Definition utils.h:151
StrideT elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< StrideT > &strides)
Definition utils.h:12
bool is_donatable(const array &in, const array &out)
Definition utils.h:174
typename std::enable_if_t< is_arrays_v< T... > > enable_for_arrays_t
Definition array.h:611
Definition utils.h:89
StrideT loc
Definition utils.h:142
ContiguousIterator(const std::vector< int > &shape, const std::vector< StrideT > &strides, int dims)
Definition utils.h:130
void seek(StrideT n)
Definition utils.h:105
void reset()
Definition utils.h:115
ContiguousIterator()
Definition utils.h:120
ContiguousIterator(const array &a)
Definition utils.h:122
void step()
Definition utils.h:90
bool row_contiguous
Definition array.h:233