3template <
typename T,
typename U,
typename Op>
7 uint index [[thread_position_in_grid]]) {
8 out[index] = Op()(in[index]);
11template <
typename T,
typename U,
typename Op>
15 uint2 index [[thread_position_in_grid]],
16 uint2 grid_dim [[threads_per_grid]]) {
17 size_t offset = index.x + grid_dim.x * size_t(index.y);
18 out[offset] = Op()(in[offset]);
21template <
typename T,
typename U,
typename Op,
int N = 1>
25 constant
const int* in_shape,
26 constant
const size_t* in_strides,
27 device
const int& ndim,
28 uint3 index [[thread_position_in_grid]],
29 uint3 grid_dim [[threads_per_grid]]) {
31 elem_to_loc({N * index.x, index.y, index.z}, in_shape, in_strides, ndim);
32 auto xshape = in_shape[ndim - 1];
33 auto xstride = in_strides[ndim - 1];
35 N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
36 for (
int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
37 out[out_idx++] = Op()(in[idx]);