3template <
typename T,
typename Op>
9 uint index [[thread_position_in_grid]]) {
10 d[index] = Op()(a[index], b[index], c[index]);
13template <
typename T,
typename Op>
19 constant
const size_t& a_strides,
20 constant
const size_t& b_strides,
21 constant
const size_t& c_strides,
22 uint index [[thread_position_in_grid]]) {
26 d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]);
29template <
typename T,
typename Op>
35 constant
const size_t a_strides[2],
36 constant
const size_t b_strides[2],
37 constant
const size_t c_strides[2],
38 uint2 index [[thread_position_in_grid]],
39 uint2 grid_dim [[threads_per_grid]]) {
43 size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
44 d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
47template <
typename T,
typename Op>
53 constant
const size_t a_strides[3],
54 constant
const size_t b_strides[3],
55 constant
const size_t c_strides[3],
56 uint3 index [[thread_position_in_grid]],
57 uint3 grid_dim [[threads_per_grid]]) {
62 index.x + (size_t)grid_dim.x * (index.y + (
size_t)grid_dim.y * index.z);
63 d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
66template <
typename T,
typename Op,
int DIM>
72 constant
const int shape[DIM],
73 constant
const size_t a_strides[DIM],
74 constant
const size_t b_strides[DIM],
75 constant
const size_t c_strides[DIM],
76 uint3 index [[thread_position_in_grid]],
77 uint3 grid_dim [[threads_per_grid]]) {
79 elem_to_loc_3_nd<DIM>(index, shape, a_strides, b_strides, c_strides);
81 index.x + (size_t)grid_dim.x * (index.y + (
size_t)grid_dim.y * index.z);
82 d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
85template <
typename T,
typename Op>
91 constant
const int* shape,
92 constant
const size_t* a_strides,
93 constant
const size_t* b_strides,
94 constant
const size_t* c_strides,
95 constant
const int& ndim,
96 uint3 index [[thread_position_in_grid]],
97 uint3 grid_dim [[threads_per_grid]]) {
100 size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
101 d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);