MLX
Loading...
Searching...
No Matches
ternary.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4#include "mlx/allocator.h"
5#include "mlx/array.h"
8namespace mlx::core {
9
10namespace {
11
12// TODO: Add support for more combinations of input types.
13enum class TernaryOpType {
14 ScalarScalarScalar,
15 VectorVectorVector,
16 General,
17};
18
19TernaryOpType
20get_ternary_op_type(const array& a, const array& b, const array& c) {
21 TernaryOpType topt;
22 if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
23 topt = TernaryOpType::ScalarScalarScalar;
24 } else if (
25 (a.flags().row_contiguous && b.flags().row_contiguous &&
26 c.flags().row_contiguous) ||
27 (a.flags().col_contiguous && b.flags().col_contiguous &&
28 c.flags().col_contiguous)) {
29 topt = TernaryOpType::VectorVectorVector;
30 } else {
31 topt = TernaryOpType::General;
32 }
33 return topt;
34}
35
36void set_ternary_op_output_data(
37 const array& a,
38 const array& b,
39 const array& c,
40 array& out,
41 TernaryOpType topt,
42 bool donate_with_move = false) {
43 auto maybe_donate = [&out, donate_with_move](const array& x) {
44 if (is_donatable(x, out)) {
45 if (donate_with_move) {
46 out.move_shared_buffer(x);
47 } else {
48 out.copy_shared_buffer(x);
49 }
50 return true;
51 }
52 return false;
53 };
54
55 switch (topt) {
56 case TernaryOpType::ScalarScalarScalar:
57 out.set_data(
58 allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
59 break;
60 case TernaryOpType::VectorVectorVector:
61 if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
62 out.set_data(
63 allocator::malloc_or_wait(out.itemsize() * b.data_size()),
64 b.data_size(),
65 b.strides(),
66 b.flags());
67 }
68 break;
69 case TernaryOpType::General:
70 out.set_data(allocator::malloc_or_wait(out.nbytes()));
71 break;
72 }
73}
74
75template <typename T1, typename T2, typename T3, typename U, typename Op>
76void ternary_op_dims1(
77 const array& a,
78 const array& b,
79 const array& c,
80 array& out,
81 Op op) {
82 const T1* a_ptr = a.data<T1>();
83 const T2* b_ptr = b.data<T2>();
84 const T3* c_ptr = c.data<T3>();
85
86 U* dst = out.data<U>();
87 size_t a_idx = 0;
88 size_t b_idx = 0;
89 size_t c_idx = 0;
90 for (size_t i = 0; i < out.size(); ++i) {
91 dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
92 a_idx += a.strides()[0];
93 b_idx += b.strides()[0];
94 c_idx += c.strides()[0];
95 }
96}
97
98template <typename T1, typename T2, typename T3, typename U, typename Op>
99void ternary_op_dims2(
100 const array& a,
101 const array& b,
102 const array& c,
103 array& out,
104 Op op) {
105 const T1* a_ptr = a.data<T1>();
106 const T2* b_ptr = b.data<T2>();
107 const T3* c_ptr = c.data<T3>();
108
109 U* dst = out.data<U>();
110 size_t a_idx = 0;
111 size_t b_idx = 0;
112 size_t c_idx = 0;
113 size_t out_idx = 0;
114 for (size_t i = 0; i < a.shape()[0]; ++i) {
115 for (size_t j = 0; j < a.shape()[1]; ++j) {
116 dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
117 a_idx += a.strides()[1];
118 b_idx += b.strides()[1];
119 c_idx += c.strides()[1];
120 }
121 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
122 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
123 c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
124 }
125}
126
127template <typename T1, typename T2, typename T3, typename U, typename Op>
128void ternary_op_dims3(
129 const array& a,
130 const array& b,
131 const array& c,
132 array& out,
133 Op op) {
134 const T1* a_ptr = a.data<T1>();
135 const T2* b_ptr = b.data<T2>();
136 const T3* c_ptr = c.data<T3>();
137 U* dst = out.data<U>();
138 size_t a_idx = 0;
139 size_t b_idx = 0;
140 size_t c_idx = 0;
141 size_t out_idx = 0;
142 for (size_t i = 0; i < a.shape()[0]; ++i) {
143 for (size_t j = 0; j < a.shape()[1]; ++j) {
144 for (size_t k = 0; k < a.shape()[2]; ++k) {
145 dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
146 a_idx += a.strides()[2];
147 b_idx += b.strides()[2];
148 c_idx += c.strides()[2];
149 }
150 a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
151 b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
152 c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
153 }
154 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
155 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
156 c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
157 }
158}
159
160template <typename T1, typename T2, typename T3, typename U, typename Op>
161void ternary_op_dims4(
162 const array& a,
163 const array& b,
164 const array& c,
165 array& out,
166 Op op) {
167 const T1* a_ptr = a.data<T1>();
168 const T2* b_ptr = b.data<T2>();
169 const T3* c_ptr = c.data<T3>();
170
171 U* dst = out.data<U>();
172 size_t a_idx = 0;
173 size_t b_idx = 0;
174 size_t c_idx = 0;
175 size_t out_idx = 0;
176 for (size_t i = 0; i < a.shape()[0]; ++i) {
177 for (size_t j = 0; j < a.shape()[1]; ++j) {
178 for (size_t k = 0; k < a.shape()[2]; ++k) {
179 for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
180 dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
181 a_idx += a.strides()[3];
182 b_idx += b.strides()[3];
183 c_idx += c.strides()[3];
184 }
185 a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
186 b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
187 c_idx += c.strides()[2] - c.strides()[3] * c.shape()[3];
188 }
189 a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
190 b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
191 c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
192 }
193 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
194 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
195 c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
196 }
197}
198
199template <typename T1, typename T2, typename T3, typename U, typename Op>
200void ternary_op_dispatch_dims(
201 const array& a,
202 const array& b,
203 const array& c,
204 array& out,
205 Op op) {
206 switch (out.ndim()) {
207 case 1:
208 ternary_op_dims1<T1, T2, T3, U, Op>(a, b, c, out, op);
209 return;
210 case 2:
211 ternary_op_dims2<T1, T2, T3, U, Op>(a, b, c, out, op);
212 return;
213 case 3:
214 ternary_op_dims3<T1, T2, T3, U, Op>(a, b, c, out, op);
215 return;
216 case 4:
217 ternary_op_dims4<T1, T2, T3, U, Op>(a, b, c, out, op);
218 return;
219 }
220
221 const T1* a_ptr = a.data<T1>();
222 const T2* b_ptr = b.data<T2>();
223 const T3* c_ptr = c.data<T3>();
224 U* dst = out.data<U>();
225 for (size_t i = 0; i < out.size(); i++) {
226 int a_idx = elem_to_loc(i, a.shape(), a.strides());
227 int b_idx = elem_to_loc(i, b.shape(), b.strides());
228 int c_idx = elem_to_loc(i, c.shape(), c.strides());
229 dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
230 }
231}
232
233template <typename T1, typename T2, typename T3, typename U, typename Op>
234void ternary_op(
235 const array& a,
236 const array& b,
237 const array& c,
238 array& out,
239 Op op) {
240 TernaryOpType topt = get_ternary_op_type(a, b, c);
241 set_ternary_op_output_data(a, b, c, out, topt);
242
243 // The full computation is scalar-scalar-scalar so we call the base op once.
244 if (topt == TernaryOpType::ScalarScalarScalar) {
245 *(out.data<U>()) = op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
246 return;
247 }
248
249 ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op);
250}
251
252} // namespace
253
254} // namespace mlx::core
Op op
Definition binary.h:141
Buffer malloc_or_wait(size_t size)
Definition allocator.h:7
stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:12
bool is_donatable(const array &in, const array &out)
Definition utils.h:158