MLX
 
Loading...
Searching...
No Matches
ternary.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4#include "mlx/allocator.h"
5#include "mlx/array.h"
8
9namespace mlx::core {
10
11template <typename T1, typename T2, typename T3, typename U, typename Op, int D>
13 const T1* a,
14 const T2* b,
15 const T3* c,
16 U* out,
17 Op op,
18 const Shape& shape,
19 const Strides& a_strides,
20 const Strides& b_strides,
21 const Strides& c_strides,
22 const Strides& out_strides,
23 int axis) {
24 auto stride_a = a_strides[axis];
25 auto stride_b = b_strides[axis];
26 auto stride_c = c_strides[axis];
27 auto stride_out = out_strides[axis];
28 auto N = shape[axis];
29
30 for (int i = 0; i < N; i++) {
31 if constexpr (D > 1) {
32 ternary_op_dims<T1, T2, T3, U, Op, D - 1>(
33 a,
34 b,
35 c,
36 out,
37 op,
38 shape,
39 a_strides,
40 b_strides,
41 c_strides,
42 out_strides,
43 axis + 1);
44 } else {
45 *out = op(*a, *b, *c);
46 }
47 a += stride_a;
48 b += stride_b;
49 c += stride_c;
50 out += stride_out;
51 }
52}
53
54template <typename T1, typename T2, typename T3, typename U, typename Op>
56 const array& a,
57 const array& b,
58 const array& c,
59 array& out,
60 Op op) {
61 auto [shape, strides] = collapse_contiguous_dims(
62 a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
63 const auto& a_strides = strides[0];
64 const auto& b_strides = strides[1];
65 const auto& c_strides = strides[2];
66 const auto& out_strides = strides[3];
67
68 const T1* a_ptr = a.data<T1>();
69 const T2* b_ptr = b.data<T2>();
70 const T3* c_ptr = c.data<T3>();
71 U* out_ptr = out.data<T3>();
72 int ndim = shape.size();
73 switch (ndim) {
74 case 1:
76 a_ptr,
77 b_ptr,
78 c_ptr,
79 out_ptr,
80 op,
81 shape,
82 a_strides,
83 b_strides,
84 c_strides,
85 out_strides,
86 0);
87 return;
88 case 2:
90 a_ptr,
91 b_ptr,
92 c_ptr,
93 out_ptr,
94 op,
95 shape,
96 a_strides,
97 b_strides,
98 c_strides,
99 out_strides,
100 0);
101 return;
102 }
103
104 ContiguousIterator a_it(shape, a_strides, ndim - 2);
105 ContiguousIterator b_it(shape, b_strides, ndim - 2);
106 ContiguousIterator c_it(shape, c_strides, ndim - 2);
107 auto stride = out_strides[ndim - 3];
108 for (size_t elem = 0; elem < a.size(); elem += stride) {
110 a_ptr + a_it.loc,
111 b_ptr + b_it.loc,
112 c_ptr + c_it.loc,
113 out_ptr + elem,
114 op,
115 shape,
116 a_strides,
117 b_strides,
118 c_strides,
119 out_strides,
120 ndim - 2);
121 a_it.step();
122 b_it.step();
123 c_it.step();
124 }
125}
126
127template <typename T1, typename T2, typename T3, typename U, typename Op>
129 const array& a,
130 const array& b,
131 const array& c,
132 array& out,
133 Op op) {
134 TernaryOpType topt = get_ternary_op_type(a, b, c);
135 set_ternary_op_output_data(a, b, c, out, topt);
136
137 // The full computation is scalar-scalar-scalar so we call the base op once.
139 *(out.data<U>()) = op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
140 } else if (topt == TernaryOpType::VectorVectorVector) {
141 const T1* a_ptr = a.data<T1>();
142 const T2* b_ptr = b.data<T2>();
143 const T3* c_ptr = c.data<T3>();
144 U* out_ptr = out.data<U>();
145 for (size_t i = 0; i < out.size(); ++i) {
146 *out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
147 a_ptr++;
148 b_ptr++;
149 c_ptr++;
150 out_ptr++;
151 }
152 } else {
154 }
155}
156
157} // namespace mlx::core
Definition array.h:24
const Shape & shape() const
The shape of the array as a vector of integers.
Definition array.h:103
size_t size() const
The number of elements in the array.
Definition array.h:88
T * data()
Definition array.h:342
Definition allocator.h:7
std::tuple< Shape, std::vector< Strides > > collapse_contiguous_dims(const Shape &shape, const std::vector< Strides > &strides, int64_t size_cap=std::numeric_limits< int32_t >::max())
TernaryOpType get_ternary_op_type(const array &a, const array &b, const array &c)
Definition ternary.h:18
std::vector< ShapeElem > Shape
Definition array.h:21
void set_ternary_op_output_data(const array &a, const array &b, const array &c, array &out, TernaryOpType topt, bool donate_with_move=false)
Definition ternary.h:34
std::vector< int64_t > Strides
Definition array.h:22
void ternary_op_dims(const T1 *a, const T2 *b, const T3 *c, U *out, Op op, const Shape &shape, const Strides &a_strides, const Strides &b_strides, const Strides &c_strides, const Strides &out_strides, int axis)
Definition ternary.h:12
void ternary_op(const array &a, const array &b, const array &c, array &out, Op op)
Definition ternary.h:128
void ternary_op_dispatch_dims(const array &a, const array &b, const array &c, array &out, Op op)
Definition ternary.h:55
TernaryOpType
Definition ternary.h:11
@ ScalarScalarScalar
Definition ternary.h:12
@ VectorVectorVector
Definition ternary.h:13
Definition utils.h:73
int64_t loc
Definition utils.h:126
void step()
Definition utils.h:74