MLX
Loading...
Searching...
No Matches
binary_two.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
7
8namespace mlx::core {
9
10namespace {
11
12template <typename T, typename U, typename Op>
13void binary_op_dims1(
14 const array& a,
15 const array& b,
16 array& out_a,
17 array& out_b,
18 Op op) {
19 const T* a_ptr = a.data<T>();
20 const T* b_ptr = b.data<T>();
21 U* dst_a = out_a.data<U>();
22 U* dst_b = out_b.data<U>();
23 size_t a_idx = 0;
24 size_t b_idx = 0;
25 for (size_t i = 0; i < out_a.size(); ++i) {
26 auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
27 dst_a[i] = dst.first;
28 dst_b[i] = dst.second;
29 a_idx += a.strides()[0];
30 b_idx += b.strides()[0];
31 }
32}
33
34template <typename T, typename U, typename Op>
35void binary_op_dims1(
36 const array& a,
37 const array& b,
38 array& out_a,
39 array& out_b,
40 Op op,
41 int stride) {
42 const T* a_ptr = a.data<T>();
43 const T* b_ptr = b.data<T>();
44 U* dst_a = out_a.data<U>();
45 U* dst_b = out_b.data<U>();
46 size_t a_idx = 0;
47 size_t b_idx = 0;
48 for (size_t i = 0; i < a.shape()[0]; i++) {
49 op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
50 a_idx += a.strides()[0];
51 b_idx += b.strides()[0];
52 dst_a += stride;
53 dst_b += stride;
54 }
55}
56
57template <typename T, typename U, typename Op>
58void binary_op_dims2(
59 const array& a,
60 const array& b,
61 array& out_a,
62 array& out_b,
63 Op op) {
64 const T* a_ptr = a.data<T>();
65 const T* b_ptr = b.data<T>();
66 U* dst_a = out_a.data<U>();
67 U* dst_b = out_b.data<U>();
68 size_t a_idx = 0;
69 size_t b_idx = 0;
70 size_t out_idx = 0;
71 for (size_t i = 0; i < a.shape()[0]; ++i) {
72 for (size_t j = 0; j < a.shape()[1]; ++j) {
73 auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
74 dst_a[out_idx] = dst.first;
75 dst_b[out_idx++] = dst.second;
76 a_idx += a.strides()[1];
77 b_idx += b.strides()[1];
78 }
79 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
80 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
81 }
82}
83
84template <typename T, typename U, typename Op>
85void binary_op_dims2(
86 const array& a,
87 const array& b,
88 array& out_a,
89 array& out_b,
90 Op op,
91 int stride) {
92 const T* a_ptr = a.data<T>();
93 const T* b_ptr = b.data<T>();
94 U* dst_a = out_a.data<U>();
95 U* dst_b = out_b.data<U>();
96 size_t a_idx = 0;
97 size_t b_idx = 0;
98 for (size_t i = 0; i < a.shape()[0]; ++i) {
99 for (size_t j = 0; j < a.shape()[1]; ++j) {
100 op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
101 a_idx += a.strides()[1];
102 b_idx += b.strides()[1];
103 dst_a += stride;
104 dst_b += stride;
105 }
106 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
107 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
108 }
109}
110
111template <typename T, typename U, typename Op>
112void binary_op_dims3(
113 const array& a,
114 const array& b,
115 array& out_a,
116 array& out_b,
117 Op op) {
118 const T* a_ptr = a.data<T>();
119 const T* b_ptr = b.data<T>();
120 U* dst_a = out_a.data<U>();
121 U* dst_b = out_b.data<U>();
122 size_t a_idx = 0;
123 size_t b_idx = 0;
124 size_t out_idx = 0;
125 for (size_t i = 0; i < a.shape()[0]; ++i) {
126 for (size_t j = 0; j < a.shape()[1]; ++j) {
127 for (size_t k = 0; k < a.shape()[2]; ++k) {
128 auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
129 dst_a[out_idx] = dst.first;
130 dst_b[out_idx++] = dst.second;
131 a_idx += a.strides()[2];
132 b_idx += b.strides()[2];
133 }
134 a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
135 b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
136 }
137 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
138 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
139 }
140}
141
142template <typename T, typename U, typename Op>
143void binary_op_dims4(
144 const array& a,
145 const array& b,
146 array& out_a,
147 array& out_b,
148 Op op) {
149 const T* a_ptr = a.data<T>();
150 const T* b_ptr = b.data<T>();
151 U* dst_a = out_a.data<U>();
152 U* dst_b = out_b.data<U>();
153 size_t a_idx = 0;
154 size_t b_idx = 0;
155 size_t out_idx = 0;
156 for (size_t i = 0; i < a.shape()[0]; ++i) {
157 for (size_t j = 0; j < a.shape()[1]; ++j) {
158 for (size_t k = 0; k < a.shape()[2]; ++k) {
159 for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
160 auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
161 dst_a[out_idx] = dst.first;
162 dst_b[out_idx++] = dst.second;
163 a_idx += a.strides()[3];
164 b_idx += b.strides()[3];
165 }
166 a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
167 b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
168 }
169 a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
170 b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
171 }
172 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
173 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
174 }
175}
176
177template <typename T, typename U, typename Op>
178void binary_op_dispatch_dims(
179 const array& a,
180 const array& b,
181 array& out_a,
182 array& out_b,
183 Op op) {
184 switch (out_a.ndim()) {
185 case 1:
186 binary_op_dims1<T, U, Op>(a, b, out_a, out_b, op);
187 return;
188 case 2:
189 binary_op_dims2<T, U, Op>(a, b, out_a, out_b, op);
190 return;
191 case 3:
192 binary_op_dims3<T, U, Op>(a, b, out_a, out_b, op);
193 return;
194 case 4:
195 binary_op_dims4<T, U, Op>(a, b, out_a, out_b, op);
196 return;
197 }
198
199 const T* a_ptr = a.data<T>();
200 const T* b_ptr = b.data<T>();
201 U* dst_a = out_a.data<U>();
202 U* dst_b = out_b.data<U>();
203 for (size_t i = 0; i < out_a.size(); i++) {
204 int a_idx = elem_to_loc(i, a.shape(), a.strides());
205 int b_idx = elem_to_loc(i, b.shape(), b.strides());
206 std::tie(dst_a[i], dst_b[i]) = op(a_ptr[a_idx], b_ptr[b_idx]);
207 }
208}
209
210template <typename T, typename U, typename Op>
211void binary_op_dispatch_dims(
212 const array& a,
213 const array& b,
214 array& out_a,
215 array& out_b,
216 Op op,
217 int dim,
218 int stride) {
219 // Number of dimensions to loop over for vectorized ops
220 switch (dim) {
221 case 1:
222 binary_op_dims1<T, U, Op>(a, b, out_a, out_b, op, stride);
223 return;
224 case 2:
225 binary_op_dims2<T, U, Op>(a, b, out_a, out_b, op, stride);
226 return;
227 }
228
229 const T* a_ptr = a.data<T>();
230 const T* b_ptr = b.data<T>();
231 U* dst_a = out_a.data<U>();
232 U* dst_b = out_b.data<U>();
233 for (size_t i = 0; i < out_a.size(); i += stride) {
234 int a_idx = elem_to_loc(i, a.shape(), a.strides());
235 int b_idx = elem_to_loc(i, b.shape(), b.strides());
236 op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
237 dst_a += stride;
238 dst_b += stride;
239 }
240}
241
242template <
243 typename T,
244 typename U,
245 typename Op,
246 typename OpSV,
247 typename OpVS,
248 typename OpVV>
249void binary_op(
250 const array& a,
251 const array& b,
252 array& out_a,
253 array& out_b,
254 Op op,
255 OpSV opsv,
256 OpVS opvs,
257 OpVV opvv) {
258 auto bopt = get_binary_op_type(a, b);
259 set_binary_op_output_data(a, b, out_a, bopt);
260 set_binary_op_output_data(a, b, out_b, bopt);
261
262 // The full computation is scalar scalar so call the base op once
263 if (bopt == BinaryOpType::ScalarScalar) {
264 std::tie(*(out_a.data<U>()), *(out_b.data<U>())) =
265 op(*a.data<T>(), *b.data<T>());
266 return;
267 }
268
269 // The full computation is scalar vector so delegate to the op
270 if (bopt == BinaryOpType::ScalarVector) {
271 opsv(
272 a.data<T>(),
273 b.data<T>(),
274 out_a.data<U>(),
275 out_b.data<U>(),
276 b.data_size());
277 return;
278 }
279
280 // The full computation is vector scalar so delegate to the op
281 if (bopt == BinaryOpType::VectorScalar) {
282 opvs(
283 a.data<T>(),
284 b.data<T>(),
285 out_a.data<U>(),
286 out_b.data<U>(),
287 a.data_size());
288 return;
289 }
290
291 // The full computation is vector vector so delegate to the op
292 if (bopt == BinaryOpType::VectorVector) {
293 opvv(
294 a.data<T>(),
295 b.data<T>(),
296 out_a.data<U>(),
297 out_b.data<U>(),
298 out_a.size());
299 return;
300 }
301
302 // General computation so let's try to optimize
303
304 // Get the left-most dim such that the array is row contiguous after
305 auto& strides = out_a.strides();
306 auto leftmost_rc_dim = [&strides](const array& arr) {
307 int d = arr.ndim() - 1;
308 for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
309 }
310 return d + 1;
311 };
312 auto a_rc_dim = leftmost_rc_dim(a);
313 auto b_rc_dim = leftmost_rc_dim(b);
314
315 // Get the left-most dim such that the array is a broadcasted "scalar" after
316 auto leftmost_s_dim = [](const array& arr) {
317 int d = arr.ndim() - 1;
318 for (; d >= 0 && arr.strides()[d] == 0; d--) {
319 }
320 return d + 1;
321 };
322 auto a_s_dim = leftmost_s_dim(a);
323 auto b_s_dim = leftmost_s_dim(b);
324
325 auto ndim = out_a.ndim();
326
327 // Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
328 int dim = ndim;
329 if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
330 bopt = BinaryOpType::VectorVector;
331 dim = d;
332 // Case 2: LxM and Fx1 where L and F are broadcastable and M is row
333 // contiguous
334 } else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
335 bopt = BinaryOpType::VectorScalar;
336 dim = d;
337 // Case 3: Lx1 and FxM where L and F are broadcastable and M is row
338 // contiguous
339 } else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
340 bopt = BinaryOpType::ScalarVector;
341 dim = d;
342 }
343
344 // Can be sure dim > 0 since otherwise we would have used one of the fully
345 // contiguous methods above. Except for the case that the flags do not
346 // correspond to the underlying contiguity.
347 size_t stride;
348 if (dim == 0 || strides[dim - 1] < 16) {
349 stride = 1;
350 bopt = BinaryOpType::General;
351 dim = ndim;
352 } else {
353 stride = strides[dim - 1];
354 }
355
356 switch (bopt) {
357 case BinaryOpType::VectorVector:
358 binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvv, dim, stride);
359 break;
360 case BinaryOpType::VectorScalar:
361 binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvs, dim, stride);
362 break;
363 case BinaryOpType::ScalarVector:
364 binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opsv, dim, stride);
365 break;
366 default:
367 binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, op);
368 break;
369 }
370}
371
372template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV>
373void binary_op(
374 const array& a,
375 const array& b,
376 std::vector<array>& outputs,
377 Op op,
378 OpSV opsv,
379 OpVS opvs,
380 OpVV opvv) {
381 // TODO: The following mess of constexpr evaluations can probably be achieved
382 // with template specializations and overloading. Would it be simpler?
383
384 if (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
385 if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
386 if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
387 // All ops are UseDefaultBinaryOp (why oh why would someone call that?)
388 binary_op<T, T>(
389 a,
390 b,
391 outputs[0],
392 outputs[1],
393 op,
394 DefaultScalarVector<T, T, Op>(op),
395 DefaultVectorScalar<T, T, Op>(op),
396 DefaultVectorVector<T, T, Op>(op));
397 } else {
398 // opsv and opvs were UseDefaultBinaryOp
399 binary_op<T, T>(
400 a,
401 b,
402 outputs[0],
403 outputs[1],
404 op,
405 DefaultScalarVector<T, T, Op>(op),
406 DefaultVectorScalar<T, T, Op>(op),
407 opvv);
408 }
409 } else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
410 // opsv and opvv were UseDefaultBinaryOp
411 binary_op<T, T>(
412 a,
413 b,
414 outputs[0],
415 outputs[1],
416 op,
417 DefaultScalarVector<T, T, Op>(op),
418 opvs,
419 DefaultVectorVector<T, T, Op>(op));
420 } else {
421 // opsv was UseDefaultBinaryOp
422 binary_op<T, T>(
423 a,
424 b,
425 outputs[0],
426 outputs[1],
427 op,
428 DefaultScalarVector<T, T, Op>(op),
429 opvs,
430 opvv);
431 }
432 } else if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
433 if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
434 // opvs and opvv were UseDefaultBinaryOp
435 binary_op<T, T>(
436 a,
437 b,
438 outputs[0],
439 outputs[1],
440 op,
441 opsv,
442 DefaultVectorScalar<T, T, Op>(op),
443 DefaultVectorVector<T, T, Op>(op));
444 } else {
445 // opvs was UseDefaultBinaryOp
446 binary_op<T, T>(
447 a,
448 b,
449 outputs[0],
450 outputs[1],
451 op,
452 opsv,
453 DefaultVectorScalar<T, T, Op>(op),
454 opvv);
455 }
456 } else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
457 // opvv was UseDefaultBinaryOp
458 binary_op<T, T>(
459 a,
460 b,
461 outputs[0],
462 outputs[1],
463 op,
464 opsv,
465 opvs,
466 DefaultVectorVector<T, T, Op>(op));
467 } else {
468 // All ops provided
469 binary_op<T, T>(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv);
470 }
471}
472
473template <typename T, typename Op>
474void binary_op(
475 const array& a,
476 const array& b,
477 std::vector<array>& outputs,
478 Op op) {
479 DefaultScalarVector<T, T, Op> opsv(op);
480 DefaultVectorScalar<T, T, Op> opvs(op);
481 DefaultVectorVector<T, T, Op> opvv(op);
482 binary_op<T, T>(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv);
483}
484
485template <typename... Ops>
486void binary(
487 const array& a,
488 const array& b,
489 std::vector<array>& outputs,
490 Ops... ops) {
491 switch (outputs[0].dtype()) {
492 case bool_:
493 binary_op<bool>(a, b, outputs, ops...);
494 break;
495 case uint8:
496 binary_op<uint8_t>(a, b, outputs, ops...);
497 break;
498 case uint16:
499 binary_op<uint16_t>(a, b, outputs, ops...);
500 break;
501 case uint32:
502 binary_op<uint32_t>(a, b, outputs, ops...);
503 break;
504 case uint64:
505 binary_op<uint64_t>(a, b, outputs, ops...);
506 break;
507 case int8:
508 binary_op<int8_t>(a, b, outputs, ops...);
509 break;
510 case int16:
511 binary_op<int16_t>(a, b, outputs, ops...);
512 break;
513 case int32:
514 binary_op<int32_t>(a, b, outputs, ops...);
515 break;
516 case int64:
517 binary_op<int64_t>(a, b, outputs, ops...);
518 break;
519 case float16:
520 binary_op<float16_t>(a, b, outputs, ops...);
521 break;
522 case float32:
523 binary_op<float>(a, b, outputs, ops...);
524 break;
525 case bfloat16:
526 binary_op<bfloat16_t>(a, b, outputs, ops...);
527 break;
528 case complex64:
529 binary_op<complex64_t>(a, b, outputs, ops...);
530 break;
531 }
532}
533
534} // namespace
535
536} // namespace mlx::core
Op op
Definition binary.h:141
const char * binary()
Definition allocator.h:7
constexpr Dtype bool_
Definition dtype.h:60
constexpr Dtype uint64
Definition dtype.h:65
constexpr Dtype uint16
Definition dtype.h:63
stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:12
constexpr Dtype bfloat16
Definition dtype.h:74
constexpr Dtype int32
Definition dtype.h:69
constexpr Dtype float32
Definition dtype.h:73
constexpr Dtype int16
Definition dtype.h:68
constexpr Dtype int8
Definition dtype.h:67
constexpr Dtype int64
Definition dtype.h:70
constexpr Dtype uint8
Definition dtype.h:62
constexpr Dtype float16
Definition dtype.h:72
constexpr Dtype uint32
Definition dtype.h:64
constexpr Dtype complex64
Definition dtype.h:75