MLX
 
Loading...
Searching...
No Matches
binary_two.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
7
8namespace mlx::core {
9
10namespace {
11
12template <typename T, typename U, typename Op, int D>
14 const T* a,
15 const T* b,
16 U* out_a,
17 U* out_b,
18 Op op,
19 const Shape& shape,
20 const Strides& a_strides,
21 const Strides& b_strides,
22 const Strides& out_strides,
23 int axis) {
24 auto stride_a = a_strides[axis];
25 auto stride_b = b_strides[axis];
26 auto stride_out = out_strides[axis];
27 auto N = shape[axis];
28
29 for (int i = 0; i < N; i++) {
30 if constexpr (D > 1) {
31 binary_op_dims<T, U, Op, D - 1>(
32 a,
33 b,
34 out_a,
35 out_b,
36 op,
37 shape,
38 a_strides,
39 b_strides,
40 out_strides,
41 axis + 1);
42 } else {
43 std::tie(*out_a, *out_b) = op(*a, *b);
44 }
45 a += stride_a;
46 b += stride_b;
47 out_a += stride_out;
48 out_b += stride_out;
49 }
50}
51
52template <typename T, typename U, typename Op>
54 const array& a,
55 const array& b,
56 array& out_a,
57 array& out_b,
58 Op op) {
59 auto [shape, strides] = collapse_contiguous_dims(
60 a.shape(), {a.strides(), b.strides(), out_a.strides()});
61 const T* a_ptr = a.data<T>();
62 const T* b_ptr = b.data<T>();
63 U* out_a_ptr = out_a.data<U>();
64 U* out_b_ptr = out_b.data<U>();
65
66 const auto& a_strides = strides[0];
67 const auto& b_strides = strides[1];
68 const auto& out_strides = strides[2];
69 int ndim = shape.size();
70 switch (ndim) {
71 case 1:
73 a_ptr,
74 b_ptr,
75 out_a_ptr,
76 out_b_ptr,
77 op,
78 shape,
79 a_strides,
80 b_strides,
81 out_strides,
82 0);
83 return;
84 case 2:
86 a_ptr,
87 b_ptr,
88 out_a_ptr,
89 out_b_ptr,
90 op,
91 shape,
92 a_strides,
93 b_strides,
94 out_strides,
95 0);
96 return;
97 }
98
99 ContiguousIterator a_it(shape, a_strides, ndim - 2);
100 ContiguousIterator b_it(shape, b_strides, ndim - 2);
101 auto stride = out_strides[ndim - 3];
102 for (size_t elem = 0; elem < a.size(); elem += stride) {
104 a_ptr + a_it.loc,
105 b_ptr + b_it.loc,
106 out_a_ptr + elem,
107 out_b_ptr + elem,
108 op,
109 shape,
110 a_strides,
111 b_strides,
112 out_strides,
113 ndim - 2);
114 a_it.step();
115 b_it.step();
116 }
117}
118
119template <typename T, typename U = T, typename Op>
120void binary_op(
121 const array& a,
122 const array& b,
123 array& out_a,
124 array& out_b,
125 Op op,
126 BinaryOpType bopt) {
127 // The full computation is scalar scalar so call the base op once
128 if (bopt == BinaryOpType::General) {
129 binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, op);
130 return;
131 }
132
133 auto a_ptr = a.data<T>();
134 auto b_ptr = b.data<T>();
135 auto out_a_ptr = out_a.data<U>();
136 auto out_b_ptr = out_b.data<U>();
137 if (bopt == BinaryOpType::ScalarScalar) {
138 std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
139 } else if (bopt == BinaryOpType::ScalarVector) {
140 for (size_t i = 0; i < b.data_size(); ++i) {
141 std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
142 out_a_ptr++;
143 out_b_ptr++;
144 b_ptr++;
145 }
146 } else if (bopt == BinaryOpType::VectorScalar) {
147 for (size_t i = 0; i < a.data_size(); ++i) {
148 std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
149 out_a_ptr++;
150 out_b_ptr++;
151 a_ptr++;
152 }
153 } else { // VectorVector
154 for (size_t i = 0; i < a.size(); ++i) {
155 std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
156 out_a_ptr++;
157 out_b_ptr++;
158 a_ptr++;
159 b_ptr++;
160 }
161 }
162}
163
164} // namespace
165
166} // namespace mlx::core
Definition array.h:24
constexpr int N
Definition neon_fp16_simd.h:9
Definition allocator.h:7
std::tuple< Shape, std::vector< Strides > > collapse_contiguous_dims(const Shape &shape, const std::vector< Strides > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())
BinaryOpType
Definition binary.h:11
@ General
Definition binary.h:16
@ ScalarScalar
Definition binary.h:12
@ VectorScalar
Definition binary.h:14
@ ScalarVector
Definition binary.h:13
std::vector< ShapeElem > Shape
Definition array.h:21
std::vector< int64_t > Strides
Definition array.h:22
void binary_op_dispatch_dims(const T *a, const T *b, U *out, int dim, int size, const Shape &shape, const Strides &a_strides, const Strides &b_strides, const Strides &out_strides)
Definition binary.h:108
void binary_op(const array &a, const array &b, array &out, BinaryOpType bopt)
Definition binary.h:152
void binary_op_dims(const T *a, const T *b, U *out, const Shape &shape, const Strides &a_strides, const Strides &b_strides, const Strides &out_strides, int axis)
Definition binary.h:76
Definition utils.h:73