3template <
typename T,
typename U,
typename Op>
9 uint index [[thread_position_in_grid]]) {
10 auto out = Op()(a[0], b[0]);
15template <
typename T,
typename U,
typename Op>
21 uint index [[thread_position_in_grid]]) {
22 auto out = Op()(a[0], b[index]);
27template <
typename T,
typename U,
typename Op>
33 uint index [[thread_position_in_grid]]) {
34 auto out = Op()(a[index], b[0]);
39template <
typename T,
typename U,
typename Op>
45 uint index [[thread_position_in_grid]]) {
46 auto out = Op()(a[index], b[index]);
51template <
typename T,
typename U,
typename Op>
57 constant
const size_t& a_stride,
58 constant
const size_t& b_stride,
59 uint index [[thread_position_in_grid]]) {
62 auto out = Op()(a[a_idx], b[b_idx]);
67template <
typename T,
typename U,
typename Op>
73 constant
const size_t a_strides[2],
74 constant
const size_t b_strides[2],
75 uint2 index [[thread_position_in_grid]],
76 uint2 grid_dim [[threads_per_grid]]) {
79 size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
80 auto out = Op()(a[a_idx], b[b_idx]);
85template <
typename T,
typename U,
typename Op>
91 constant
const size_t a_strides[3],
92 constant
const size_t b_strides[3],
93 uint3 index [[thread_position_in_grid]],
94 uint3 grid_dim [[threads_per_grid]]) {
98 index.x + (size_t)grid_dim.x * (index.y + (
size_t)grid_dim.y * index.z);
99 auto out = Op()(a[a_idx], b[b_idx]);
104template <
typename T,
typename U,
typename Op,
int DIM>
110 constant
const int shape[DIM],
111 constant
const size_t a_strides[DIM],
112 constant
const size_t b_strides[DIM],
113 uint3 index [[thread_position_in_grid]],
114 uint3 grid_dim [[threads_per_grid]]) {
115 auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
117 index.x + (size_t)grid_dim.x * (index.y + (
size_t)grid_dim.y * index.z);
118 auto out = Op()(a[idx.x], b[idx.y]);
123template <
typename T,
typename U,
typename Op>
129 constant
const int* shape,
130 constant
const size_t* a_strides,
131 constant
const size_t* b_strides,
132 constant
const int& ndim,
133 uint3 index [[thread_position_in_grid]],
134 uint3 grid_dim [[threads_per_grid]]) {
136 size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
137 auto out = Op()(a[idx.x], b[idx.y]);