MLX
 
Loading...
Searching...
No Matches
unary.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
5#include "mlx/allocator.h"
6#include "mlx/array.h"
10#include "mlx/utils.h"
11
12namespace mlx::core {
13
14void set_unary_output_data(const array& in, array& out) {
15 if (in.flags().contiguous) {
16 if (is_donatable(in, out)) {
17 out.copy_shared_buffer(in);
18 } else {
19 auto size = in.data_size();
20 out.set_data(
22 size,
23 in.strides(),
24 in.flags());
25 }
26 } else {
28 }
29}
30
31template <typename T, typename U = T, typename Op>
32void unary_op(const T* a, U* out, size_t shape, size_t stride) {
33 for (size_t i = 0; i < shape; i += 1) {
34 out[i] = Op{}(*a);
35 a += stride;
36 }
37}
38
39template <typename T, typename U = T, typename Op>
40void unary_op(const array& a, array& out, Op) {
41 const T* src = a.data<T>();
42 U* dst = out.data<U>();
43 auto ndim = a.ndim();
44 if (a.flags().contiguous) {
45 auto size = a.data_size();
46 constexpr int N = simd::max_size<T>;
47 while (size >= N) {
48 simd::store(dst, Op{}(simd::load<T, N>(src)));
49 size -= N;
50 src += N;
51 dst += N;
52 }
53 while (size > 0) {
54 *dst = Op{}(*src);
55 size--;
56 dst++;
57 src++;
58 }
59 } else {
60 size_t shape = ndim > 0 ? a.shape().back() : 1;
61 size_t stride = ndim > 0 ? a.strides().back() : 1;
62 if (ndim <= 1) {
63 unary_op<T, U, Op>(src, dst, shape, stride);
64 return;
65 }
66 auto it = ContiguousIterator(a.shape(), a.strides(), ndim - 1);
67 for (size_t elem = 0; elem < a.size(); elem += shape) {
68 unary_op<T, U, Op>(src + it.loc, dst + elem, shape, stride);
69 it.step();
70 }
71 }
72}
73
74template <typename Op>
75void unary(const array& a, array& out, Op op, Stream stream) {
77 auto& encoder = cpu::get_command_encoder(stream);
78 encoder.set_input_array(a);
79 encoder.set_output_array(out);
80 encoder.dispatch([a = array::unsafe_weak_copy(a),
81 out = array::unsafe_weak_copy(out),
82 op = op]() mutable {
83 switch (out.dtype()) {
84 case bool_:
85 unary_op<bool>(a, out, op);
86 break;
87 case uint8:
88 unary_op<uint8_t>(a, out, op);
89 break;
90 case uint16:
91 unary_op<uint16_t>(a, out, op);
92 break;
93 case uint32:
94 unary_op<uint32_t>(a, out, op);
95 break;
96 case uint64:
97 unary_op<uint64_t>(a, out, op);
98 break;
99 case int8:
100 unary_op<int8_t>(a, out, op);
101 break;
102 case int16:
103 unary_op<int16_t>(a, out, op);
104 break;
105 case int32:
106 unary_op<int32_t>(a, out, op);
107 break;
108 case int64:
109 unary_op<int64_t>(a, out, op);
110 break;
111 case float16:
112 unary_op<float16_t>(a, out, op);
113 break;
114 case float32:
115 unary_op<float>(a, out, op);
116 break;
117 case float64:
118 unary_op<double>(a, out, op);
119 break;
120 case bfloat16:
121 unary_op<bfloat16_t>(a, out, op);
122 break;
123 case complex64:
124 unary_op<complex64_t>(a, out, op);
125 break;
126 }
127 });
128}
129
130template <typename Op>
131void unary_real_fp(const array& a, array& out, Op op, Stream stream) {
132 set_unary_output_data(a, out);
133 auto& encoder = cpu::get_command_encoder(stream);
134 encoder.set_input_array(a);
135 encoder.set_output_array(out);
136 encoder.dispatch([a = array::unsafe_weak_copy(a),
137 out = array::unsafe_weak_copy(out),
138 op = op]() mutable {
139 switch (out.dtype()) {
140 case bfloat16:
141 unary_op<bfloat16_t>(a, out, op);
142 break;
143 case float16:
144 unary_op<float16_t>(a, out, op);
145 break;
146 case float32:
147 unary_op<float>(a, out, op);
148 break;
149 case float64:
150 unary_op<double>(a, out, op);
151 break;
152 default:
153 std::ostringstream err;
154 err << "[unary_real] Does not support " << out.dtype();
155 throw std::runtime_error(err.str());
156 }
157 });
158}
159template <typename Op>
160void unary_fp(const array& a, array& out, Op op, Stream stream) {
161 set_unary_output_data(a, out);
162 auto& encoder = cpu::get_command_encoder(stream);
163 encoder.set_input_array(a);
164 encoder.set_output_array(out);
165 encoder.dispatch([a = array::unsafe_weak_copy(a),
166 out = array::unsafe_weak_copy(out),
167 op = op]() mutable {
168 switch (out.dtype()) {
169 case bfloat16:
170 unary_op<bfloat16_t>(a, out, op);
171 break;
172 case float16:
173 unary_op<float16_t>(a, out, op);
174 break;
175 case float32:
176 unary_op<float>(a, out, op);
177 break;
178 case float64:
179 unary_op<double>(a, out, op);
180 break;
181 case complex64:
182 unary_op<complex64_t>(a, out, op);
183 break;
184 default:
185 std::ostringstream err;
186 err << "[unary_fp] Does not support " << out.dtype();
187 throw std::runtime_error(err.str());
188 }
189 });
190}
191
192template <typename Op>
193void unary_signed(const array& a, array& out, Op op, Stream stream) {
194 set_unary_output_data(a, out);
195 auto& encoder = cpu::get_command_encoder(stream);
196 encoder.set_input_array(a);
197 encoder.set_output_array(out);
198 encoder.dispatch([a = array::unsafe_weak_copy(a),
199 out = array::unsafe_weak_copy(out),
200 op = op]() mutable {
201 switch (out.dtype()) {
202 case int8:
203 unary_op<int8_t>(a, out, op);
204 break;
205 case int16:
206 unary_op<int16_t>(a, out, op);
207 break;
208 case int32:
209 unary_op<int32_t>(a, out, op);
210 break;
211 case int64:
212 unary_op<int64_t>(a, out, op);
213 break;
214 case float16:
215 unary_op<float16_t>(a, out, op);
216 break;
217 case float32:
218 unary_op<float>(a, out, op);
219 break;
220 case float64:
221 unary_op<double>(a, out, op);
222 break;
223 case bfloat16:
224 unary_op<bfloat16_t>(a, out, op);
225 break;
226 case complex64:
227 unary_op<complex64_t>(a, out, op);
228 break;
229 default:
230 throw std::runtime_error("[Abs] Called on unsigned type");
231 }
232 });
233}
234
235template <typename Op>
236void unary_complex(const array& a, array& out, Op op, Stream stream) {
237 set_unary_output_data(a, out);
238 auto& encoder = cpu::get_command_encoder(stream);
239 encoder.set_input_array(a);
240 encoder.set_output_array(out);
241 encoder.dispatch([a = array::unsafe_weak_copy(a),
242 out = array::unsafe_weak_copy(out),
243 op = op]() mutable { unary_op<complex64_t>(a, out, op); });
244}
245
246template <typename Op>
247void unary_complex_to_float(const array& a, array& out, Op op, Stream stream) {
248 set_unary_output_data(a, out);
249 auto& encoder = cpu::get_command_encoder(stream);
250 encoder.set_input_array(a);
251 encoder.set_output_array(out);
252 encoder.dispatch(
254 out = array::unsafe_weak_copy(out),
255 op = op]() mutable { unary_op<complex64_t, float>(a, out, op); });
256}
257
258template <typename Op>
259void unary_int(const array& a, array& out, Op op, Stream stream) {
260 set_unary_output_data(a, out);
261 auto& encoder = cpu::get_command_encoder(stream);
262 encoder.set_input_array(a);
263 encoder.set_output_array(out);
264 encoder.dispatch([a = array::unsafe_weak_copy(a),
265 out = array::unsafe_weak_copy(out),
266 op = op]() mutable {
267 switch (out.dtype()) {
268 case uint8:
269 unary_op<uint8_t>(a, out, op);
270 break;
271 case uint16:
272 unary_op<uint16_t>(a, out, op);
273 break;
274 case uint32:
275 unary_op<uint32_t>(a, out, op);
276 break;
277 case uint64:
278 unary_op<uint64_t>(a, out, op);
279 break;
280 case int8:
281 unary_op<int8_t>(a, out, op);
282 break;
283 case int16:
284 unary_op<int16_t>(a, out, op);
285 break;
286 case int32:
287 unary_op<int32_t>(a, out, op);
288 break;
289 case int64:
290 unary_op<int64_t>(a, out, op);
291 break;
292 default:
293 std::ostringstream err;
294 err << "[unary_int] Does not support " << out.dtype();
295 throw std::runtime_error(err.str());
296 }
297 });
298}
299
300} // namespace mlx::core
Definition array.h:24
const Flags & flags() const
Get the Flags bit-field.
Definition array.h:313
const Shape & shape() const
The shape of the array as a vector of integers.
Definition array.h:103
const Strides & strides() const
The strides of the array.
Definition array.h:117
size_t nbytes() const
The number of bytes in the array.
Definition array.h:93
size_t ndim() const
The number of dimensions of the array.
Definition array.h:98
size_t size() const
The number of elements in the array.
Definition array.h:88
T * data()
Definition array.h:349
static array unsafe_weak_copy(const array &other)
Get a new array that refers to the same data as the input but with a non-owning pointer to it.
void copy_shared_buffer(const array &other, const Strides &strides, Flags flags, size_t data_size, size_t offset=0)
Dtype dtype() const
Get the arrays data type.
Definition array.h:131
size_t itemsize() const
The size of the array's datatype in bytes.
Definition array.h:83
void set_data(allocator::Buffer buffer, Deleter d=allocator::free)
size_t data_size() const
The size (in elements) of the underlying buffer the array points to.
Definition array.h:327
Buffer malloc_or_wait(size_t size)
CommandEncoder & get_command_encoder(Stream stream)
Simd< T, N > load(const T *x)
Definition base_simd.h:28
static constexpr int max_size
Definition base_simd.h:14
void store(T *dst, Simd< T, N > x)
Definition base_simd.h:33
Definition allocator.h:7
void unary_complex_to_float(const array &a, array &out, Op op, Stream stream)
Definition unary.h:247
void unary(const array &a, array &out, Op op, Stream stream)
Definition unary.h:75
constexpr Dtype bool_
Definition dtype.h:68
constexpr Dtype uint64
Definition dtype.h:73
constexpr Dtype uint16
Definition dtype.h:71
constexpr Dtype float64
Definition dtype.h:82
void set_unary_output_data(const array &in, array &out)
Definition unary.h:14
void unary_int(const array &a, array &out, Op op, Stream stream)
Definition unary.h:259
constexpr Dtype bfloat16
Definition dtype.h:83
constexpr Dtype int32
Definition dtype.h:77
constexpr Dtype float32
Definition dtype.h:81
void unary_fp(const array &a, array &out, Op op, Stream stream)
Definition unary.h:160
constexpr Dtype int16
Definition dtype.h:76
constexpr Dtype int8
Definition dtype.h:75
constexpr Dtype int64
Definition dtype.h:78
void unary_real_fp(const array &a, array &out, Op op, Stream stream)
Definition unary.h:131
constexpr Dtype uint8
Definition dtype.h:70
void unary_complex(const array &a, array &out, Op op, Stream stream)
Definition unary.h:236
void unary_op(const T *a, U *out, size_t shape, size_t stride)
Definition unary.h:32
constexpr Dtype float16
Definition dtype.h:80
constexpr Dtype uint32
Definition dtype.h:72
void unary_signed(const array &a, array &out, Op op, Stream stream)
Definition unary.h:193
bool is_donatable(const array &in, const array &out)
Definition utils.h:155
constexpr Dtype complex64
Definition dtype.h:84
Definition utils.h:73
Definition stream.h:9
bool contiguous
Definition array.h:238