3template <
typename T,
typename Op>
9 uint index [[thread_position_in_grid]]) {
10 d[index] = Op()(a[index], b[index], c[index]);
13template <
typename T,
typename Op>
19 uint2 index [[thread_position_in_grid]],
20 uint2 grid_dim [[threads_per_grid]]) {
21 size_t offset = index.x + grid_dim.x * size_t(index.y);
22 d[offset] = Op()(a[offset], b[offset], c[offset]);
25template <
typename T,
typename Op>
31 constant
const size_t& a_strides,
32 constant
const size_t& b_strides,
33 constant
const size_t& c_strides,
34 uint index [[thread_position_in_grid]]) {
38 d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]);
41template <
typename T,
typename Op>
47 constant
const size_t a_strides[2],
48 constant
const size_t b_strides[2],
49 constant
const size_t c_strides[2],
50 uint2 index [[thread_position_in_grid]],
51 uint2 grid_dim [[threads_per_grid]]) {
55 size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
56 d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
59template <
typename T,
typename Op>
65 constant
const size_t a_strides[3],
66 constant
const size_t b_strides[3],
67 constant
const size_t c_strides[3],
68 uint3 index [[thread_position_in_grid]],
69 uint3 grid_dim [[threads_per_grid]]) {
74 index.x + (size_t)grid_dim.x * (index.y + (
size_t)grid_dim.y * index.z);
75 d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
78template <
typename T,
typename Op,
int DIM>
84 constant
const int shape[DIM],
85 constant
const size_t a_strides[DIM],
86 constant
const size_t b_strides[DIM],
87 constant
const size_t c_strides[DIM],
88 uint3 index [[thread_position_in_grid]],
89 uint3 grid_dim [[threads_per_grid]]) {
91 elem_to_loc_3_nd<DIM>(index, shape, a_strides, b_strides, c_strides);
93 index.x + (size_t)grid_dim.x * (index.y + (
size_t)grid_dim.y * index.z);
94 d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
97template <
typename T,
typename Op>
103 constant
const int* shape,
104 constant
const size_t* a_strides,
105 constant
const size_t* b_strides,
106 constant
const size_t* c_strides,
107 constant
const int& ndim,
108 uint3 index [[thread_position_in_grid]],
109 uint3 grid_dim [[threads_per_grid]]) {
112 size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
113 d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);