MLX
 
Loading...
Searching...
No Matches
binary.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4#include <cassert>
5
6#include "mlx/allocator.h"
7#include "mlx/array.h"
9
10namespace mlx::core {
11
12namespace {
13
14enum class BinaryOpType {
15 ScalarScalar,
16 ScalarVector,
17 VectorScalar,
18 VectorVector,
19 General,
20};
21
22BinaryOpType get_binary_op_type(const array& a, const array& b) {
23 BinaryOpType bopt;
24 if (a.data_size() == 1 && b.data_size() == 1) {
25 bopt = BinaryOpType::ScalarScalar;
26 } else if (a.data_size() == 1 && b.flags().contiguous) {
27 bopt = BinaryOpType::ScalarVector;
28 } else if (b.data_size() == 1 && a.flags().contiguous) {
29 bopt = BinaryOpType::VectorScalar;
30 } else if (
31 (a.flags().row_contiguous && b.flags().row_contiguous) ||
32 (a.flags().col_contiguous && b.flags().col_contiguous)) {
33 bopt = BinaryOpType::VectorVector;
34 } else {
35 bopt = BinaryOpType::General;
36 }
37 return bopt;
38}
39
40void set_binary_op_output_data(
41 const array& a,
42 const array& b,
43 array& out,
44 BinaryOpType bopt,
45 bool donate_with_move = false) {
46 bool b_donatable = is_donatable(b, out);
47 bool a_donatable = is_donatable(a, out);
48 switch (bopt) {
49 case BinaryOpType::ScalarScalar:
50 out.set_data(
51 allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
52 break;
53 case BinaryOpType::ScalarVector:
54 if (b_donatable) {
55 if (donate_with_move) {
56 out.move_shared_buffer(b);
57 } else {
58 out.copy_shared_buffer(b);
59 }
60 } else {
61 out.set_data(
62 allocator::malloc_or_wait(b.data_size() * out.itemsize()),
63 b.data_size(),
64 b.strides(),
65 b.flags());
66 }
67 break;
68 case BinaryOpType::VectorScalar:
69 if (a_donatable) {
70 if (donate_with_move) {
71 out.move_shared_buffer(a);
72 } else {
73 out.copy_shared_buffer(a);
74 }
75 } else {
76 out.set_data(
77 allocator::malloc_or_wait(a.data_size() * out.itemsize()),
78 a.data_size(),
79 a.strides(),
80 a.flags());
81 }
82 break;
83 case BinaryOpType::VectorVector:
84 if (a_donatable) {
85 if (donate_with_move) {
86 out.move_shared_buffer(a);
87 } else {
88 out.copy_shared_buffer(a);
89 }
90 } else if (b_donatable) {
91 if (donate_with_move) {
92 out.move_shared_buffer(b);
93 } else {
94 out.copy_shared_buffer(b);
95 }
96 } else {
97 out.set_data(
98 allocator::malloc_or_wait(a.data_size() * out.itemsize()),
99 a.data_size(),
100 a.strides(),
101 a.flags());
102 }
103 break;
104 case BinaryOpType::General:
105 if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) {
106 if (donate_with_move) {
107 out.move_shared_buffer(a);
108 } else {
109 out.copy_shared_buffer(a);
110 }
111 } else if (
112 b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
113 if (donate_with_move) {
114 out.move_shared_buffer(b);
115 } else {
116 out.copy_shared_buffer(b);
117 }
118 } else {
119 out.set_data(allocator::malloc_or_wait(out.nbytes()));
120 }
121 break;
122 }
123}
124
125struct UseDefaultBinaryOp {};
126
127template <typename T, typename U, typename Op>
128struct DefaultVectorScalar {
129 Op op;
130
131 DefaultVectorScalar(Op op_) : op(op_) {}
132
133 void operator()(const T* a, const T* b, U* dst, int size) {
134 T scalar = *b;
135 while (size-- > 0) {
136 *dst = op(*a, scalar);
137 dst++;
138 a++;
139 }
140 }
141};
142
143template <typename T, typename U, typename Op>
144struct DefaultScalarVector {
145 Op op;
146
147 DefaultScalarVector(Op op_) : op(op_) {}
148
149 void operator()(const T* a, const T* b, U* dst, int size) {
150 T scalar = *a;
151 while (size-- > 0) {
152 *dst = op(scalar, *b);
153 dst++;
154 b++;
155 }
156 }
157};
158
159template <typename T, typename U, typename Op>
160struct DefaultVectorVector {
161 Op op;
162
163 DefaultVectorVector(Op op_) : op(op_) {}
164
165 void operator()(const T* a, const T* b, U* dst, int size) {
166 while (size-- > 0) {
167 *dst = op(*a, *b);
168 dst++;
169 a++;
170 b++;
171 }
172 }
173};
174
175template <typename T, typename U, typename Op, int D, bool Strided>
176void binary_op_dims(
177 const T* a,
178 const T* b,
179 U* out,
180 Op op,
181 const Shape& shape,
182 const Strides& a_strides,
183 const Strides& b_strides,
184 const Strides& out_strides,
185 int axis) {
186 auto stride_a = a_strides[axis];
187 auto stride_b = b_strides[axis];
188 auto stride_out = out_strides[axis];
189 auto N = shape[axis];
190
191 for (int i = 0; i < N; i++) {
192 if constexpr (D > 1) {
193 binary_op_dims<T, U, Op, D - 1, Strided>(
194 a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1);
195 } else {
196 if constexpr (Strided) {
197 op(a, b, out, stride_out);
198 } else {
199 *out = op(*a, *b);
200 }
201 }
202 out += stride_out;
203 a += stride_a;
204 b += stride_b;
205 }
206}
207
208template <typename T, typename U, bool Strided, typename Op>
209void binary_op_dispatch_dims(
210 const array& a,
211 const array& b,
212 array& out,
213 Op op,
214 int dim,
215 const Shape& shape,
216 const Strides& a_strides,
217 const Strides& b_strides,
218 const Strides& out_strides) {
219 const T* a_ptr = a.data<T>();
220 const T* b_ptr = b.data<T>();
221 U* out_ptr = out.data<U>();
222 switch (dim) {
223 case 1:
224 binary_op_dims<T, U, Op, 1, Strided>(
225 a_ptr,
226 b_ptr,
227 out_ptr,
228 op,
229 shape,
230 a_strides,
231 b_strides,
232 out_strides,
233 0);
234 return;
235 case 2:
236 binary_op_dims<T, U, Op, 2, Strided>(
237 a_ptr,
238 b_ptr,
239 out_ptr,
240 op,
241 shape,
242 a_strides,
243 b_strides,
244 out_strides,
245 0);
246 return;
247 case 3:
248 binary_op_dims<T, U, Op, 3, Strided>(
249 a_ptr,
250 b_ptr,
251 out_ptr,
252 op,
253 shape,
254 a_strides,
255 b_strides,
256 out_strides,
257 0);
258 return;
259 }
260
261 ContiguousIterator a_it(shape, a_strides, dim - 3);
262 ContiguousIterator b_it(shape, b_strides, dim - 3);
263 auto stride = out_strides[dim - 4];
264 for (int64_t elem = 0; elem < a.size(); elem += stride) {
265 binary_op_dims<T, U, Op, 3, Strided>(
266 a_ptr + a_it.loc,
267 b_ptr + b_it.loc,
268 out_ptr + elem,
269 op,
270 shape,
271 a_strides,
272 b_strides,
273 out_strides,
274 dim - 3);
275 a_it.step();
276 b_it.step();
277 }
278}
279
280template <
281 typename T,
282 typename U,
283 typename Op,
284 typename OpSV,
285 typename OpVS,
286 typename OpVV>
287void binary_op(
288 const array& a,
289 const array& b,
290 array& out,
291 Op op,
292 OpSV opsv,
293 OpVS opvs,
294 OpVV opvv) {
295 auto bopt = get_binary_op_type(a, b);
296 set_binary_op_output_data(a, b, out, bopt);
297
298 // The full computation is scalar scalar so call the base op once
299 if (bopt == BinaryOpType::ScalarScalar) {
300 *(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
301 return;
302 }
303
304 // The full computation is scalar vector so delegate to the op
305 if (bopt == BinaryOpType::ScalarVector) {
306 opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
307 return;
308 }
309
310 // The full computation is vector scalar so delegate to the op
311 if (bopt == BinaryOpType::VectorScalar) {
312 opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
313 return;
314 }
315
316 // The full computation is vector vector so delegate to the op
317 if (bopt == BinaryOpType::VectorVector) {
318 opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
319 return;
320 }
321
322 // General computation so let's try to optimize
323 auto [new_shape, new_strides] = collapse_contiguous_dims(
324 a.shape(), {a.strides(), b.strides(), out.strides()});
325 const auto& a_strides = new_strides[0];
326 const auto& b_strides = new_strides[1];
327 const auto& strides = new_strides[2];
328
329 // Get the left-most dim such that the array is row contiguous after
330 auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
331 int d = arr_strides.size() - 1;
332 for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
333 }
334 return d + 1;
335 };
336 auto a_rc_dim = leftmost_rc_dim(a_strides);
337 auto b_rc_dim = leftmost_rc_dim(b_strides);
338
339 // Get the left-most dim such that the array is a broadcasted "scalar" after
340 auto leftmost_s_dim = [](const auto& arr_strides) {
341 int d = arr_strides.size() - 1;
342 for (; d >= 0 && arr_strides[d] == 0; d--) {
343 }
344 return d + 1;
345 };
346 auto a_s_dim = leftmost_s_dim(a_strides);
347 auto b_s_dim = leftmost_s_dim(b_strides);
348
349 auto ndim = new_shape.size();
350
351 // Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
352 int dim = ndim;
353 if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
354 bopt = BinaryOpType::VectorVector;
355 dim = d;
356 // Case 2: LxM and Fx1 where L and F are broadcastable and M is row
357 // contiguous
358 } else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
359 bopt = BinaryOpType::VectorScalar;
360 dim = d;
361 // Case 3: Lx1 and FxM where L and F are broadcastable and M is row
362 // contiguous
363 } else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
364 bopt = BinaryOpType::ScalarVector;
365 dim = d;
366 }
367
368 // Can be sure dim > 0 since otherwise we would have used one of the fully
369 // contiguous methods above. Except for the case that the flags do not
370 // correspond to the underlying contiguity.
371 if (dim == 0 || strides[dim - 1] < 16) {
372 bopt = BinaryOpType::General;
373 dim = ndim;
374 }
375
376 switch (bopt) {
377 case BinaryOpType::VectorVector:
378 binary_op_dispatch_dims<T, U, true>(
379 a, b, out, opvv, dim, new_shape, a_strides, b_strides, strides);
380 break;
381 case BinaryOpType::VectorScalar:
382 binary_op_dispatch_dims<T, U, true>(
383 a, b, out, opvs, dim, new_shape, a_strides, b_strides, strides);
384 break;
385 case BinaryOpType::ScalarVector:
386 binary_op_dispatch_dims<T, U, true>(
387 a, b, out, opsv, dim, new_shape, a_strides, b_strides, strides);
388 break;
389 default:
390 binary_op_dispatch_dims<T, U, false>(
391 a, b, out, op, dim, new_shape, a_strides, b_strides, strides);
392 break;
393 }
394}
395
396template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV>
397void binary_op(
398 const array& a,
399 const array& b,
400 array& out,
401 Op op,
402 OpSV opsv,
403 OpVS opvs,
404 OpVV opvv) {
405 // TODO: The following mess of constexpr evaluations can probably be achieved
406 // with template specializations and overloading. Would it be simpler?
407
408 if constexpr (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
409 if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
410 if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
411 // All ops are UseDefaultBinaryOp (why oh why would someone call that?)
412 binary_op<T, T>(
413 a,
414 b,
415 out,
416 op,
417 DefaultScalarVector<T, T, Op>(op),
418 DefaultVectorScalar<T, T, Op>(op),
419 DefaultVectorVector<T, T, Op>(op));
420 } else {
421 // opsv and opvs were UseDefaultBinaryOp
422 binary_op<T, T>(
423 a,
424 b,
425 out,
426 op,
427 DefaultScalarVector<T, T, Op>(op),
428 DefaultVectorScalar<T, T, Op>(op),
429 opvv);
430 }
431 } else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
432 value) {
433 // opsv and opvv were UseDefaultBinaryOp
434 binary_op<T, T>(
435 a,
436 b,
437 out,
438 op,
439 DefaultScalarVector<T, T, Op>(op),
440 opvs,
441 DefaultVectorVector<T, T, Op>(op));
442 } else {
443 // opsv was UseDefaultBinaryOp
444 binary_op<T, T>(
445 a, b, out, op, DefaultScalarVector<T, T, Op>(op), opvs, opvv);
446 }
447 } else if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::
448 value) {
449 if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
450 // opvs and opvv were UseDefaultBinaryOp
451 binary_op<T, T>(
452 a,
453 b,
454 out,
455 op,
456 opsv,
457 DefaultVectorScalar<T, T, Op>(op),
458 DefaultVectorVector<T, T, Op>(op));
459 } else {
460 // opvs was UseDefaultBinaryOp
461 binary_op<T, T>(
462 a, b, out, op, opsv, DefaultVectorScalar<T, T, Op>(op), opvv);
463 }
464 } else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
465 value) {
466 // opvv was UseDefaultBinaryOp
467 binary_op<T, T>(
468 a, b, out, op, opsv, opvs, DefaultVectorVector<T, T, Op>(op));
469 } else {
470 // All ops provided
471 binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
472 }
473}
474
475template <typename T, typename Op>
476void binary_op(const array& a, const array& b, array& out, Op op) {
477 DefaultScalarVector<T, T, Op> opsv(op);
478 DefaultVectorScalar<T, T, Op> opvs(op);
479 DefaultVectorVector<T, T, Op> opvv(op);
480 binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
481}
482
483template <typename... Ops>
484void binary(const array& a, const array& b, array& out, Ops... ops) {
485 switch (out.dtype()) {
486 case bool_:
487 binary_op<bool>(a, b, out, ops...);
488 break;
489 case uint8:
490 binary_op<uint8_t>(a, b, out, ops...);
491 break;
492 case uint16:
493 binary_op<uint16_t>(a, b, out, ops...);
494 break;
495 case uint32:
496 binary_op<uint32_t>(a, b, out, ops...);
497 break;
498 case uint64:
499 binary_op<uint64_t>(a, b, out, ops...);
500 break;
501 case int8:
502 binary_op<int8_t>(a, b, out, ops...);
503 break;
504 case int16:
505 binary_op<int16_t>(a, b, out, ops...);
506 break;
507 case int32:
508 binary_op<int32_t>(a, b, out, ops...);
509 break;
510 case int64:
511 binary_op<int64_t>(a, b, out, ops...);
512 break;
513 case float16:
514 binary_op<float16_t>(a, b, out, ops...);
515 break;
516 case float32:
517 binary_op<float>(a, b, out, ops...);
518 break;
519 case bfloat16:
520 binary_op<bfloat16_t>(a, b, out, ops...);
521 break;
522 case complex64:
523 binary_op<complex64_t>(a, b, out, ops...);
524 break;
525 }
526}
527
528} // namespace
529
530} // namespace mlx::core
Definition array.h:24
Buffer malloc_or_wait(size_t size)
const char * binary()
Definition allocator.h:7
constexpr Dtype bool_
Definition dtype.h:67
constexpr Dtype uint64
Definition dtype.h:72
constexpr Dtype uint16
Definition dtype.h:70
std::tuple< Shape, std::vector< Strides > > collapse_contiguous_dims(const Shape &shape, const std::vector< Strides > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())
constexpr Dtype bfloat16
Definition dtype.h:81
constexpr Dtype int32
Definition dtype.h:76
constexpr Dtype float32
Definition dtype.h:80
std::vector< ShapeElem > Shape
Definition array.h:21
constexpr Dtype int16
Definition dtype.h:75
std::vector< int64_t > Strides
Definition array.h:22
constexpr Dtype int8
Definition dtype.h:74
constexpr Dtype int64
Definition dtype.h:77
constexpr Dtype uint8
Definition dtype.h:69
constexpr Dtype float16
Definition dtype.h:79
constexpr Dtype uint32
Definition dtype.h:71
bool is_donatable(const array &in, const array &out)
Definition utils.h:155
constexpr Dtype complex64
Definition dtype.h:82
void dst(const shape_t &shape, const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1)
Definition pocketfft.h:3416
Definition utils.h:73