3template <
typename T,
typename Op>
7 uint index [[thread_position_in_grid]]) {
8 out[index] = Op()(in[index]);
11template <
typename T,
typename Op>
15 device
const int* in_shape,
16 device
const size_t* in_strides,
17 device
const int& ndim,
18 uint index [[thread_position_in_grid]]) {
19 auto idx =
elem_to_loc(index, in_shape, in_strides, ndim);
20 out[index] = Op()(in[idx]);