MLX
 
Loading...
Searching...
No Matches
ternary.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4#include "mlx/allocator.h"
5#include "mlx/array.h"
8namespace mlx::core {
9
10namespace {
11
12// TODO: Add support for more combinations of input types.
13enum class TernaryOpType {
14 ScalarScalarScalar,
15 VectorVectorVector,
16 General,
17};
18
19TernaryOpType
20get_ternary_op_type(const array& a, const array& b, const array& c) {
21 TernaryOpType topt;
22 if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
23 topt = TernaryOpType::ScalarScalarScalar;
24 } else if (
25 (a.flags().row_contiguous && b.flags().row_contiguous &&
26 c.flags().row_contiguous) ||
27 (a.flags().col_contiguous && b.flags().col_contiguous &&
28 c.flags().col_contiguous)) {
29 topt = TernaryOpType::VectorVectorVector;
30 } else {
31 topt = TernaryOpType::General;
32 }
33 return topt;
34}
35
36void set_ternary_op_output_data(
37 const array& a,
38 const array& b,
39 const array& c,
40 array& out,
41 TernaryOpType topt,
42 bool donate_with_move = false) {
43 auto maybe_donate = [&out, donate_with_move](const array& x) {
44 if (is_donatable(x, out)) {
45 if (donate_with_move) {
46 out.move_shared_buffer(x);
47 } else {
48 out.copy_shared_buffer(x);
49 }
50 return true;
51 }
52 return false;
53 };
54
55 switch (topt) {
56 case TernaryOpType::ScalarScalarScalar:
57 out.set_data(
58 allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
59 break;
60 case TernaryOpType::VectorVectorVector:
61 if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
62 out.set_data(
63 allocator::malloc_or_wait(out.itemsize() * b.data_size()),
64 b.data_size(),
65 b.strides(),
66 b.flags());
67 }
68 break;
69 case TernaryOpType::General:
70 // Try to donate an input which is row_contiguous
71 if (!((a.flags().row_contiguous && maybe_donate(a)) ||
72 (b.flags().row_contiguous && maybe_donate(b)) ||
73 (c.flags().row_contiguous && maybe_donate(c)))) {
74 out.set_data(allocator::malloc_or_wait(out.nbytes()));
75 }
76 break;
77 }
78}
79template <typename T1, typename T2, typename T3, typename U, typename Op, int D>
80void ternary_op_dims(
81 const T1* a,
82 const T2* b,
83 const T3* c,
84 U* out,
85 Op op,
86 const Shape& shape,
87 const Strides& a_strides,
88 const Strides& b_strides,
89 const Strides& c_strides,
90 const Strides& out_strides,
91 int axis) {
92 auto stride_a = a_strides[axis];
93 auto stride_b = b_strides[axis];
94 auto stride_c = c_strides[axis];
95 auto stride_out = out_strides[axis];
96 auto N = shape[axis];
97
98 for (int i = 0; i < N; i++) {
99 if constexpr (D > 1) {
100 ternary_op_dims<T1, T2, T3, U, Op, D - 1>(
101 a,
102 b,
103 c,
104 out,
105 op,
106 shape,
107 a_strides,
108 b_strides,
109 c_strides,
110 out_strides,
111 axis + 1);
112 } else {
113 *out = op(*a, *b, *c);
114 }
115 a += stride_a;
116 b += stride_b;
117 c += stride_c;
118 out += stride_out;
119 }
120}
121
122template <typename T1, typename T2, typename T3, typename U, typename Op>
123void ternary_op_dispatch_dims(
124 const array& a,
125 const array& b,
126 const array& c,
127 array& out,
128 Op op) {
129 auto [shape, strides] = collapse_contiguous_dims(
130 a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
131 const auto& a_strides = strides[0];
132 const auto& b_strides = strides[1];
133 const auto& c_strides = strides[2];
134 const auto& out_strides = strides[3];
135
136 const T1* a_ptr = a.data<T1>();
137 const T2* b_ptr = b.data<T2>();
138 const T3* c_ptr = c.data<T3>();
139 U* out_ptr = out.data<T3>();
140 int ndim = shape.size();
141 switch (ndim) {
142 case 1:
143 ternary_op_dims<T1, T2, T3, U, Op, 1>(
144 a_ptr,
145 b_ptr,
146 c_ptr,
147 out_ptr,
148 op,
149 shape,
150 a_strides,
151 b_strides,
152 c_strides,
153 out_strides,
154 0);
155 return;
156 case 2:
157 ternary_op_dims<T1, T2, T3, U, Op, 2>(
158 a_ptr,
159 b_ptr,
160 c_ptr,
161 out_ptr,
162 op,
163 shape,
164 a_strides,
165 b_strides,
166 c_strides,
167 out_strides,
168 0);
169 return;
170 }
171
172 ContiguousIterator a_it(shape, a_strides, ndim - 2);
173 ContiguousIterator b_it(shape, b_strides, ndim - 2);
174 ContiguousIterator c_it(shape, c_strides, ndim - 2);
175 auto stride = out_strides[ndim - 3];
176 for (size_t elem = 0; elem < a.size(); elem += stride) {
177 ternary_op_dims<T1, T2, T3, U, Op, 2>(
178 a_ptr + a_it.loc,
179 b_ptr + b_it.loc,
180 c_ptr + c_it.loc,
181 out_ptr + elem,
182 op,
183 shape,
184 a_strides,
185 b_strides,
186 c_strides,
187 out_strides,
188 ndim - 2);
189 a_it.step();
190 b_it.step();
191 c_it.step();
192 }
193}
194
195template <typename T1, typename T2, typename T3, typename U, typename Op>
196void ternary_op(
197 const array& a,
198 const array& b,
199 const array& c,
200 array& out,
201 Op op) {
202 TernaryOpType topt = get_ternary_op_type(a, b, c);
203 set_ternary_op_output_data(a, b, c, out, topt);
204
205 // The full computation is scalar-scalar-scalar so we call the base op once.
206 if (topt == TernaryOpType::ScalarScalarScalar) {
207 *(out.data<U>()) = op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
208 } else if (topt == TernaryOpType::VectorVectorVector) {
209 const T1* a_ptr = a.data<T1>();
210 const T2* b_ptr = b.data<T2>();
211 const T3* c_ptr = c.data<T3>();
212 U* out_ptr = out.data<U>();
213 for (size_t i = 0; i < out.size(); ++i) {
214 *out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
215 a_ptr++;
216 b_ptr++;
217 c_ptr++;
218 out_ptr++;
219 }
220 } else {
221 ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op);
222 }
223}
224
225} // namespace
226
227} // namespace mlx::core
Definition array.h:24
Buffer malloc_or_wait(size_t size)
Definition allocator.h:7
std::tuple< Shape, std::vector< Strides > > collapse_contiguous_dims(const Shape &shape, const std::vector< Strides > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())
std::vector< ShapeElem > Shape
Definition array.h:21
std::vector< int64_t > Strides
Definition array.h:22
bool is_donatable(const array &in, const array &out)
Definition utils.h:155
Definition utils.h:73