3template <
typename T,
typename U,
typename Op>
8 uint index [[thread_position_in_grid]]) {
9 c[index] = Op()(a[0], b[0]);
12template <
typename T,
typename U,
typename Op>
17 uint index [[thread_position_in_grid]]) {
18 c[index] = Op()(a[0], b[index]);
21template <
typename T,
typename U,
typename Op>
26 uint index [[thread_position_in_grid]]) {
27 c[index] = Op()(a[index], b[0]);
30template <
typename T,
typename U,
typename Op>
35 uint index [[thread_position_in_grid]]) {
36 c[index] = Op()(a[index], b[index]);
39template <
typename T,
typename U,
typename Op>
44 constant
const size_t& a_stride,
45 constant
const size_t& b_stride,
46 uint index [[thread_position_in_grid]]) {
49 c[index] = Op()(a[a_idx], b[b_idx]);
52template <
typename T,
typename U,
typename Op>
57 constant
const size_t a_strides[2],
58 constant
const size_t b_strides[2],
59 uint2 index [[thread_position_in_grid]],
60 uint2 grid_dim [[threads_per_grid]]) {
63 size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
64 c[out_idx] = Op()(a[a_idx], b[b_idx]);
67template <
typename T,
typename U,
typename Op>
72 constant
const size_t a_strides[3],
73 constant
const size_t b_strides[3],
74 uint3 index [[thread_position_in_grid]],
75 uint3 grid_dim [[threads_per_grid]]) {
79 index.x + (size_t)grid_dim.x * (index.y + (
size_t)grid_dim.y * index.z);
80 c[out_idx] = Op()(a[a_idx], b[b_idx]);
83template <
typename T,
typename U,
typename Op,
int DIM>
88 constant
const int shape[DIM],
89 constant
const size_t a_strides[DIM],
90 constant
const size_t b_strides[DIM],
91 uint3 index [[thread_position_in_grid]],
92 uint3 grid_dim [[threads_per_grid]]) {
93 auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
95 index.x + (size_t)grid_dim.x * (index.y + (
size_t)grid_dim.y * index.z);
96 c[out_idx] = Op()(a[idx.x], b[idx.y]);
99template <
typename T,
typename U,
typename Op>
104 constant
const int* shape,
105 constant
const size_t* a_strides,
106 constant
const size_t* b_strides,
107 constant
const int& ndim,
108 uint3 index [[thread_position_in_grid]],
109 uint3 grid_dim [[threads_per_grid]]) {
111 size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
112 c[out_idx] = Op()(a[idx.x], b[idx.y]);