MLX
 
Loading...
Searching...
No Matches
utils.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <vector>
6
7#include "mlx/array.h"
8
9namespace mlx::core {
10
11inline int64_t
12elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
13 int64_t loc = 0;
14 for (int i = shape.size() - 1; i >= 0; --i) {
15 auto q_and_r = ldiv(elem, shape[i]);
16 loc += q_and_r.rem * strides[i];
17 elem = q_and_r.quot;
18 }
19 return loc;
20}
21
22inline int64_t elem_to_loc(int elem, const array& a) {
23 if (a.flags().row_contiguous) {
24 return elem;
25 }
26 return elem_to_loc(elem, a.shape(), a.strides());
27}
28
30 Strides strides(shape.size(), 1);
31 for (int i = shape.size() - 1; i > 0; i--) {
32 strides[i - 1] = strides[i] * shape[i];
33 }
34 return strides;
35}
36
37// Collapse dims that are contiguous to possibly route to a better kernel
38// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
39// should return {{2, 4}, {{1, 2}}}.
40//
41// When multiple arrays are passed they should all have the same shape. The
42// collapsed axes are also the same so one shape is returned.
43std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
44 const Shape& shape,
45 const std::vector<Strides>& strides,
46 int64_t size_cap = std::numeric_limits<int32_t>::max());
47
48inline std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
49 const std::vector<array>& xs,
50 size_t size_cap = std::numeric_limits<int32_t>::max()) {
51 std::vector<Strides> strides;
52 for (auto& x : xs) {
53 strides.emplace_back(x.strides());
54 }
55 return collapse_contiguous_dims(xs[0].shape(), strides, size_cap);
56}
57
58template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
59inline auto collapse_contiguous_dims(Arrays&&... xs) {
61 std::vector<array>{std::forward<Arrays>(xs)...});
62}
63
64// The single array version of the above.
65std::pair<Shape, Strides> collapse_contiguous_dims(
66 const Shape& shape,
67 const Strides& strides,
68 int64_t size_cap = std::numeric_limits<int32_t>::max());
69std::pair<Shape, Strides> collapse_contiguous_dims(
70 const array& a,
71 int64_t size_cap = std::numeric_limits<int32_t>::max());
72
74 inline void step() {
75 int dims = shape_.size();
76 if (dims == 0) {
77 return;
78 }
79 int i = dims - 1;
80 while (pos_[i] == (shape_[i] - 1) && i > 0) {
81 pos_[i] = 0;
82 loc -= (shape_[i] - 1) * strides_[i];
83 i--;
84 }
85 pos_[i]++;
86 loc += strides_[i];
87 }
88
89 void seek(int64_t n) {
90 loc = 0;
91 for (int i = shape_.size() - 1; i >= 0; --i) {
92 auto q_and_r = ldiv(n, shape_[i]);
93 loc += q_and_r.rem * strides_[i];
94 pos_[i] = q_and_r.rem;
95 n = q_and_r.quot;
96 }
97 }
98
99 void reset() {
100 loc = 0;
101 std::fill(pos_.begin(), pos_.end(), 0);
102 }
103
105
106 explicit ContiguousIterator(const array& a)
107 : shape_(a.shape()), strides_(a.strides()) {
108 if (!shape_.empty()) {
109 std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
110 pos_ = Shape(shape_.size(), 0);
111 }
112 }
113
115 const Shape& shape,
116 const Strides& strides,
117 int dims)
118 : shape_(shape.begin(), shape.begin() + dims),
119 strides_(strides.begin(), strides.begin() + dims) {
120 if (!shape_.empty()) {
121 std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
122 pos_ = Shape(shape_.size(), 0);
123 }
124 }
125
126 int64_t loc{0};
127
128 private:
129 Shape shape_;
130 Strides strides_;
131 Shape pos_;
132};
133
134inline auto check_contiguity(const Shape& shape, const Strides& strides) {
135 size_t no_broadcast_data_size = 1;
136 int64_t f_stride = 1;
137 int64_t b_stride = 1;
138 bool is_row_contiguous = true;
139 bool is_col_contiguous = true;
140
141 for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
142 is_col_contiguous &= strides[i] == f_stride || shape[i] == 1;
143 is_row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
144 f_stride *= shape[i];
145 b_stride *= shape[ri];
146 if (strides[i] > 0) {
147 no_broadcast_data_size *= shape[i];
148 }
149 }
150
151 return std::make_tuple(
152 no_broadcast_data_size, is_row_contiguous, is_col_contiguous);
153}
154
155inline bool is_donatable(const array& in, const array& out) {
156 constexpr size_t donation_extra = 16384;
157
158 return in.is_donatable() && in.itemsize() == out.itemsize() &&
159 in.buffer_size() <= out.nbytes() + donation_extra;
160}
161
162void move_or_copy(const array& in, array& out);
164 const array& in,
165 array& out,
166 const Strides& strides,
167 array::Flags flags,
168 size_t data_size,
169 size_t offset = 0);
170
171std::pair<bool, Strides> prepare_reshape(const array& in, const array& out);
172
174 const array& in,
175 const Strides& out_strides,
176 array& out);
177} // namespace mlx::core
Definition array.h:24
const Flags & flags() const
Get the Flags bit-field.
Definition array.h:318
const Shape & shape() const
The shape of the array as a vector of integers.
Definition array.h:103
const Strides & strides() const
The strides of the array.
Definition array.h:117
size_t nbytes() const
The number of bytes in the array.
Definition array.h:93
bool is_donatable() const
True indicates the arrays buffer is safe to reuse.
Definition array.h:283
size_t buffer_size() const
Definition array.h:343
size_t itemsize() const
The size of the array's datatype in bytes.
Definition array.h:83
Definition allocator.h:7
Strides make_contiguous_strides(const Shape &shape)
Definition utils.h:29
std::tuple< Shape, std::vector< Strides > > collapse_contiguous_dims(const Shape &shape, const std::vector< Strides > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())
int64_t elem_to_loc(int elem, const Shape &shape, const Strides &strides)
Definition utils.h:12
std::pair< bool, Strides > prepare_reshape(const array &in, const array &out)
std::vector< ShapeElem > Shape
Definition array.h:21
std::vector< int64_t > Strides
Definition array.h:22
void move_or_copy(const array &in, array &out)
void shared_buffer_reshape(const array &in, const Strides &out_strides, array &out)
auto check_contiguity(const Shape &shape, const Strides &strides)
Definition utils.h:134
bool is_donatable(const array &in, const array &out)
Definition utils.h:155
typename std::enable_if_t< is_arrays_v< T... > > enable_for_arrays_t
Definition array.h:630
int64_t loc
Definition utils.h:126
ContiguousIterator()
Definition utils.h:104
ContiguousIterator(const Shape &shape, const Strides &strides, int dims)
Definition utils.h:114
ContiguousIterator(const array &a)
Definition utils.h:106
void step()
Definition utils.h:74
void seek(int64_t n)
Definition utils.h:89
void reset()
Definition utils.h:99
Definition array.h:225
bool row_contiguous
Definition array.h:237