MLX
Loading...
Searching...
No Matches
binary_two.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3template <typename T, typename U, typename Op>
4[[kernel]] void binary_ss(
5 device const T* a,
6 device const T* b,
7 device U* c,
8 device U* d,
9 uint index [[thread_position_in_grid]]) {
10 auto out = Op()(a[0], b[0]);
11 c[index] = out[0];
12 d[index] = out[1];
13}
14
15template <typename T, typename U, typename Op>
16[[kernel]] void binary_sv(
17 device const T* a,
18 device const T* b,
19 device U* c,
20 device U* d,
21 uint index [[thread_position_in_grid]]) {
22 auto out = Op()(a[0], b[index]);
23 c[index] = out[0];
24 d[index] = out[1];
25}
26
27template <typename T, typename U, typename Op>
28[[kernel]] void binary_vs(
29 device const T* a,
30 device const T* b,
31 device U* c,
32 device U* d,
33 uint index [[thread_position_in_grid]]) {
34 auto out = Op()(a[index], b[0]);
35 c[index] = out[0];
36 d[index] = out[1];
37}
38
39template <typename T, typename U, typename Op>
40[[kernel]] void binary_vv(
41 device const T* a,
42 device const T* b,
43 device U* c,
44 device U* d,
45 uint index [[thread_position_in_grid]]) {
46 auto out = Op()(a[index], b[index]);
47 c[index] = out[0];
48 d[index] = out[1];
49}
50
51template <typename T, typename U, typename Op>
52[[kernel]] void binary_sv2(
53 device const T* a,
54 device const T* b,
55 device U* c,
56 device U* d,
57 uint2 index [[thread_position_in_grid]],
58 uint2 grid_dim [[threads_per_grid]]) {
59 size_t offset = index.x + grid_dim.x * size_t(index.y);
60 auto out = Op()(a[0], b[offset]);
61 c[offset] = out[0];
62 d[offset] = out[1];
63}
64
65template <typename T, typename U, typename Op>
66[[kernel]] void binary_vs2(
67 device const T* a,
68 device const T* b,
69 device U* c,
70 device U* d,
71 uint2 index [[thread_position_in_grid]],
72 uint2 grid_dim [[threads_per_grid]]) {
73 size_t offset = index.x + grid_dim.x * size_t(index.y);
74 auto out = Op()(a[offset], b[0]);
75 c[offset] = out[0];
76 d[offset] = out[1];
77}
78
79template <typename T, typename U, typename Op>
80[[kernel]] void binary_vv2(
81 device const T* a,
82 device const T* b,
83 device U* c,
84 device U* d,
85 uint2 index [[thread_position_in_grid]],
86 uint2 grid_dim [[threads_per_grid]]) {
87 size_t offset = index.x + grid_dim.x * size_t(index.y);
88 auto out = Op()(a[offset], b[offset]);
89 c[offset] = out[0];
90 d[offset] = out[1];
91}
92
93template <typename T, typename U, typename Op>
94[[kernel]] void binary_g_nd1(
95 device const T* a,
96 device const T* b,
97 device U* c,
98 device U* d,
99 constant const size_t& a_stride,
100 constant const size_t& b_stride,
101 uint index [[thread_position_in_grid]]) {
102 auto a_idx = elem_to_loc_1(index, a_stride);
103 auto b_idx = elem_to_loc_1(index, b_stride);
104 auto out = Op()(a[a_idx], b[b_idx]);
105 c[index] = out[0];
106 d[index] = out[1];
107}
108
109template <typename T, typename U, typename Op>
110[[kernel]] void binary_g_nd2(
111 device const T* a,
112 device const T* b,
113 device U* c,
114 device U* d,
115 constant const size_t a_strides[2],
116 constant const size_t b_strides[2],
117 uint2 index [[thread_position_in_grid]],
118 uint2 grid_dim [[threads_per_grid]]) {
119 auto a_idx = elem_to_loc_2(index, a_strides);
120 auto b_idx = elem_to_loc_2(index, b_strides);
121 size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
122 auto out = Op()(a[a_idx], b[b_idx]);
123 c[out_idx] = out[0];
124 d[out_idx] = out[1];
125}
126
127template <typename T, typename U, typename Op>
128[[kernel]] void binary_g_nd3(
129 device const T* a,
130 device const T* b,
131 device U* c,
132 device U* d,
133 constant const size_t a_strides[3],
134 constant const size_t b_strides[3],
135 uint3 index [[thread_position_in_grid]],
136 uint3 grid_dim [[threads_per_grid]]) {
137 auto a_idx = elem_to_loc_3(index, a_strides);
138 auto b_idx = elem_to_loc_3(index, b_strides);
139 size_t out_idx =
140 index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
141 auto out = Op()(a[a_idx], b[b_idx]);
142 c[out_idx] = out[0];
143 d[out_idx] = out[1];
144}
145
146template <typename T, typename U, typename Op, int DIM>
147[[kernel]] void binary_g_nd(
148 device const T* a,
149 device const T* b,
150 device U* c,
151 device U* d,
152 constant const int shape[DIM],
153 constant const size_t a_strides[DIM],
154 constant const size_t b_strides[DIM],
155 uint3 index [[thread_position_in_grid]],
156 uint3 grid_dim [[threads_per_grid]]) {
157 auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
158 size_t out_idx =
159 index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
160 auto out = Op()(a[idx.x], b[idx.y]);
161 c[out_idx] = out[0];
162 d[out_idx] = out[1];
163}
164
165template <typename T, typename U, typename Op>
166[[kernel]] void binary_g(
167 device const T* a,
168 device const T* b,
169 device U* c,
170 device U* d,
171 constant const int* shape,
172 constant const size_t* a_strides,
173 constant const size_t* b_strides,
174 constant const int& ndim,
175 uint3 index [[thread_position_in_grid]],
176 uint3 grid_dim [[threads_per_grid]]) {
177 auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
178 size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
179 auto out = Op()(a[idx.x], b[idx.y]);
180 c[out_idx] = out[0];
181 d[out_idx] = out[1];
182}
METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t &stride)
Definition utils.h:123
METAL_FUNC stride_t elem_to_loc_3(uint3 elem, constant const stride_t strides[3])
Definition utils.h:135
METAL_FUNC uint2 elem_to_loc_2_nd(uint3 elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, int ndim)
Definition utils.h:200
METAL_FUNC stride_t elem_to_loc_2(uint2 elem, constant const stride_t strides[2])
Definition utils.h:129
void binary_g_nd2(device const T *a, device const T *b, device U *c, device U *d, constant const size_t a_strides[2], constant const size_t b_strides[2], uint2 index, uint2 grid_dim)
Definition binary_two.h:110
void binary_sv2(device const T *a, device const T *b, device U *c, device U *d, uint2 index, uint2 grid_dim)
Definition binary_two.h:52
void binary_vs(device const T *a, device const T *b, device U *c, device U *d, uint index)
Definition binary_two.h:28
void binary_vv2(device const T *a, device const T *b, device U *c, device U *d, uint2 index, uint2 grid_dim)
Definition binary_two.h:80
void binary_vs2(device const T *a, device const T *b, device U *c, device U *d, uint2 index, uint2 grid_dim)
Definition binary_two.h:66
void binary_g_nd3(device const T *a, device const T *b, device U *c, device U *d, constant const size_t a_strides[3], constant const size_t b_strides[3], uint3 index, uint3 grid_dim)
Definition binary_two.h:128
void binary_g_nd(device const T *a, device const T *b, device U *c, device U *d, constant const int shape[DIM], constant const size_t a_strides[DIM], constant const size_t b_strides[DIM], uint3 index, uint3 grid_dim)
Definition binary_two.h:147
void binary_sv(device const T *a, device const T *b, device U *c, device U *d, uint index)
Definition binary_two.h:16
void binary_vv(device const T *a, device const T *b, device U *c, device U *d, uint index)
Definition binary_two.h:40
void binary_g(device const T *a, device const T *b, device U *c, device U *d, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, constant const int &ndim, uint3 index, uint3 grid_dim)
Definition binary_two.h:166
void binary_g_nd1(device const T *a, device const T *b, device U *c, device U *d, constant const size_t &a_stride, constant const size_t &b_stride, uint index)
Definition binary_two.h:94
void binary_ss(device const T *a, device const T *b, device U *c, device U *d, uint index)
Definition binary_two.h:4