MLX
Loading...
Searching...
No Matches
ternary.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4#include "mlx/allocator.h"
5#include "mlx/array.h"
8namespace mlx::core {
9
10namespace {
11
12// TODO: Add support for more combinations of input types.
13enum class TernaryOpType {
14 ScalarScalarScalar,
15 General,
16};
17
18TernaryOpType
19get_ternary_op_type(const array& a, const array& b, const array& c) {
20 TernaryOpType topt;
21 if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
22 topt = TernaryOpType::ScalarScalarScalar;
23 } else {
24 topt = TernaryOpType::General;
25 }
26 return topt;
27}
28
29void set_ternary_op_output_data(
30 const array& a,
31 const array& b,
32 const array& c,
33 array& out,
34 TernaryOpType topt,
35 bool donate_with_move = false) {
36 switch (topt) {
37 case TernaryOpType::ScalarScalarScalar:
38 out.set_data(
39 allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
40 break;
41 case TernaryOpType::General:
42 out.set_data(allocator::malloc_or_wait(out.nbytes()));
43 break;
44 }
45}
46
47template <typename T1, typename T2, typename T3, typename U, typename Op>
48void ternary_op_dims1(
49 const array& a,
50 const array& b,
51 const array& c,
52 array& out,
53 Op op) {
54 const T1* a_ptr = a.data<T1>();
55 const T2* b_ptr = b.data<T2>();
56 const T3* c_ptr = c.data<T3>();
57
58 U* dst = out.data<U>();
59 size_t a_idx = 0;
60 size_t b_idx = 0;
61 size_t c_idx = 0;
62 for (size_t i = 0; i < out.size(); ++i) {
63 dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
64 a_idx += a.strides()[0];
65 b_idx += b.strides()[0];
66 c_idx += c.strides()[0];
67 }
68}
69
70template <typename T1, typename T2, typename T3, typename U, typename Op>
71void ternary_op_dims2(
72 const array& a,
73 const array& b,
74 const array& c,
75 array& out,
76 Op op) {
77 const T1* a_ptr = a.data<T1>();
78 const T2* b_ptr = b.data<T2>();
79 const T3* c_ptr = c.data<T3>();
80
81 U* dst = out.data<U>();
82 size_t a_idx = 0;
83 size_t b_idx = 0;
84 size_t c_idx = 0;
85 size_t out_idx = 0;
86 for (size_t i = 0; i < a.shape()[0]; ++i) {
87 for (size_t j = 0; j < a.shape()[1]; ++j) {
88 dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
89 a_idx += a.strides()[1];
90 b_idx += b.strides()[1];
91 c_idx += c.strides()[1];
92 }
93 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
94 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
95 c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
96 }
97}
98
99template <typename T1, typename T2, typename T3, typename U, typename Op>
100void ternary_op_dims3(
101 const array& a,
102 const array& b,
103 const array& c,
104 array& out,
105 Op op) {
106 const T1* a_ptr = a.data<T1>();
107 const T2* b_ptr = b.data<T2>();
108 const T3* c_ptr = c.data<T3>();
109 U* dst = out.data<U>();
110 size_t a_idx = 0;
111 size_t b_idx = 0;
112 size_t c_idx = 0;
113 size_t out_idx = 0;
114 for (size_t i = 0; i < a.shape()[0]; ++i) {
115 for (size_t j = 0; j < a.shape()[1]; ++j) {
116 for (size_t k = 0; k < a.shape()[2]; ++k) {
117 dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
118 a_idx += a.strides()[2];
119 b_idx += b.strides()[2];
120 c_idx += c.strides()[2];
121 }
122 a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
123 b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
124 c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
125 }
126 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
127 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
128 c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
129 }
130}
131
132template <typename T1, typename T2, typename T3, typename U, typename Op>
133void ternary_op_dims4(
134 const array& a,
135 const array& b,
136 const array& c,
137 array& out,
138 Op op) {
139 const T1* a_ptr = a.data<T1>();
140 const T2* b_ptr = b.data<T2>();
141 const T3* c_ptr = c.data<T3>();
142
143 U* dst = out.data<U>();
144 size_t a_idx = 0;
145 size_t b_idx = 0;
146 size_t c_idx = 0;
147 size_t out_idx = 0;
148 for (size_t i = 0; i < a.shape()[0]; ++i) {
149 for (size_t j = 0; j < a.shape()[1]; ++j) {
150 for (size_t k = 0; k < a.shape()[2]; ++k) {
151 for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
152 dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
153 a_idx += a.strides()[3];
154 b_idx += b.strides()[3];
155 c_idx += c.strides()[3];
156 }
157 a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
158 b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
159 c_idx += c.strides()[2] - c.strides()[3] * c.shape()[3];
160 }
161 a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
162 b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
163 c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
164 }
165 a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
166 b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
167 c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
168 }
169}
170
171template <typename T1, typename T2, typename T3, typename U, typename Op>
172void ternary_op_dispatch_dims(
173 const array& a,
174 const array& b,
175 const array& c,
176 array& out,
177 Op op) {
178 switch (out.ndim()) {
179 case 1:
180 ternary_op_dims1<T1, T2, T3, U, Op>(a, b, c, out, op);
181 return;
182 case 2:
183 ternary_op_dims2<T1, T2, T3, U, Op>(a, b, c, out, op);
184 return;
185 case 3:
186 ternary_op_dims3<T1, T2, T3, U, Op>(a, b, c, out, op);
187 return;
188 case 4:
189 ternary_op_dims4<T1, T2, T3, U, Op>(a, b, c, out, op);
190 return;
191 }
192
193 const T1* a_ptr = a.data<T1>();
194 const T2* b_ptr = b.data<T2>();
195 const T3* c_ptr = c.data<T3>();
196 U* dst = out.data<U>();
197 for (size_t i = 0; i < out.size(); i++) {
198 int a_idx = elem_to_loc(i, a.shape(), a.strides());
199 int b_idx = elem_to_loc(i, b.shape(), b.strides());
200 int c_idx = elem_to_loc(i, c.shape(), c.strides());
201 dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
202 }
203}
204
205template <typename T1, typename T2, typename T3, typename U, typename Op>
206void ternary_op(
207 const array& a,
208 const array& b,
209 const array& c,
210 array& out,
211 Op op) {
212 TernaryOpType topt = get_ternary_op_type(a, b, c);
213 set_ternary_op_output_data(a, b, c, out, topt);
214
215 // The full computation is scalar-scalar-scalar so we call the base op once.
216 if (topt == TernaryOpType::ScalarScalarScalar) {
217 *(out.data<U>()) = op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
218 return;
219 }
220
221 ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op);
222}
223
224} // namespace
225
226} // namespace mlx::core
Op op
Definition binary.h:141
Buffer malloc_or_wait(size_t size)
Definition allocator.h:7
stride_t elem_to_loc(int elem, const std::vector< int > &shape, const std::vector< stride_t > &strides)
Definition utils.h:12