MLX
Loading...
Searching...
No Matches
ternary.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4#include "mlx/allocator.h"
5#include "mlx/array.h"
8namespace mlx::core {
9
10namespace {
11
12// TODO: Add support for more combinations of input types.
13enum class TernaryOpType {
14 ScalarScalarScalar,
15 VectorVectorVector,
16 General,
17};
18
19TernaryOpType
20get_ternary_op_type(const array& a, const array& b, const array& c) {
21 TernaryOpType topt;
22 if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
23 topt = TernaryOpType::ScalarScalarScalar;
24 } else if (
25 (a.flags().row_contiguous && b.flags().row_contiguous &&
26 c.flags().row_contiguous) ||
27 (a.flags().col_contiguous && b.flags().col_contiguous &&
28 c.flags().col_contiguous)) {
29 topt = TernaryOpType::VectorVectorVector;
30 } else {
31 topt = TernaryOpType::General;
32 }
33 return topt;
34}
35
36void set_ternary_op_output_data(
37 const array& a,
38 const array& b,
39 const array& c,
40 array& out,
41 TernaryOpType topt,
42 bool donate_with_move = false) {
43 auto maybe_donate = [&out, donate_with_move](const array& x) {
44 if (is_donatable(x, out)) {
45 if (donate_with_move) {
46 out.move_shared_buffer(x);
47 } else {
48 out.copy_shared_buffer(x);
49 }
50 return true;
51 }
52 return false;
53 };
54
55 switch (topt) {
56 case TernaryOpType::ScalarScalarScalar:
57 out.set_data(
58 allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
59 break;
60 case TernaryOpType::VectorVectorVector:
61 if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
62 out.set_data(
63 allocator::malloc_or_wait(out.itemsize() * b.data_size()),
64 b.data_size(),
65 b.strides(),
66 b.flags());
67 }
68 break;
69 case TernaryOpType::General:
70 out.set_data(allocator::malloc_or_wait(out.nbytes()));
71 break;
72 }
73}
74template <typename T1, typename T2, typename T3, typename U, typename Op, int D>
75void ternary_op_dims(
76 const T1* a,
77 const T2* b,
78 const T3* c,
79 U* out,
80 Op op,
81 const std::vector<int>& shape,
82 const std::vector<size_t>& a_strides,
83 const std::vector<size_t>& b_strides,
84 const std::vector<size_t>& c_strides,
85 const std::vector<size_t>& out_strides,
86 int axis) {
87 auto stride_a = a_strides[axis];
88 auto stride_b = b_strides[axis];
89 auto stride_c = c_strides[axis];
90 auto stride_out = out_strides[axis];
91 auto N = shape[axis];
92
93 for (int i = 0; i < N; i++) {
94 if constexpr (D > 1) {
95 ternary_op_dims<T1, T2, T3, U, Op, D - 1>(
96 a,
97 b,
98 c,
99 out,
100 op,
101 shape,
102 a_strides,
103 b_strides,
104 c_strides,
105 out_strides,
106 axis + 1);
107 } else {
108 *out = op(*a, *b, *c);
109 }
110 a += stride_a;
111 b += stride_b;
112 c += stride_c;
113 out += stride_out;
114 }
115}
116
117template <typename T1, typename T2, typename T3, typename U, typename Op>
118void ternary_op_dispatch_dims(
119 const array& a,
120 const array& b,
121 const array& c,
122 array& out,
123 Op op) {
124 auto [shape, strides] = collapse_contiguous_dims(
125 a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
126 const auto& a_strides = strides[0];
127 const auto& b_strides = strides[1];
128 const auto& c_strides = strides[2];
129 const auto& out_strides = strides[3];
130
131 const T1* a_ptr = a.data<T1>();
132 const T2* b_ptr = b.data<T2>();
133 const T3* c_ptr = c.data<T3>();
134 U* out_ptr = out.data<T3>();
135 int ndim = shape.size();
136 switch (ndim) {
137 case 1:
138 ternary_op_dims<T1, T2, T3, U, Op, 1>(
139 a_ptr,
140 b_ptr,
141 c_ptr,
142 out_ptr,
143 op,
144 shape,
145 a_strides,
146 b_strides,
147 c_strides,
148 out_strides,
149 0);
150 return;
151 case 2:
152 ternary_op_dims<T1, T2, T3, U, Op, 2>(
153 a_ptr,
154 b_ptr,
155 c_ptr,
156 out_ptr,
157 op,
158 shape,
159 a_strides,
160 b_strides,
161 c_strides,
162 out_strides,
163 0);
164 return;
165 }
166
167 ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2);
168 ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2);
169 ContiguousIterator<size_t> c_it(shape, c_strides, ndim - 2);
170 size_t stride = out_strides[ndim - 3];
171 for (size_t elem = 0; elem < a.size(); elem += stride) {
172 ternary_op_dims<T1, T2, T3, U, Op, 2>(
173 a_ptr + a_it.loc,
174 b_ptr + b_it.loc,
175 c_ptr + c_it.loc,
176 out_ptr + elem,
177 op,
178 shape,
179 a_strides,
180 b_strides,
181 c_strides,
182 out_strides,
183 ndim - 2);
184 a_it.step();
185 b_it.step();
186 c_it.step();
187 }
188}
189
190template <typename T1, typename T2, typename T3, typename U, typename Op>
191void ternary_op(
192 const array& a,
193 const array& b,
194 const array& c,
195 array& out,
196 Op op) {
197 TernaryOpType topt = get_ternary_op_type(a, b, c);
198 set_ternary_op_output_data(a, b, c, out, topt);
199
200 // The full computation is scalar-scalar-scalar so we call the base op once.
201 if (topt == TernaryOpType::ScalarScalarScalar) {
202 *(out.data<U>()) = op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
203 } else if (topt == TernaryOpType::VectorVectorVector) {
204 const T1* a_ptr = a.data<T1>();
205 const T2* b_ptr = b.data<T2>();
206 const T3* c_ptr = c.data<T3>();
207 U* out_ptr = out.data<U>();
208 for (size_t i = 0; i < out.size(); ++i) {
209 *out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
210 a_ptr++;
211 b_ptr++;
212 c_ptr++;
213 out_ptr++;
214 }
215 } else {
216 ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op);
217 }
218}
219
220} // namespace
221
222} // namespace mlx::core
Op op
Definition binary.h:129
Buffer malloc_or_wait(size_t size)
Definition allocator.h:7
std::tuple< std::vector< int >, std::vector< std::vector< int64_t > > > collapse_contiguous_dims(const std::vector< int > &shape, const std::vector< std::vector< int64_t > > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())
bool is_donatable(const array &in, const array &out)
Definition utils.h:174