MLX
Loading...
Searching...
No Matches
binary_two.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
7
8namespace mlx::core {
9
10namespace {
11
12template <typename T, typename U, typename Op, int D>
13void binary_op_dims(
14 const T* a,
15 const T* b,
16 U* out_a,
17 U* out_b,
18 Op op,
19 const std::vector<int>& shape,
20 const std::vector<size_t>& a_strides,
21 const std::vector<size_t>& b_strides,
22 const std::vector<size_t>& out_strides,
23 int axis) {
24 auto stride_a = a_strides[axis];
25 auto stride_b = b_strides[axis];
26 auto stride_out = out_strides[axis];
27 auto N = shape[axis];
28
29 for (int i = 0; i < N; i++) {
30 if constexpr (D > 1) {
31 binary_op_dims<T, U, Op, D - 1>(
32 a,
33 b,
34 out_a,
35 out_b,
36 op,
37 shape,
38 a_strides,
39 b_strides,
40 out_strides,
41 axis + 1);
42 } else {
43 std::tie(*out_a, *out_b) = op(*a, *b);
44 }
45 a += stride_a;
46 b += stride_b;
47 out_a += stride_out;
48 out_b += stride_out;
49 }
50}
51
52template <typename T, typename U, typename Op>
53void binary_op_dispatch_dims(
54 const array& a,
55 const array& b,
56 array& out_a,
57 array& out_b,
58 Op op) {
59 auto [shape, strides] = collapse_contiguous_dims(
60 a.shape(), {a.strides(), b.strides(), out_a.strides()});
61 const auto& a_strides = strides[0];
62 const auto& b_strides = strides[1];
63 const auto& out_strides = strides[2];
64 const T* a_ptr = a.data<T>();
65 const T* b_ptr = b.data<T>();
66 U* out_a_ptr = out_a.data<U>();
67 U* out_b_ptr = out_b.data<U>();
68
69 int ndim = shape.size();
70 switch (ndim) {
71 case 1:
72 binary_op_dims<T, U, Op, 1>(
73 a_ptr,
74 b_ptr,
75 out_a_ptr,
76 out_b_ptr,
77 op,
78 shape,
79 a_strides,
80 b_strides,
81 out_strides,
82 0);
83 return;
84 case 2:
85 binary_op_dims<T, U, Op, 2>(
86 a_ptr,
87 b_ptr,
88 out_a_ptr,
89 out_b_ptr,
90 op,
91 shape,
92 a_strides,
93 b_strides,
94 out_strides,
95 0);
96 return;
97 }
98
99 ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2);
100 ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2);
101 size_t stride = out_strides[ndim - 3];
102 for (size_t elem = 0; elem < a.size(); elem += stride) {
103 binary_op_dims<T, U, Op, 2>(
104 a_ptr + a_it.loc,
105 b_ptr + b_it.loc,
106 out_a_ptr + elem,
107 out_b_ptr + elem,
108 op,
109 shape,
110 a_strides,
111 b_strides,
112 out_strides,
113 ndim - 2);
114 a_it.step();
115 b_it.step();
116 }
117}
118
119template <typename T, typename U = T, typename Op>
120void binary_op(
121 const array& a,
122 const array& b,
123 std::vector<array>& outputs,
124 Op op) {
125 auto bopt = get_binary_op_type(a, b);
126 auto& out_a = outputs[0];
127 auto& out_b = outputs[1];
128 set_binary_op_output_data(a, b, out_a, bopt);
129 set_binary_op_output_data(a, b, out_b, bopt);
130
131 // The full computation is scalar scalar so call the base op once
132 if (bopt == BinaryOpType::General) {
133 binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, op);
134 return;
135 }
136
137 auto a_ptr = a.data<T>();
138 auto b_ptr = b.data<T>();
139 auto out_a_ptr = out_a.data<U>();
140 auto out_b_ptr = out_b.data<U>();
141 if (bopt == BinaryOpType::ScalarScalar) {
142 std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
143 } else if (bopt == BinaryOpType::ScalarVector) {
144 for (size_t i = 0; i < b.size(); ++i) {
145 std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
146 out_a_ptr++;
147 out_b_ptr++;
148 b_ptr++;
149 }
150 } else if (bopt == BinaryOpType::VectorScalar) {
151 for (size_t i = 0; i < a.size(); ++i) {
152 std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
153 out_a_ptr++;
154 out_b_ptr++;
155 a_ptr++;
156 }
157 } else { // VectorVector
158 for (size_t i = 0; i < a.size(); ++i) {
159 std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
160 out_a_ptr++;
161 out_b_ptr++;
162 a_ptr++;
163 b_ptr++;
164 }
165 }
166}
167
168template <typename Op>
169void binary(
170 const array& a,
171 const array& b,
172 std::vector<array>& outputs,
173 Op op) {
174 switch (outputs[0].dtype()) {
175 case bool_:
176 binary_op<bool>(a, b, outputs, op);
177 break;
178 case uint8:
179 binary_op<uint8_t>(a, b, outputs, op);
180 break;
181 case uint16:
182 binary_op<uint16_t>(a, b, outputs, op);
183 break;
184 case uint32:
185 binary_op<uint32_t>(a, b, outputs, op);
186 break;
187 case uint64:
188 binary_op<uint64_t>(a, b, outputs, op);
189 break;
190 case int8:
191 binary_op<int8_t>(a, b, outputs, op);
192 break;
193 case int16:
194 binary_op<int16_t>(a, b, outputs, op);
195 break;
196 case int32:
197 binary_op<int32_t>(a, b, outputs, op);
198 break;
199 case int64:
200 binary_op<int64_t>(a, b, outputs, op);
201 break;
202 case float16:
203 binary_op<float16_t>(a, b, outputs, op);
204 break;
205 case float32:
206 binary_op<float>(a, b, outputs, op);
207 break;
208 case bfloat16:
209 binary_op<bfloat16_t>(a, b, outputs, op);
210 break;
211 case complex64:
212 binary_op<complex64_t>(a, b, outputs, op);
213 break;
214 }
215}
216
217} // namespace
218
219} // namespace mlx::core
Op op
Definition binary.h:129
const char * binary()
Definition allocator.h:7
constexpr Dtype bool_
Definition dtype.h:67
constexpr Dtype uint64
Definition dtype.h:72
constexpr Dtype uint16
Definition dtype.h:70
std::tuple< std::vector< int >, std::vector< std::vector< int64_t > > > collapse_contiguous_dims(const std::vector< int > &shape, const std::vector< std::vector< int64_t > > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())
constexpr Dtype bfloat16
Definition dtype.h:81
constexpr Dtype int32
Definition dtype.h:76
constexpr Dtype float32
Definition dtype.h:80
constexpr Dtype int16
Definition dtype.h:75
constexpr Dtype int8
Definition dtype.h:74
constexpr Dtype int64
Definition dtype.h:77
constexpr Dtype uint8
Definition dtype.h:69
constexpr Dtype float16
Definition dtype.h:79
constexpr Dtype uint32
Definition dtype.h:71
constexpr Dtype complex64
Definition dtype.h:82