MLX
 
Loading...
Searching...
No Matches
binary.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4#include <cassert>
5
6#include "mlx/allocator.h"
7#include "mlx/array.h"
10
12
13namespace mlx::core {
14
15template <typename Op>
17 Op op;
18
19 VectorScalar(Op op_) : op(op_) {}
20
21 template <typename T, typename U>
22 void operator()(const T* a, const T* b, U* dst, int size) {
23 T scalar = *b;
24 constexpr int N = simd::max_size<T>;
25 while (size >= N) {
27 dst += N;
28 a += N;
29 size -= N;
30 }
31 while (size-- > 0) {
32 *dst = op(*a, scalar);
33 dst++;
34 a++;
35 }
36 }
37};
38
39template <typename Op>
41 Op op;
42
43 ScalarVector(Op op_) : op(op_) {}
44
45 template <typename T, typename U>
46 void operator()(const T* a, const T* b, U* dst, int size) {
47 T scalar = *a;
48 constexpr int N = simd::max_size<T>;
49 while (size >= N) {
51 dst += N;
52 b += N;
53 size -= N;
54 }
55 while (size-- > 0) {
56 *dst = op(scalar, *b);
57 dst++;
58 b++;
59 }
60 }
61};
62
63template <typename Op>
65 Op op;
66
67 VectorVector(Op op_) : op(op_) {}
68
69 template <typename T, typename U>
70 void operator()(const T* a, const T* b, U* dst, int size) {
71 constexpr int N = simd::max_size<T>;
72 while (size >= N) {
74 dst += N;
75 a += N;
76 b += N;
77 size -= N;
78 }
79 while (size-- > 0) {
80 *dst = op(*a, *b);
81 dst++;
82 a++;
83 b++;
84 }
85 }
86};
87
88template <typename T, typename U, typename Op, int D, bool Strided>
90 const T* a,
91 const T* b,
92 U* out,
93 Op op,
94 const Shape& shape,
95 const Strides& a_strides,
96 const Strides& b_strides,
97 const Strides& out_strides,
98 int axis) {
99 auto stride_a = a_strides[axis];
100 auto stride_b = b_strides[axis];
101 auto stride_out = out_strides[axis];
102 auto N = shape[axis];
103
104 for (int i = 0; i < N; i++) {
105 if constexpr (D > 1) {
106 binary_op_dims<T, U, Op, D - 1, Strided>(
107 a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1);
108 } else {
109 if constexpr (Strided) {
110 op(a, b, out, stride_out);
111 } else {
112 *out = op(*a, *b);
113 }
114 }
115 out += stride_out;
116 a += stride_a;
117 b += stride_b;
118 }
119}
120
121template <typename T, typename U, bool Strided, typename Op>
123 const array& a,
124 const array& b,
125 array& out,
126 Op op,
127 int dim,
128 const Shape& shape,
129 const Strides& a_strides,
130 const Strides& b_strides,
131 const Strides& out_strides) {
132 const T* a_ptr = a.data<T>();
133 const T* b_ptr = b.data<T>();
134 U* out_ptr = out.data<U>();
135 switch (dim) {
136 case 1:
138 a_ptr,
139 b_ptr,
140 out_ptr,
141 op,
142 shape,
143 a_strides,
144 b_strides,
145 out_strides,
146 0);
147 return;
148 case 2:
150 a_ptr,
151 b_ptr,
152 out_ptr,
153 op,
154 shape,
155 a_strides,
156 b_strides,
157 out_strides,
158 0);
159 return;
160 case 3:
162 a_ptr,
163 b_ptr,
164 out_ptr,
165 op,
166 shape,
167 a_strides,
168 b_strides,
169 out_strides,
170 0);
171 return;
172 }
173
174 ContiguousIterator a_it(shape, a_strides, dim - 3);
175 ContiguousIterator b_it(shape, b_strides, dim - 3);
176 auto stride = out_strides[dim - 4];
177 for (int64_t elem = 0; elem < a.size(); elem += stride) {
179 a_ptr + a_it.loc,
180 b_ptr + b_it.loc,
181 out_ptr + elem,
182 op,
183 shape,
184 a_strides,
185 b_strides,
186 out_strides,
187 dim - 3);
188 a_it.step();
189 b_it.step();
190 }
191}
192
193template <typename T, typename U, typename Op>
194void binary_op(const array& a, const array& b, array& out, Op op) {
195 auto bopt = get_binary_op_type(a, b);
196 set_binary_op_output_data(a, b, out, bopt);
197
198 // The full computation is scalar scalar so call the base op once
199 if (bopt == BinaryOpType::ScalarScalar) {
200 *(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
201 return;
202 }
203
204 // The full computation is scalar vector so delegate to the op
205 if (bopt == BinaryOpType::ScalarVector) {
206 ScalarVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
207 return;
208 }
209
210 // The full computation is vector scalar so delegate to the op
211 if (bopt == BinaryOpType::VectorScalar) {
212 VectorScalar{op}(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
213 return;
214 }
215
216 // The full computation is vector vector so delegate to the op
217 if (bopt == BinaryOpType::VectorVector) {
218 VectorVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
219 return;
220 }
221
222 // General computation so let's try to optimize
223 auto [new_shape, new_strides] = collapse_contiguous_dims(
224 a.shape(), {a.strides(), b.strides(), out.strides()});
225 const auto& a_strides = new_strides[0];
226 const auto& b_strides = new_strides[1];
227 const auto& strides = new_strides[2];
228
229 // Get the left-most dim such that the array is row contiguous after
230 auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
231 int d = arr_strides.size() - 1;
232 for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
233 }
234 return d + 1;
235 };
236 auto a_rc_dim = leftmost_rc_dim(a_strides);
237 auto b_rc_dim = leftmost_rc_dim(b_strides);
238
239 // Get the left-most dim such that the array is a broadcasted "scalar" after
240 auto leftmost_s_dim = [](const auto& arr_strides) {
241 int d = arr_strides.size() - 1;
242 for (; d >= 0 && arr_strides[d] == 0; d--) {
243 }
244 return d + 1;
245 };
246 auto a_s_dim = leftmost_s_dim(a_strides);
247 auto b_s_dim = leftmost_s_dim(b_strides);
248
249 auto ndim = new_shape.size();
250
251 // Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
252 int dim = ndim;
253 if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
255 dim = d;
256 // Case 2: LxM and Fx1 where L and F are broadcastable and M is row
257 // contiguous
258 } else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
260 dim = d;
261 // Case 3: Lx1 and FxM where L and F are broadcastable and M is row
262 // contiguous
263 } else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
265 dim = d;
266 }
267
268 // Can be sure dim > 0 since otherwise we would have used one of the fully
269 // contiguous methods above. Except for the case that the flags do not
270 // correspond to the underlying contiguity.
271 if (dim == 0 || strides[dim - 1] < 16) {
273 dim = ndim;
274 }
275
276 switch (bopt) {
279 a,
280 b,
281 out,
282 VectorVector{op},
283 dim,
284 new_shape,
285 a_strides,
286 b_strides,
287 strides);
288 break;
291 a,
292 b,
293 out,
294 VectorScalar{op},
295 dim,
296 new_shape,
297 a_strides,
298 b_strides,
299 strides);
300 break;
303 a,
304 b,
305 out,
306 ScalarVector{op},
307 dim,
308 new_shape,
309 a_strides,
310 b_strides,
311 strides);
312 break;
313 default:
315 a, b, out, op, dim, new_shape, a_strides, b_strides, strides);
316 break;
317 }
318}
319
320template <typename T, typename Op>
321void binary_op(const array& a, const array& b, array& out, Op op) {
322 binary_op<T, T>(a, b, out, op);
323}
324
325template <typename Op>
326void binary(const array& a, const array& b, array& out, Op op) {
327 switch (out.dtype()) {
328 case bool_:
329 binary_op<bool>(a, b, out, op);
330 break;
331 case uint8:
332 binary_op<uint8_t>(a, b, out, op);
333 break;
334 case uint16:
335 binary_op<uint16_t>(a, b, out, op);
336 break;
337 case uint32:
338 binary_op<uint32_t>(a, b, out, op);
339 break;
340 case uint64:
341 binary_op<uint64_t>(a, b, out, op);
342 break;
343 case int8:
344 binary_op<int8_t>(a, b, out, op);
345 break;
346 case int16:
347 binary_op<int16_t>(a, b, out, op);
348 break;
349 case int32:
350 binary_op<int32_t>(a, b, out, op);
351 break;
352 case int64:
353 binary_op<int64_t>(a, b, out, op);
354 break;
355 case float16:
356 binary_op<float16_t>(a, b, out, op);
357 break;
358 case float32:
359 binary_op<float>(a, b, out, op);
360 break;
361 case float64:
362 binary_op<double>(a, b, out, op);
363 break;
364 case bfloat16:
365 binary_op<bfloat16_t>(a, b, out, op);
366 break;
367 case complex64:
368 binary_op<complex64_t>(a, b, out, op);
369 break;
370 }
371}
372
373} // namespace mlx::core
Definition array.h:24
const Shape & shape() const
The shape of the array as a vector of integers.
Definition array.h:103
size_t size() const
The number of elements in the array.
Definition array.h:88
T * data()
Definition array.h:354
Dtype dtype() const
Get the arrays data type.
Definition array.h:131
size_t data_size() const
The size (in elements) of the underlying buffer the array points to.
Definition array.h:332
Simd< T, N > load(const T *x)
Definition base_simd.h:28
static constexpr int max_size
Definition base_simd.h:14
void store(T *dst, Simd< T, N > x)
Definition base_simd.h:33
Definition allocator.h:7
constexpr Dtype bool_
Definition dtype.h:68
constexpr Dtype uint64
Definition dtype.h:73
BinaryOpType get_binary_op_type(const array &a, const array &b)
Definition binary.h:19
constexpr Dtype uint16
Definition dtype.h:71
constexpr Dtype float64
Definition dtype.h:82
std::tuple< Shape, std::vector< Strides > > collapse_contiguous_dims(const Shape &shape, const std::vector< Strides > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())
constexpr Dtype bfloat16
Definition dtype.h:83
@ General
Definition binary.h:16
@ VectorVector
Definition binary.h:15
@ ScalarScalar
Definition binary.h:12
@ VectorScalar
Definition binary.h:14
@ ScalarVector
Definition binary.h:13
constexpr Dtype int32
Definition dtype.h:77
void binary_op_dispatch_dims(const array &a, const array &b, array &out, Op op, int dim, const Shape &shape, const Strides &a_strides, const Strides &b_strides, const Strides &out_strides)
Definition binary.h:122
constexpr Dtype float32
Definition dtype.h:81
std::vector< ShapeElem > Shape
Definition array.h:21
void set_binary_op_output_data(const array &a, const array &b, array &out, BinaryOpType bopt, bool donate_with_move=false)
Definition binary.h:37
constexpr Dtype int16
Definition dtype.h:76
std::vector< int64_t > Strides
Definition array.h:22
void binary_op_dims(const T *a, const T *b, U *out, Op op, const Shape &shape, const Strides &a_strides, const Strides &b_strides, const Strides &out_strides, int axis)
Definition binary.h:89
constexpr Dtype int8
Definition dtype.h:75
constexpr Dtype int64
Definition dtype.h:78
constexpr Dtype uint8
Definition dtype.h:70
void binary_op(const array &a, const array &b, array &out, Op op)
Definition binary.h:194
constexpr Dtype float16
Definition dtype.h:80
constexpr Dtype uint32
Definition dtype.h:72
void binary(const array &a, const array &b, array &out, Op op)
Definition binary.h:326
constexpr Dtype complex64
Definition dtype.h:84
Definition utils.h:73
int64_t loc
Definition utils.h:126
void step()
Definition utils.h:74
Definition binary.h:40
ScalarVector(Op op_)
Definition binary.h:43
void operator()(const T *a, const T *b, U *dst, int size)
Definition binary.h:46
Op op
Definition binary.h:41
Definition binary.h:16
void operator()(const T *a, const T *b, U *dst, int size)
Definition binary.h:22
Op op
Definition binary.h:17
VectorScalar(Op op_)
Definition binary.h:19
Definition binary.h:64
VectorVector(Op op_)
Definition binary.h:67
Op op
Definition binary.h:65
void operator()(const T *a, const T *b, U *dst, int size)
Definition binary.h:70
Definition accelerate_simd.h:55