MLX
Loading...
Searching...
No Matches
binary.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4#include "mlx/allocator.h"
5#include "mlx/array.h"
7
8namespace mlx::core {
9
10namespace {
11
12enum class BinaryOpType {
13 ScalarScalar,
14 ScalarVector,
15 VectorScalar,
16 VectorVector,
17 General,
18};
19
20BinaryOpType get_binary_op_type(const array& a, const array& b) {
21 BinaryOpType bopt;
22 if (a.data_size() == 1 && b.data_size() == 1) {
23 bopt = BinaryOpType::ScalarScalar;
24 } else if (a.data_size() == 1 && b.flags().contiguous) {
25 bopt = BinaryOpType::ScalarVector;
26 } else if (b.data_size() == 1 && a.flags().contiguous) {
27 bopt = BinaryOpType::VectorScalar;
28 } else if (
29 a.flags().row_contiguous && b.flags().row_contiguous ||
30 a.flags().col_contiguous && b.flags().col_contiguous) {
31 bopt = BinaryOpType::VectorVector;
32 } else {
33 bopt = BinaryOpType::General;
34 }
35 return bopt;
36}
37
38void set_binary_op_output_data(
39 const array& a,
40 const array& b,
41 array& out,
42 BinaryOpType bopt,
43 bool donate_with_move = false) {
44 switch (bopt) {
45 case BinaryOpType::ScalarScalar:
46 out.set_data(
47 allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
48 break;
49 case BinaryOpType::ScalarVector:
50 if (b.is_donatable() && b.itemsize() == out.itemsize()) {
51 if (donate_with_move) {
52 out.move_shared_buffer(b);
53 } else {
54 out.copy_shared_buffer(b);
55 }
56 } else {
57 out.set_data(
58 allocator::malloc_or_wait(b.data_size() * out.itemsize()),
59 b.data_size(),
60 b.strides(),
61 b.flags());
62 }
63 break;
64 case BinaryOpType::VectorScalar:
65 if (a.is_donatable() && a.itemsize() == out.itemsize()) {
66 if (donate_with_move) {
67 out.move_shared_buffer(a);
68 } else {
69 out.copy_shared_buffer(a);
70 }
71 } else {
72 out.set_data(
73 allocator::malloc_or_wait(a.data_size() * out.itemsize()),
74 a.data_size(),
75 a.strides(),
76 a.flags());
77 }
78 break;
79 case BinaryOpType::VectorVector:
80 if (a.is_donatable() && a.itemsize() == out.itemsize()) {
81 if (donate_with_move) {
82 out.move_shared_buffer(a);
83 } else {
84 out.copy_shared_buffer(a);
85 }
86 } else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
87 if (donate_with_move) {
88 out.move_shared_buffer(b);
89 } else {
90 out.copy_shared_buffer(b);
91 }
92 } else {
93 out.set_data(
94 allocator::malloc_or_wait(a.data_size() * out.itemsize()),
95 a.data_size(),
96 a.strides(),
97 a.flags());
98 }
99 break;
100 case BinaryOpType::General:
101 if (a.is_donatable() && a.flags().row_contiguous &&
102 a.itemsize() == out.itemsize() && a.size() == out.size()) {
103 if (donate_with_move) {
104 out.move_shared_buffer(a);
105 } else {
106 out.copy_shared_buffer(a);
107 }
108 } else if (
109 b.is_donatable() && b.flags().row_contiguous &&
110 b.itemsize() == out.itemsize() && b.size() == out.size()) {
111 if (donate_with_move) {
112 out.move_shared_buffer(b);
113 } else {
114 out.copy_shared_buffer(b);
115 }
116 } else {
117 out.set_data(allocator::malloc_or_wait(out.nbytes()));
118 }
119 break;
120 }
121}
122
123struct UseDefaultBinaryOp {
124 template <typename T, typename U>
125 void operator()(const T* a, const T* b, U* dst, int size) {
126 // Should we throw? This should normally never be called.
127 assert(false);
128 }
129
130 template <typename T, typename U>
131 void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
132 // Should we throw? This should normally never be called.
133 assert(false);
134 }
135};
136
137template <typename T, typename U, typename Op>
138struct DefaultVectorScalar {
139 Op op;
140
141 DefaultVectorScalar(Op op_) : op(op_) {}
142
143 void operator()(const T* a, const T* b, U* dst, int size) {
144 T scalar = *b;
145 while (size-- > 0) {
146 *dst = op(*a, scalar);
147 dst++;
148 a++;
149 }
150 }
151
152 void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
153 T scalar = *b;
154 while (size-- > 0) {
155 auto dst = op(*a, scalar);
156 *dst_a = dst.first;
157 *dst_b = dst.second;
158 dst_a++;
159 dst_b++;
160 a++;
161 }
162 }
163};
164
165template <typename T, typename U, typename Op>
166struct DefaultScalarVector {
167 Op op;
168
169 DefaultScalarVector(Op op_) : op(op_) {}
170
171 void operator()(const T* a, const T* b, U* dst, int size) {
172 T scalar = *a;
173 while (size-- > 0) {
174 *dst = op(scalar, *b);
175 dst++;
176 b++;
177 }
178 }
179
180 void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
181 T scalar = *a;
182 while (size-- > 0) {
183 auto dst = op(scalar, *b);
184 *dst_a = dst.first;
185 *dst_b = dst.second;
186 dst_a++;
187 dst_b++;
188 b++;
189 }
190 }
191};
192
193template <typename T, typename U, typename Op>
194struct DefaultVectorVector {
195 Op op;
196
197 DefaultVectorVector(Op op_) : op(op_) {}
198
199 void operator()(const T* a, const T* b, U* dst, int size) {
200 while (size-- > 0) {
201 *dst = op(*a, *b);
202 dst++;
203 a++;
204 b++;
205 }
206 }
207
208 void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
209 while (size-- > 0) {
210 auto dst = op(*a, *b);
211 *dst_a = dst.first;
212 *dst_b = dst.second;
213 dst_a++;
214 dst_b++;
215 a++;
216 b++;
217 }
218 }
219};
220
221template <typename T, typename U, typename Op>
222void binary_op_dims1(const array& a, const array& b, array& out, Op op) {
223 const T* a_ptr = a.data<T>();
224 const T* b_ptr = b.data<T>();
225 U* dst = out.data<U>();
226 size_t a_idx = 0;
227 size_t b_idx = 0;
228 for (size_t i = 0; i < out.size(); ++i) {
229 dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]);
230 a_idx += a.strides()[0];
231 b_idx += b.strides()[0];
232 }
233}
234
235template <typename T, typename U, typename Op>
236void binary_op_dims1(
237 const array& a,
238 const array& b,
239 array& out,
240 Op op,
241 int stride) {
242 const T* a_ptr = a.data<T>();
243 const T* b_ptr = b.data<T>();
244 U* dst = out.data<U>();
245 size_t a_idx = 0;
246 size_t b_idx = 0;
247 for (size_t i = 0; i < a.shape()[0]; i++) {
248 op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
249 a_idx += a.strides()[0];
250 b_idx += b.strides()[0];
251 dst += stride;
252 }
253}
254
255template <typename T, typename U, typename Op>
256void binary_op_dims2(const array& a, const array& b, array& out, Op op) {
257 const T* a_ptr = a.data<T>();
258 const T* b_ptr = b.data<T>();
259 U* dst = out.data<U>();
260 size_t a_idx = 0;
261 size_t b_idx = 0;
262 size_t out_idx = 0;
263 for (size_t i = 0; i < a.shape()[0]; ++i) {
264 for (size_t j = 0; j < a.shape()[1]; ++j) {
265 dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
266 a_idx += a.strides()[1];
267 b_idx += b.strides()[1];
268 }
269 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
270 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
271 }
272}
273
274template <typename T, typename U, typename Op>
275void binary_op_dims2(
276 const array& a,
277 const array& b,
278 array& out,
279 Op op,
280 int stride) {
281 const T* a_ptr = a.data<T>();
282 const T* b_ptr = b.data<T>();
283 U* dst = out.data<U>();
284 size_t a_idx = 0;
285 size_t b_idx = 0;
286 for (size_t i = 0; i < a.shape()[0]; ++i) {
287 for (size_t j = 0; j < a.shape()[1]; ++j) {
288 op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
289 a_idx += a.strides()[1];
290 b_idx += b.strides()[1];
291 dst += stride;
292 }
293 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
294 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
295 }
296}
297
298template <typename T, typename U, typename Op>
299void binary_op_dims3(const array& a, const array& b, array& out, Op op) {
300 const T* a_ptr = a.data<T>();
301 const T* b_ptr = b.data<T>();
302 U* dst = out.data<U>();
303 size_t a_idx = 0;
304 size_t b_idx = 0;
305 size_t out_idx = 0;
306 for (size_t i = 0; i < a.shape()[0]; ++i) {
307 for (size_t j = 0; j < a.shape()[1]; ++j) {
308 for (size_t k = 0; k < a.shape()[2]; ++k) {
309 dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
310 a_idx += a.strides()[2];
311 b_idx += b.strides()[2];
312 }
313 a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
314 b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
315 }
316 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
317 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
318 }
319}
320
321template <typename T, typename U, typename Op>
322void binary_op_dims4(const array& a, const array& b, array& out, Op op) {
323 const T* a_ptr = a.data<T>();
324 const T* b_ptr = b.data<T>();
325 U* dst = out.data<U>();
326 size_t a_idx = 0;
327 size_t b_idx = 0;
328 size_t out_idx = 0;
329 for (size_t i = 0; i < a.shape()[0]; ++i) {
330 for (size_t j = 0; j < a.shape()[1]; ++j) {
331 for (size_t k = 0; k < a.shape()[2]; ++k) {
332 for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
333 dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
334 a_idx += a.strides()[3];
335 b_idx += b.strides()[3];
336 }
337 a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
338 b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
339 }
340 a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
341 b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
342 }
343 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
344 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
345 }
346}
347
348template <typename T, typename U, typename Op>
349void binary_op_dispatch_dims(
350 const array& a,
351 const array& b,
352 array& out,
353 Op op) {
354 switch (out.ndim()) {
355 case 1:
356 binary_op_dims1<T, U, Op>(a, b, out, op);
357 return;
358 case 2:
359 binary_op_dims2<T, U, Op>(a, b, out, op);
360 return;
361 case 3:
362 binary_op_dims3<T, U, Op>(a, b, out, op);
363 return;
364 case 4:
365 binary_op_dims4<T, U, Op>(a, b, out, op);
366 return;
367 }
368
369 const T* a_ptr = a.data<T>();
370 const T* b_ptr = b.data<T>();
371 U* dst = out.data<U>();
372 for (size_t i = 0; i < out.size(); i++) {
373 int a_idx = elem_to_loc(i, a.shape(), a.strides());
374 int b_idx = elem_to_loc(i, b.shape(), b.strides());
375 dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]);
376 }
377}
378
379template <typename T, typename U, typename Op>
380void binary_op_dispatch_dims(
381 const array& a,
382 const array& b,
383 array& out,
384 Op op,
385 int dim,
386 int stride) {
387 // Number of dimensions to loop over for vectorized ops
388 switch (dim) {
389 case 1:
390 binary_op_dims1<T, U, Op>(a, b, out, op, stride);
391 return;
392 case 2:
393 binary_op_dims2<T, U, Op>(a, b, out, op, stride);
394 return;
395 }
396
397 const T* a_ptr = a.data<T>();
398 const T* b_ptr = b.data<T>();
399 U* dst = out.data<U>();
400 for (size_t i = 0; i < out.size(); i += stride) {
401 int a_idx = elem_to_loc(i, a.shape(), a.strides());
402 int b_idx = elem_to_loc(i, b.shape(), b.strides());
403 op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
404 dst += stride;
405 }
406}
407
408template <
409 typename T,
410 typename U,
411 typename Op,
412 typename OpSV,
413 typename OpVS,
414 typename OpVV>
415void binary_op(
416 const array& a,
417 const array& b,
418 array& out,
419 Op op,
420 OpSV opsv,
421 OpVS opvs,
422 OpVV opvv) {
423 auto bopt = get_binary_op_type(a, b);
424 set_binary_op_output_data(a, b, out, bopt);
425
426 // The full computation is scalar scalar so call the base op once
427 if (bopt == BinaryOpType::ScalarScalar) {
428 *(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
429 return;
430 }
431
432 // The full computation is scalar vector so delegate to the op
433 if (bopt == BinaryOpType::ScalarVector) {
434 opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
435 return;
436 }
437
438 // The full computation is vector scalar so delegate to the op
439 if (bopt == BinaryOpType::VectorScalar) {
440 opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
441 return;
442 }
443
444 // The full computation is vector vector so delegate to the op
445 if (bopt == BinaryOpType::VectorVector) {
446 opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
447 return;
448 }
449
450 // General computation so let's try to optimize
451
452 // Get the left-most dim such that the array is row contiguous after
453 auto& strides = out.strides();
454 auto leftmost_rc_dim = [&strides](const array& arr) {
455 int d = arr.ndim() - 1;
456 for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
457 }
458 return d + 1;
459 };
460 auto a_rc_dim = leftmost_rc_dim(a);
461 auto b_rc_dim = leftmost_rc_dim(b);
462
463 // Get the left-most dim such that the array is a broadcasted "scalar" after
464 auto leftmost_s_dim = [](const array& arr) {
465 int d = arr.ndim() - 1;
466 for (; d >= 0 && arr.strides()[d] == 0; d--) {
467 }
468 return d + 1;
469 };
470 auto a_s_dim = leftmost_s_dim(a);
471 auto b_s_dim = leftmost_s_dim(b);
472
473 auto ndim = out.ndim();
474
475 // Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
476 int dim = ndim;
477 if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
478 bopt = BinaryOpType::VectorVector;
479 dim = d;
480 // Case 2: LxM and Fx1 where L and F are broadcastable and M is row
481 // contiguous
482 } else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
483 bopt = BinaryOpType::VectorScalar;
484 dim = d;
485 // Case 3: Lx1 and FxM where L and F are broadcastable and M is row
486 // contiguous
487 } else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
488 bopt = BinaryOpType::ScalarVector;
489 dim = d;
490 }
491
492 // Can be sure dim > 0 since otherwise we would have used one of the fully
493 // contiguous methods above. Except for the case that the flags do not
494 // correspond to the underlying contiguity.
495 size_t stride;
496 if (dim == 0 || strides[dim - 1] < 16) {
497 stride = 1;
498 bopt = BinaryOpType::General;
499 dim = ndim;
500 } else {
501 stride = strides[dim - 1];
502 }
503
504 switch (bopt) {
505 case BinaryOpType::VectorVector:
506 binary_op_dispatch_dims<T, U>(a, b, out, opvv, dim, stride);
507 break;
508 case BinaryOpType::VectorScalar:
509 binary_op_dispatch_dims<T, U>(a, b, out, opvs, dim, stride);
510 break;
511 case BinaryOpType::ScalarVector:
512 binary_op_dispatch_dims<T, U>(a, b, out, opsv, dim, stride);
513 break;
514 default:
515 binary_op_dispatch_dims<T, U>(a, b, out, op);
516 break;
517 }
518}
519
520template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV>
521void binary_op(
522 const array& a,
523 const array& b,
524 array& out,
525 Op op,
526 OpSV opsv,
527 OpVS opvs,
528 OpVV opvv) {
529 // TODO: The following mess of constexpr evaluations can probably be achieved
530 // with template specializations and overloading. Would it be simpler?
531
532 if (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
533 if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
534 if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
535 // All ops are UseDefaultBinaryOp (why oh why would someone call that?)
536 binary_op<T, T>(
537 a,
538 b,
539 out,
540 op,
541 DefaultScalarVector<T, T, Op>(op),
542 DefaultVectorScalar<T, T, Op>(op),
543 DefaultVectorVector<T, T, Op>(op));
544 } else {
545 // opsv and opvs were UseDefaultBinaryOp
546 binary_op<T, T>(
547 a,
548 b,
549 out,
550 op,
551 DefaultScalarVector<T, T, Op>(op),
552 DefaultVectorScalar<T, T, Op>(op),
553 opvv);
554 }
555 } else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
556 // opsv and opvv were UseDefaultBinaryOp
557 binary_op<T, T>(
558 a,
559 b,
560 out,
561 op,
562 DefaultScalarVector<T, T, Op>(op),
563 opvs,
564 DefaultVectorVector<T, T, Op>(op));
565 } else {
566 // opsv was UseDefaultBinaryOp
567 binary_op<T, T>(
568 a, b, out, op, DefaultScalarVector<T, T, Op>(op), opvs, opvv);
569 }
570 } else if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
571 if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
572 // opvs and opvv were UseDefaultBinaryOp
573 binary_op<T, T>(
574 a,
575 b,
576 out,
577 op,
578 opsv,
579 DefaultVectorScalar<T, T, Op>(op),
580 DefaultVectorVector<T, T, Op>(op));
581 } else {
582 // opvs was UseDefaultBinaryOp
583 binary_op<T, T>(
584 a, b, out, op, opsv, DefaultVectorScalar<T, T, Op>(op), opvv);
585 }
586 } else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
587 // opvv was UseDefaultBinaryOp
588 binary_op<T, T>(
589 a, b, out, op, opsv, opvs, DefaultVectorVector<T, T, Op>(op));
590 } else {
591 // All ops provided
592 binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
593 }
594}
595
596template <typename T, typename Op>
597void binary_op(const array& a, const array& b, array& out, Op op) {
598 DefaultScalarVector<T, T, Op> opsv(op);
599 DefaultVectorScalar<T, T, Op> opvs(op);
600 DefaultVectorVector<T, T, Op> opvv(op);
601 binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
602}
603
604template <typename... Ops>
605void binary(const array& a, const array& b, array& out, Ops... ops) {
606 switch (out.dtype()) {
607 case bool_:
608 binary_op<bool>(a, b, out, ops...);
609 break;
610 case uint8:
611 binary_op<uint8_t>(a, b, out, ops...);
612 break;
613 case uint16:
614 binary_op<uint16_t>(a, b, out, ops...);
615 break;
616 case uint32:
617 binary_op<uint32_t>(a, b, out, ops...);
618 break;
619 case uint64:
620 binary_op<uint64_t>(a, b, out, ops...);
621 break;
622 case int8:
623 binary_op<int8_t>(a, b, out, ops...);
624 break;
625 case int16:
626 binary_op<int16_t>(a, b, out, ops...);
627 break;
628 case int32:
629 binary_op<int32_t>(a, b, out, ops...);
630 break;
631 case int64:
632 binary_op<int64_t>(a, b, out, ops...);
633 break;
634 case float16:
635 binary_op<float16_t>(a, b, out, ops...);
636 break;
637 case float32:
638 binary_op<float>(a, b, out, ops...);
639 break;
640 case bfloat16:
641 binary_op<bfloat16_t>(a, b, out, ops...);
642 break;
643 case complex64:
644 binary_op<complex64_t>(a, b, out, ops...);
645 break;
646 }
647}
648
649} // namespace
650
651} // namespace mlx::core
Op op
Definition binary.h:139
Buffer malloc_or_wait(size_t size)
Definition allocator.h:7
constexpr Dtype bool_
Definition dtype.h:60
constexpr Dtype uint64
Definition dtype.h:65
constexpr Dtype uint16
Definition dtype.h:63
stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:12
constexpr Dtype bfloat16
Definition dtype.h:74
constexpr Dtype int32
Definition dtype.h:69
constexpr Dtype float32
Definition dtype.h:73
constexpr Dtype int16
Definition dtype.h:68
constexpr Dtype int8
Definition dtype.h:67
constexpr Dtype int64
Definition dtype.h:70
constexpr Dtype uint8
Definition dtype.h:62
constexpr Dtype float16
Definition dtype.h:72
constexpr Dtype uint32
Definition dtype.h:64
constexpr Dtype complex64
Definition dtype.h:75