MLX
 
Loading...
Searching...
No Matches
binary.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4#include <cassert>
5
6#include "mlx/array.h"
9
11
12namespace mlx::core {
13
14template <typename Op>
16 template <typename T, typename U>
17 void operator()(const T* a, const T* b, U* dst, int size) {
18 T scalar = *b;
19 constexpr int N = simd::max_size<T>;
20 while (size >= N) {
21 simd::store(dst, Op{}(simd::load<T, N>(a), simd::Simd<T, N>(scalar)));
22 dst += N;
23 a += N;
24 size -= N;
25 }
26 while (size-- > 0) {
27 *dst = Op{}(*a, scalar);
28 dst++;
29 a++;
30 }
31 }
32};
33
34template <typename Op>
36 template <typename T, typename U>
37 void operator()(const T* a, const T* b, U* dst, int size) {
38 T scalar = *a;
39 constexpr int N = simd::max_size<T>;
40 while (size >= N) {
41 simd::store(dst, Op{}(simd::Simd<T, N>(scalar), simd::load<T, N>(b)));
42 dst += N;
43 b += N;
44 size -= N;
45 }
46 while (size-- > 0) {
47 *dst = Op{}(scalar, *b);
48 dst++;
49 b++;
50 }
51 }
52};
53
54template <typename Op>
56 template <typename T, typename U>
57 void operator()(const T* a, const T* b, U* dst, int size) {
58 constexpr int N = simd::max_size<T>;
59 while (size >= N) {
61 dst += N;
62 a += N;
63 b += N;
64 size -= N;
65 }
66 while (size-- > 0) {
67 *dst = Op{}(*a, *b);
68 dst++;
69 a++;
70 b++;
71 }
72 }
73};
74
75template <typename T, typename U, typename Op, int D, bool Strided>
77 const T* a,
78 const T* b,
79 U* out,
80 const Shape& shape,
81 const Strides& a_strides,
82 const Strides& b_strides,
83 const Strides& out_strides,
84 int axis) {
85 auto stride_a = a_strides[axis];
86 auto stride_b = b_strides[axis];
87 auto stride_out = out_strides[axis];
88 auto N = shape[axis];
89
90 for (int i = 0; i < N; i++) {
91 if constexpr (D > 1) {
92 binary_op_dims<T, U, Op, D - 1, Strided>(
93 a, b, out, shape, a_strides, b_strides, out_strides, axis + 1);
94 } else {
95 if constexpr (Strided) {
96 Op{}(a, b, out, stride_out);
97 } else {
98 *out = Op{}(*a, *b);
99 }
100 }
101 out += stride_out;
102 a += stride_a;
103 b += stride_b;
104 }
105}
106
107template <typename T, typename U, bool Strided, typename Op>
109 const T* a,
110 const T* b,
111 U* out,
112 int dim,
113 int size,
114 const Shape& shape,
115 const Strides& a_strides,
116 const Strides& b_strides,
117 const Strides& out_strides) {
118 switch (dim) {
119 case 1:
121 a, b, out, shape, a_strides, b_strides, out_strides, 0);
122 return;
123 case 2:
125 a, b, out, shape, a_strides, b_strides, out_strides, 0);
126 return;
127 case 3:
129 a, b, out, shape, a_strides, b_strides, out_strides, 0);
130 return;
131 }
132
133 ContiguousIterator a_it(shape, a_strides, dim - 3);
134 ContiguousIterator b_it(shape, b_strides, dim - 3);
135 auto stride = out_strides[dim - 4];
136 for (int64_t elem = 0; elem < size; elem += stride) {
138 a + a_it.loc,
139 b + b_it.loc,
140 out + elem,
141 shape,
142 a_strides,
143 b_strides,
144 out_strides,
145 dim - 3);
146 a_it.step();
147 b_it.step();
148 }
149}
150
151template <typename T, typename U, typename Op>
152void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
153 // The full computation is scalar scalar so call the base op once
154 auto a_ptr = a.data<T>();
155 auto b_ptr = b.data<T>();
156
157 auto out_ptr = out.data<U>();
158 if (bopt == BinaryOpType::ScalarScalar) {
159 *out_ptr = Op{}(*a_ptr, *b_ptr);
160 return;
161 }
162
163 // The full computation is scalar vector so delegate to the op
164 if (bopt == BinaryOpType::ScalarVector) {
165 ScalarVector<Op>{}(a_ptr, b_ptr, out_ptr, b.data_size());
166 return;
167 }
168
169 // The full computation is vector scalar so delegate to the op
170 if (bopt == BinaryOpType::VectorScalar) {
171 VectorScalar<Op>{}(a_ptr, b_ptr, out_ptr, a.data_size());
172 return;
173 }
174
175 // The full computation is vector vector so delegate to the op
176 if (bopt == BinaryOpType::VectorVector) {
177 VectorVector<Op>{}(a_ptr, b_ptr, out_ptr, a.size());
178 return;
179 }
180
181 // General computation so let's try to optimize
182 auto [new_shape, new_strides] = collapse_contiguous_dims(
183 a.shape(), {a.strides(), b.strides(), out.strides()});
184 auto& a_strides = new_strides[0];
185 auto& b_strides = new_strides[1];
186 auto& strides = new_strides[2];
187
188 // Get the left-most dim such that the array is row contiguous after
189 auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
190 int d = arr_strides.size() - 1;
191 for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
192 }
193 return d + 1;
194 };
195 auto a_rc_dim = leftmost_rc_dim(a_strides);
196 auto b_rc_dim = leftmost_rc_dim(b_strides);
197
198 // Get the left-most dim such that the array is a broadcasted "scalar" after
199 auto leftmost_s_dim = [](const auto& arr_strides) {
200 int d = arr_strides.size() - 1;
201 for (; d >= 0 && arr_strides[d] == 0; d--) {
202 }
203 return d + 1;
204 };
205 auto a_s_dim = leftmost_s_dim(a_strides);
206 auto b_s_dim = leftmost_s_dim(b_strides);
207
208 auto ndim = new_shape.size();
209
210 // Case 1: LxM and FxM where L and F are broadcastable and M is row
211 // contiguous
212 int dim = ndim;
213 if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
215 dim = d;
216 // Case 2: LxM and Fx1 where L and F are broadcastable and M is row
217 // contiguous
218 } else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
220 dim = d;
221 // Case 3: Lx1 and FxM where L and F are broadcastable and M is row
222 // contiguous
223 } else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
225 dim = d;
226 }
227
228 // Can be sure dim > 0 since otherwise we would have used one of the fully
229 // contiguous methods above. Except for the case that the flags do not
230 // correspond to the underlying contiguity.
231 if (dim == 0 || strides[dim - 1] < 16) {
233 dim = ndim;
234 }
235
236 switch (bopt) {
239 a_ptr,
240 b_ptr,
241 out_ptr,
242 dim,
243 a.size(),
244 new_shape,
245 a_strides,
246 b_strides,
247 strides);
248 break;
251 a_ptr,
252 b_ptr,
253 out_ptr,
254 dim,
255 a.size(),
256 new_shape,
257 a_strides,
258 b_strides,
259 strides);
260 break;
263 a_ptr,
264 b_ptr,
265 out_ptr,
266 dim,
267 a.size(),
268 new_shape,
269 a_strides,
270 b_strides,
271 strides);
272 break;
273 default:
275 a_ptr,
276 b_ptr,
277 out_ptr,
278 dim,
279 a.size(),
280 new_shape,
281 a_strides,
282 b_strides,
283 strides);
284 break;
285 }
286}
287
288template <typename T, typename Op>
289void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
290 binary_op<T, T, Op>(a, b, out, bopt);
291}
292
293} // namespace mlx::core
Definition array.h:24
const Shape & shape() const
The shape of the array as a vector of integers.
Definition array.h:103
size_t size() const
The number of elements in the array.
Definition array.h:88
T * data()
Definition array.h:349
size_t data_size() const
The size (in elements) of the underlying buffer the array points to.
Definition array.h:327
Simd< T, N > load(const T *x)
Definition base_simd.h:28
static constexpr int max_size
Definition base_simd.h:14
void store(T *dst, Simd< T, N > x)
Definition base_simd.h:33
Definition allocator.h:7
std::tuple< Shape, std::vector< Strides > > collapse_contiguous_dims(const Shape &shape, const std::vector< Strides > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())
BinaryOpType
Definition binary.h:11
@ General
Definition binary.h:16
@ VectorVector
Definition binary.h:15
@ ScalarScalar
Definition binary.h:12
@ VectorScalar
Definition binary.h:14
@ ScalarVector
Definition binary.h:13
std::vector< ShapeElem > Shape
Definition array.h:21
std::vector< int64_t > Strides
Definition array.h:22
void binary_op_dispatch_dims(const T *a, const T *b, U *out, int dim, int size, const Shape &shape, const Strides &a_strides, const Strides &b_strides, const Strides &out_strides)
Definition binary.h:108
void binary_op(const array &a, const array &b, array &out, BinaryOpType bopt)
Definition binary.h:152
void binary_op_dims(const T *a, const T *b, U *out, const Shape &shape, const Strides &a_strides, const Strides &b_strides, const Strides &out_strides, int axis)
Definition binary.h:76
Definition utils.h:73
int64_t loc
Definition utils.h:126
void step()
Definition utils.h:74
Definition binary.h:35
void operator()(const T *a, const T *b, U *dst, int size)
Definition binary.h:37
Definition binary.h:15
void operator()(const T *a, const T *b, U *dst, int size)
Definition binary.h:17
Definition binary.h:55
void operator()(const T *a, const T *b, U *dst, int size)
Definition binary.h:57
Definition accelerate_simd.h:55