3template <
typename T,
typename Op>
7 uint index [[thread_position_in_grid]]) {
8 out[index] = Op()(in[index]);
11template <
typename T,
typename Op>
15 uint2 index [[thread_position_in_grid]],
16 uint2 grid_dim [[threads_per_grid]]) {
17 size_t offset = index.x + grid_dim.x * size_t(index.y);
18 out[offset] = Op()(in[offset]);
21template <
typename T,
typename Op>
25 device
const int* in_shape,
26 device
const size_t* in_strides,
27 device
const int& ndim,
28 uint index [[thread_position_in_grid]]) {
29 auto idx =
elem_to_loc(index, in_shape, in_strides, ndim);
30 out[index] = Op()(in[idx]);