3template <
typename T,
typename U,
typename Op>
7 uint index [[thread_position_in_grid]]) {
8 out[index] = Op()(in[index]);
11template <
typename T,
typename U,
typename Op>
15 uint2 index [[thread_position_in_grid]],
16 uint2 grid_dim [[threads_per_grid]]) {
17 auto offset = index.x + grid_dim.x * int64_t(index.y);
18 out[offset] = Op()(in[offset]);
26 typename IdxT = int64_t>
30 constant
const int* in_shape,
31 constant
const int64_t* in_strides,
32 device
const int& ndim,
33 uint3 index [[thread_position_in_grid]],
34 uint3 grid_dim [[threads_per_grid]]) {
36 {N * index.x, index.y, index.z}, in_shape, in_strides, ndim);
37 auto xshape = in_shape[ndim - 1];
38 IdxT xstride = in_strides[ndim - 1];
39 IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
40 for (
int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
41 out[out_idx++] = Op()(in[idx]);
constexpr int N
Definition neon_fp16_simd.h:9