3template <
typename T,
typename U,
typename Op>
9 uint index [[thread_position_in_grid]]) {
10 auto out = Op()(a[0], b[0]);
15template <
typename T,
typename U,
typename Op>
21 uint index [[thread_position_in_grid]]) {
22 auto out = Op()(a[0], b[index]);
27template <
typename T,
typename U,
typename Op>
33 uint index [[thread_position_in_grid]]) {
34 auto out = Op()(a[index], b[0]);
39template <
typename T,
typename U,
typename Op>
45 uint index [[thread_position_in_grid]]) {
46 auto out = Op()(a[index], b[index]);
51template <
typename T,
typename U,
typename Op>
57 uint2 index [[thread_position_in_grid]],
58 uint2 grid_dim [[threads_per_grid]]) {
59 size_t offset = index.x + grid_dim.x * size_t(index.y);
60 auto out = Op()(a[0], b[offset]);
65template <
typename T,
typename U,
typename Op>
71 uint2 index [[thread_position_in_grid]],
72 uint2 grid_dim [[threads_per_grid]]) {
73 size_t offset = index.x + grid_dim.x * size_t(index.y);
74 auto out = Op()(a[offset], b[0]);
79template <
typename T,
typename U,
typename Op>
85 uint2 index [[thread_position_in_grid]],
86 uint2 grid_dim [[threads_per_grid]]) {
87 size_t offset = index.x + grid_dim.x * size_t(index.y);
88 auto out = Op()(a[offset], b[offset]);
93template <
typename T,
typename U,
typename Op>
99 constant
const size_t& a_stride,
100 constant
const size_t& b_stride,
101 uint index [[thread_position_in_grid]]) {
104 auto out = Op()(a[a_idx], b[b_idx]);
109template <
typename T,
typename U,
typename Op>
115 constant
const size_t a_strides[2],
116 constant
const size_t b_strides[2],
117 uint2 index [[thread_position_in_grid]],
118 uint2 grid_dim [[threads_per_grid]]) {
121 size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
122 auto out = Op()(a[a_idx], b[b_idx]);
127template <
typename T,
typename U,
typename Op>
133 constant
const size_t a_strides[3],
134 constant
const size_t b_strides[3],
135 uint3 index [[thread_position_in_grid]],
136 uint3 grid_dim [[threads_per_grid]]) {
140 index.x + (size_t)grid_dim.x * (index.y + (
size_t)grid_dim.y * index.z);
141 auto out = Op()(a[a_idx], b[b_idx]);
146template <
typename T,
typename U,
typename Op,
int DIM>
152 constant
const int shape[DIM],
153 constant
const size_t a_strides[DIM],
154 constant
const size_t b_strides[DIM],
155 uint3 index [[thread_position_in_grid]],
156 uint3 grid_dim [[threads_per_grid]]) {
157 auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
159 index.x + (size_t)grid_dim.x * (index.y + (
size_t)grid_dim.y * index.z);
160 auto out = Op()(a[idx.x], b[idx.y]);
165template <
typename T,
typename U,
typename Op>
171 constant
const int* shape,
172 constant
const size_t* a_strides,
173 constant
const size_t* b_strides,
174 constant
const int& ndim,
175 uint3 index [[thread_position_in_grid]],
176 uint3 grid_dim [[threads_per_grid]]) {
178 size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
179 auto out = Op()(a[idx.x], b[idx.y]);