MLX
Loading...
Searching...
No Matches
binary.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4#include <cassert>
5
6#include "mlx/allocator.h"
7#include "mlx/array.h"
9
10namespace mlx::core {
11
12namespace {
13
14enum class BinaryOpType {
15 ScalarScalar,
16 ScalarVector,
17 VectorScalar,
18 VectorVector,
19 General,
20};
21
22BinaryOpType get_binary_op_type(const array& a, const array& b) {
23 BinaryOpType bopt;
24 if (a.data_size() == 1 && b.data_size() == 1) {
25 bopt = BinaryOpType::ScalarScalar;
26 } else if (a.data_size() == 1 && b.flags().contiguous) {
27 bopt = BinaryOpType::ScalarVector;
28 } else if (b.data_size() == 1 && a.flags().contiguous) {
29 bopt = BinaryOpType::VectorScalar;
30 } else if (
31 a.flags().row_contiguous && b.flags().row_contiguous ||
32 a.flags().col_contiguous && b.flags().col_contiguous) {
33 bopt = BinaryOpType::VectorVector;
34 } else {
35 bopt = BinaryOpType::General;
36 }
37 return bopt;
38}
39
40void set_binary_op_output_data(
41 const array& a,
42 const array& b,
43 array& out,
44 BinaryOpType bopt,
45 bool donate_with_move = false) {
46 switch (bopt) {
47 case BinaryOpType::ScalarScalar:
48 out.set_data(
49 allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
50 break;
51 case BinaryOpType::ScalarVector:
52 if (b.is_donatable() && b.itemsize() == out.itemsize()) {
53 if (donate_with_move) {
54 out.move_shared_buffer(b);
55 } else {
56 out.copy_shared_buffer(b);
57 }
58 } else {
59 out.set_data(
60 allocator::malloc_or_wait(b.data_size() * out.itemsize()),
61 b.data_size(),
62 b.strides(),
63 b.flags());
64 }
65 break;
66 case BinaryOpType::VectorScalar:
67 if (a.is_donatable() && a.itemsize() == out.itemsize()) {
68 if (donate_with_move) {
69 out.move_shared_buffer(a);
70 } else {
71 out.copy_shared_buffer(a);
72 }
73 } else {
74 out.set_data(
75 allocator::malloc_or_wait(a.data_size() * out.itemsize()),
76 a.data_size(),
77 a.strides(),
78 a.flags());
79 }
80 break;
81 case BinaryOpType::VectorVector:
82 if (a.is_donatable() && a.itemsize() == out.itemsize()) {
83 if (donate_with_move) {
84 out.move_shared_buffer(a);
85 } else {
86 out.copy_shared_buffer(a);
87 }
88 } else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
89 if (donate_with_move) {
90 out.move_shared_buffer(b);
91 } else {
92 out.copy_shared_buffer(b);
93 }
94 } else {
95 out.set_data(
96 allocator::malloc_or_wait(a.data_size() * out.itemsize()),
97 a.data_size(),
98 a.strides(),
99 a.flags());
100 }
101 break;
102 case BinaryOpType::General:
103 if (a.is_donatable() && a.flags().row_contiguous &&
104 a.itemsize() == out.itemsize() && a.size() == out.size()) {
105 if (donate_with_move) {
106 out.move_shared_buffer(a);
107 } else {
108 out.copy_shared_buffer(a);
109 }
110 } else if (
111 b.is_donatable() && b.flags().row_contiguous &&
112 b.itemsize() == out.itemsize() && b.size() == out.size()) {
113 if (donate_with_move) {
114 out.move_shared_buffer(b);
115 } else {
116 out.copy_shared_buffer(b);
117 }
118 } else {
119 out.set_data(allocator::malloc_or_wait(out.nbytes()));
120 }
121 break;
122 }
123}
124
125struct UseDefaultBinaryOp {
126 template <typename T, typename U>
127 void operator()(const T* a, const T* b, U* dst, int size) {
128 // Should we throw? This should normally never be called.
129 assert(false);
130 }
131
132 template <typename T, typename U>
133 void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
134 // Should we throw? This should normally never be called.
135 assert(false);
136 }
137};
138
139template <typename T, typename U, typename Op>
140struct DefaultVectorScalar {
141 Op op;
142
143 DefaultVectorScalar(Op op_) : op(op_) {}
144
145 void operator()(const T* a, const T* b, U* dst, int size) {
146 T scalar = *b;
147 while (size-- > 0) {
148 *dst = op(*a, scalar);
149 dst++;
150 a++;
151 }
152 }
153
154 void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
155 T scalar = *b;
156 while (size-- > 0) {
157 auto dst = op(*a, scalar);
158 *dst_a = dst.first;
159 *dst_b = dst.second;
160 dst_a++;
161 dst_b++;
162 a++;
163 }
164 }
165};
166
167template <typename T, typename U, typename Op>
168struct DefaultScalarVector {
169 Op op;
170
171 DefaultScalarVector(Op op_) : op(op_) {}
172
173 void operator()(const T* a, const T* b, U* dst, int size) {
174 T scalar = *a;
175 while (size-- > 0) {
176 *dst = op(scalar, *b);
177 dst++;
178 b++;
179 }
180 }
181
182 void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
183 T scalar = *a;
184 while (size-- > 0) {
185 auto dst = op(scalar, *b);
186 *dst_a = dst.first;
187 *dst_b = dst.second;
188 dst_a++;
189 dst_b++;
190 b++;
191 }
192 }
193};
194
195template <typename T, typename U, typename Op>
196struct DefaultVectorVector {
197 Op op;
198
199 DefaultVectorVector(Op op_) : op(op_) {}
200
201 void operator()(const T* a, const T* b, U* dst, int size) {
202 while (size-- > 0) {
203 *dst = op(*a, *b);
204 dst++;
205 a++;
206 b++;
207 }
208 }
209
210 void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
211 while (size-- > 0) {
212 auto dst = op(*a, *b);
213 *dst_a = dst.first;
214 *dst_b = dst.second;
215 dst_a++;
216 dst_b++;
217 a++;
218 b++;
219 }
220 }
221};
222
223template <typename T, typename U, typename Op>
224void binary_op_dims1(const array& a, const array& b, array& out, Op op) {
225 const T* a_ptr = a.data<T>();
226 const T* b_ptr = b.data<T>();
227 U* dst = out.data<U>();
228 size_t a_idx = 0;
229 size_t b_idx = 0;
230 for (size_t i = 0; i < out.size(); ++i) {
231 dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]);
232 a_idx += a.strides()[0];
233 b_idx += b.strides()[0];
234 }
235}
236
237template <typename T, typename U, typename Op>
238void binary_op_dims1(
239 const array& a,
240 const array& b,
241 array& out,
242 Op op,
243 int stride) {
244 const T* a_ptr = a.data<T>();
245 const T* b_ptr = b.data<T>();
246 U* dst = out.data<U>();
247 size_t a_idx = 0;
248 size_t b_idx = 0;
249 for (size_t i = 0; i < a.shape()[0]; i++) {
250 op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
251 a_idx += a.strides()[0];
252 b_idx += b.strides()[0];
253 dst += stride;
254 }
255}
256
257template <typename T, typename U, typename Op>
258void binary_op_dims2(const array& a, const array& b, array& out, Op op) {
259 const T* a_ptr = a.data<T>();
260 const T* b_ptr = b.data<T>();
261 U* dst = out.data<U>();
262 size_t a_idx = 0;
263 size_t b_idx = 0;
264 size_t out_idx = 0;
265 for (size_t i = 0; i < a.shape()[0]; ++i) {
266 for (size_t j = 0; j < a.shape()[1]; ++j) {
267 dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
268 a_idx += a.strides()[1];
269 b_idx += b.strides()[1];
270 }
271 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
272 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
273 }
274}
275
276template <typename T, typename U, typename Op>
277void binary_op_dims2(
278 const array& a,
279 const array& b,
280 array& out,
281 Op op,
282 int stride) {
283 const T* a_ptr = a.data<T>();
284 const T* b_ptr = b.data<T>();
285 U* dst = out.data<U>();
286 size_t a_idx = 0;
287 size_t b_idx = 0;
288 for (size_t i = 0; i < a.shape()[0]; ++i) {
289 for (size_t j = 0; j < a.shape()[1]; ++j) {
290 op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
291 a_idx += a.strides()[1];
292 b_idx += b.strides()[1];
293 dst += stride;
294 }
295 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
296 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
297 }
298}
299
300template <typename T, typename U, typename Op>
301void binary_op_dims3(const array& a, const array& b, array& out, Op op) {
302 const T* a_ptr = a.data<T>();
303 const T* b_ptr = b.data<T>();
304 U* dst = out.data<U>();
305 size_t a_idx = 0;
306 size_t b_idx = 0;
307 size_t out_idx = 0;
308 for (size_t i = 0; i < a.shape()[0]; ++i) {
309 for (size_t j = 0; j < a.shape()[1]; ++j) {
310 for (size_t k = 0; k < a.shape()[2]; ++k) {
311 dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
312 a_idx += a.strides()[2];
313 b_idx += b.strides()[2];
314 }
315 a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
316 b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
317 }
318 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
319 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
320 }
321}
322
323template <typename T, typename U, typename Op>
324void binary_op_dims4(const array& a, const array& b, array& out, Op op) {
325 const T* a_ptr = a.data<T>();
326 const T* b_ptr = b.data<T>();
327 U* dst = out.data<U>();
328 size_t a_idx = 0;
329 size_t b_idx = 0;
330 size_t out_idx = 0;
331 for (size_t i = 0; i < a.shape()[0]; ++i) {
332 for (size_t j = 0; j < a.shape()[1]; ++j) {
333 for (size_t k = 0; k < a.shape()[2]; ++k) {
334 for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
335 dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
336 a_idx += a.strides()[3];
337 b_idx += b.strides()[3];
338 }
339 a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
340 b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
341 }
342 a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
343 b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
344 }
345 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
346 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
347 }
348}
349
350template <typename T, typename U, typename Op>
351void binary_op_dispatch_dims(
352 const array& a,
353 const array& b,
354 array& out,
355 Op op) {
356 switch (out.ndim()) {
357 case 1:
358 binary_op_dims1<T, U, Op>(a, b, out, op);
359 return;
360 case 2:
361 binary_op_dims2<T, U, Op>(a, b, out, op);
362 return;
363 case 3:
364 binary_op_dims3<T, U, Op>(a, b, out, op);
365 return;
366 case 4:
367 binary_op_dims4<T, U, Op>(a, b, out, op);
368 return;
369 }
370
371 const T* a_ptr = a.data<T>();
372 const T* b_ptr = b.data<T>();
373 U* dst = out.data<U>();
374 for (size_t i = 0; i < out.size(); i++) {
375 int a_idx = elem_to_loc(i, a.shape(), a.strides());
376 int b_idx = elem_to_loc(i, b.shape(), b.strides());
377 dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]);
378 }
379}
380
381template <typename T, typename U, typename Op>
382void binary_op_dispatch_dims(
383 const array& a,
384 const array& b,
385 array& out,
386 Op op,
387 int dim,
388 int stride) {
389 // Number of dimensions to loop over for vectorized ops
390 switch (dim) {
391 case 1:
392 binary_op_dims1<T, U, Op>(a, b, out, op, stride);
393 return;
394 case 2:
395 binary_op_dims2<T, U, Op>(a, b, out, op, stride);
396 return;
397 }
398
399 const T* a_ptr = a.data<T>();
400 const T* b_ptr = b.data<T>();
401 U* dst = out.data<U>();
402 for (size_t i = 0; i < out.size(); i += stride) {
403 int a_idx = elem_to_loc(i, a.shape(), a.strides());
404 int b_idx = elem_to_loc(i, b.shape(), b.strides());
405 op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
406 dst += stride;
407 }
408}
409
410template <
411 typename T,
412 typename U,
413 typename Op,
414 typename OpSV,
415 typename OpVS,
416 typename OpVV>
417void binary_op(
418 const array& a,
419 const array& b,
420 array& out,
421 Op op,
422 OpSV opsv,
423 OpVS opvs,
424 OpVV opvv) {
425 auto bopt = get_binary_op_type(a, b);
426 set_binary_op_output_data(a, b, out, bopt);
427
428 // The full computation is scalar scalar so call the base op once
429 if (bopt == BinaryOpType::ScalarScalar) {
430 *(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
431 return;
432 }
433
434 // The full computation is scalar vector so delegate to the op
435 if (bopt == BinaryOpType::ScalarVector) {
436 opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
437 return;
438 }
439
440 // The full computation is vector scalar so delegate to the op
441 if (bopt == BinaryOpType::VectorScalar) {
442 opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
443 return;
444 }
445
446 // The full computation is vector vector so delegate to the op
447 if (bopt == BinaryOpType::VectorVector) {
448 opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
449 return;
450 }
451
452 // General computation so let's try to optimize
453
454 // Get the left-most dim such that the array is row contiguous after
455 auto& strides = out.strides();
456 auto leftmost_rc_dim = [&strides](const array& arr) {
457 int d = arr.ndim() - 1;
458 for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
459 }
460 return d + 1;
461 };
462 auto a_rc_dim = leftmost_rc_dim(a);
463 auto b_rc_dim = leftmost_rc_dim(b);
464
465 // Get the left-most dim such that the array is a broadcasted "scalar" after
466 auto leftmost_s_dim = [](const array& arr) {
467 int d = arr.ndim() - 1;
468 for (; d >= 0 && arr.strides()[d] == 0; d--) {
469 }
470 return d + 1;
471 };
472 auto a_s_dim = leftmost_s_dim(a);
473 auto b_s_dim = leftmost_s_dim(b);
474
475 auto ndim = out.ndim();
476
477 // Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
478 int dim = ndim;
479 if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
480 bopt = BinaryOpType::VectorVector;
481 dim = d;
482 // Case 2: LxM and Fx1 where L and F are broadcastable and M is row
483 // contiguous
484 } else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
485 bopt = BinaryOpType::VectorScalar;
486 dim = d;
487 // Case 3: Lx1 and FxM where L and F are broadcastable and M is row
488 // contiguous
489 } else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
490 bopt = BinaryOpType::ScalarVector;
491 dim = d;
492 }
493
494 // Can be sure dim > 0 since otherwise we would have used one of the fully
495 // contiguous methods above. Except for the case that the flags do not
496 // correspond to the underlying contiguity.
497 size_t stride;
498 if (dim == 0 || strides[dim - 1] < 16) {
499 stride = 1;
500 bopt = BinaryOpType::General;
501 dim = ndim;
502 } else {
503 stride = strides[dim - 1];
504 }
505
506 switch (bopt) {
507 case BinaryOpType::VectorVector:
508 binary_op_dispatch_dims<T, U>(a, b, out, opvv, dim, stride);
509 break;
510 case BinaryOpType::VectorScalar:
511 binary_op_dispatch_dims<T, U>(a, b, out, opvs, dim, stride);
512 break;
513 case BinaryOpType::ScalarVector:
514 binary_op_dispatch_dims<T, U>(a, b, out, opsv, dim, stride);
515 break;
516 default:
517 binary_op_dispatch_dims<T, U>(a, b, out, op);
518 break;
519 }
520}
521
522template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV>
523void binary_op(
524 const array& a,
525 const array& b,
526 array& out,
527 Op op,
528 OpSV opsv,
529 OpVS opvs,
530 OpVV opvv) {
531 // TODO: The following mess of constexpr evaluations can probably be achieved
532 // with template specializations and overloading. Would it be simpler?
533
534 if (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
535 if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
536 if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
537 // All ops are UseDefaultBinaryOp (why oh why would someone call that?)
538 binary_op<T, T>(
539 a,
540 b,
541 out,
542 op,
543 DefaultScalarVector<T, T, Op>(op),
544 DefaultVectorScalar<T, T, Op>(op),
545 DefaultVectorVector<T, T, Op>(op));
546 } else {
547 // opsv and opvs were UseDefaultBinaryOp
548 binary_op<T, T>(
549 a,
550 b,
551 out,
552 op,
553 DefaultScalarVector<T, T, Op>(op),
554 DefaultVectorScalar<T, T, Op>(op),
555 opvv);
556 }
557 } else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
558 // opsv and opvv were UseDefaultBinaryOp
559 binary_op<T, T>(
560 a,
561 b,
562 out,
563 op,
564 DefaultScalarVector<T, T, Op>(op),
565 opvs,
566 DefaultVectorVector<T, T, Op>(op));
567 } else {
568 // opsv was UseDefaultBinaryOp
569 binary_op<T, T>(
570 a, b, out, op, DefaultScalarVector<T, T, Op>(op), opvs, opvv);
571 }
572 } else if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
573 if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
574 // opvs and opvv were UseDefaultBinaryOp
575 binary_op<T, T>(
576 a,
577 b,
578 out,
579 op,
580 opsv,
581 DefaultVectorScalar<T, T, Op>(op),
582 DefaultVectorVector<T, T, Op>(op));
583 } else {
584 // opvs was UseDefaultBinaryOp
585 binary_op<T, T>(
586 a, b, out, op, opsv, DefaultVectorScalar<T, T, Op>(op), opvv);
587 }
588 } else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
589 // opvv was UseDefaultBinaryOp
590 binary_op<T, T>(
591 a, b, out, op, opsv, opvs, DefaultVectorVector<T, T, Op>(op));
592 } else {
593 // All ops provided
594 binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
595 }
596}
597
598template <typename T, typename Op>
599void binary_op(const array& a, const array& b, array& out, Op op) {
600 DefaultScalarVector<T, T, Op> opsv(op);
601 DefaultVectorScalar<T, T, Op> opvs(op);
602 DefaultVectorVector<T, T, Op> opvv(op);
603 binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
604}
605
606template <typename... Ops>
607void binary(const array& a, const array& b, array& out, Ops... ops) {
608 switch (out.dtype()) {
609 case bool_:
610 binary_op<bool>(a, b, out, ops...);
611 break;
612 case uint8:
613 binary_op<uint8_t>(a, b, out, ops...);
614 break;
615 case uint16:
616 binary_op<uint16_t>(a, b, out, ops...);
617 break;
618 case uint32:
619 binary_op<uint32_t>(a, b, out, ops...);
620 break;
621 case uint64:
622 binary_op<uint64_t>(a, b, out, ops...);
623 break;
624 case int8:
625 binary_op<int8_t>(a, b, out, ops...);
626 break;
627 case int16:
628 binary_op<int16_t>(a, b, out, ops...);
629 break;
630 case int32:
631 binary_op<int32_t>(a, b, out, ops...);
632 break;
633 case int64:
634 binary_op<int64_t>(a, b, out, ops...);
635 break;
636 case float16:
637 binary_op<float16_t>(a, b, out, ops...);
638 break;
639 case float32:
640 binary_op<float>(a, b, out, ops...);
641 break;
642 case bfloat16:
643 binary_op<bfloat16_t>(a, b, out, ops...);
644 break;
645 case complex64:
646 binary_op<complex64_t>(a, b, out, ops...);
647 break;
648 }
649}
650
651} // namespace
652
653} // namespace mlx::core
Op op
Definition binary.h:141
Buffer malloc_or_wait(size_t size)
const char * binary()
Definition allocator.h:7
constexpr Dtype bool_
Definition dtype.h:60
constexpr Dtype uint64
Definition dtype.h:65
constexpr Dtype uint16
Definition dtype.h:63
stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:12
constexpr Dtype bfloat16
Definition dtype.h:74
constexpr Dtype int32
Definition dtype.h:69
constexpr Dtype float32
Definition dtype.h:73
constexpr Dtype int16
Definition dtype.h:68
constexpr Dtype int8
Definition dtype.h:67
constexpr Dtype int64
Definition dtype.h:70
constexpr Dtype uint8
Definition dtype.h:62
constexpr Dtype float16
Definition dtype.h:72
constexpr Dtype uint32
Definition dtype.h:64
constexpr Dtype complex64
Definition dtype.h:75