MLX
 
Loading...
Searching...
No Matches
ternary.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4#include "mlx/array.h"
8
9namespace mlx::core {
10
11template <typename T1, typename T2, typename T3, typename U, typename Op, int D>
13 const T1* a,
14 const T2* b,
15 const T3* c,
16 U* out,
17 Op op,
18 const Shape& shape,
19 const Strides& a_strides,
20 const Strides& b_strides,
21 const Strides& c_strides,
22 const Strides& out_strides,
23 int axis) {
24 auto stride_a = a_strides[axis];
25 auto stride_b = b_strides[axis];
26 auto stride_c = c_strides[axis];
27 auto stride_out = out_strides[axis];
28 auto N = shape[axis];
29
30 for (int i = 0; i < N; i++) {
31 if constexpr (D > 1) {
32 ternary_op_dims<T1, T2, T3, U, Op, D - 1>(
33 a,
34 b,
35 c,
36 out,
37 op,
38 shape,
39 a_strides,
40 b_strides,
41 c_strides,
42 out_strides,
43 axis + 1);
44 } else {
45 *out = op(*a, *b, *c);
46 }
47 a += stride_a;
48 b += stride_b;
49 c += stride_c;
50 out += stride_out;
51 }
52}
53
54template <typename T1, typename T2, typename T3, typename U, typename Op>
56 const T1* a_ptr,
57 const T2* b_ptr,
58 const T3* c_ptr,
59 U* out_ptr,
60 Op op,
61 size_t size,
62 Shape& shape,
63 std::vector<Strides>& strides) {
64 const auto& a_strides = strides[0];
65 const auto& b_strides = strides[1];
66 const auto& c_strides = strides[2];
67 const auto& out_strides = strides[3];
68 int ndim = shape.size();
69 switch (ndim) {
70 case 1:
72 a_ptr,
73 b_ptr,
74 c_ptr,
75 out_ptr,
76 op,
77 shape,
78 a_strides,
79 b_strides,
80 c_strides,
81 out_strides,
82 0);
83 return;
84 case 2:
86 a_ptr,
87 b_ptr,
88 c_ptr,
89 out_ptr,
90 op,
91 shape,
92 a_strides,
93 b_strides,
94 c_strides,
95 out_strides,
96 0);
97 return;
98 }
99
100 ContiguousIterator a_it(shape, a_strides, ndim - 2);
101 ContiguousIterator b_it(shape, b_strides, ndim - 2);
102 ContiguousIterator c_it(shape, c_strides, ndim - 2);
103 auto stride = out_strides[ndim - 3];
104 for (size_t elem = 0; elem < size; elem += stride) {
106 a_ptr + a_it.loc,
107 b_ptr + b_it.loc,
108 c_ptr + c_it.loc,
109 out_ptr + elem,
110 op,
111 shape,
112 a_strides,
113 b_strides,
114 c_strides,
115 out_strides,
116 ndim - 2);
117 a_it.step();
118 b_it.step();
119 c_it.step();
120 }
121}
122
123template <typename T1, typename T2, typename T3, typename U, typename Op>
125 const array& a,
126 const array& b,
127 const array& c,
128 array& out,
129 Op op,
130 TernaryOpType topt) {
131 const T1* a_ptr = a.data<T1>();
132 const T2* b_ptr = b.data<T2>();
133 const T3* c_ptr = c.data<T3>();
134 U* out_ptr = out.data<U>();
135
137 *out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
138 } else if (topt == TernaryOpType::VectorVectorVector) {
139 for (size_t i = 0; i < out.size(); ++i) {
140 *out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
141 a_ptr++;
142 b_ptr++;
143 c_ptr++;
144 out_ptr++;
145 }
146 } else {
147 auto [shape, strides] = collapse_contiguous_dims(
148 a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
150 a_ptr, b_ptr, c_ptr, out_ptr, op, out.size(), shape, strides);
151 }
152}
153
154} // namespace mlx::core
Definition array.h:24
const Shape & shape() const
The shape of the array as a vector of integers.
Definition array.h:103
size_t size() const
The number of elements in the array.
Definition array.h:88
T * data()
Definition array.h:349
Definition allocator.h:7
void ternary_op(const array &a, const array &b, const array &c, array &out, Op op, TernaryOpType topt)
Definition ternary.h:124
std::tuple< Shape, std::vector< Strides > > collapse_contiguous_dims(const Shape &shape, const std::vector< Strides > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())
std::vector< ShapeElem > Shape
Definition array.h:21
std::vector< int64_t > Strides
Definition array.h:22
void ternary_op_dims(const T1 *a, const T2 *b, const T3 *c, U *out, Op op, const Shape &shape, const Strides &a_strides, const Strides &b_strides, const Strides &c_strides, const Strides &out_strides, int axis)
Definition ternary.h:12
void ternary_op_dispatch_dims(const T1 *a_ptr, const T2 *b_ptr, const T3 *c_ptr, U *out_ptr, Op op, size_t size, Shape &shape, std::vector< Strides > &strides)
Definition ternary.h:55
TernaryOpType
Definition ternary.h:11
@ ScalarScalarScalar
Definition ternary.h:12
@ VectorVectorVector
Definition ternary.h:13
Definition utils.h:73
int64_t loc
Definition utils.h:126
void step()
Definition utils.h:74