3template <
typename T,
typename U,
typename Op>
8 uint index [[thread_position_in_grid]]) {
9 c[index] = Op()(a[0], b[0]);
12template <
typename T,
typename U,
typename Op>
17 uint index [[thread_position_in_grid]]) {
18 c[index] = Op()(a[0], b[index]);
21template <
typename T,
typename U,
typename Op>
26 uint index [[thread_position_in_grid]]) {
27 c[index] = Op()(a[index], b[0]);
30template <
typename T,
typename U,
typename Op>
35 uint index [[thread_position_in_grid]]) {
36 c[index] = Op()(a[index], b[index]);
39template <
typename T,
typename U,
typename Op>
44 uint2 index [[thread_position_in_grid]],
45 uint2 grid_dim [[threads_per_grid]]) {
46 int64_t offset = index.x + grid_dim.x * int64_t(index.y);
47 c[offset] = Op()(a[0], b[offset]);
50template <
typename T,
typename U,
typename Op>
55 uint2 index [[thread_position_in_grid]],
56 uint2 grid_dim [[threads_per_grid]]) {
57 int64_t offset = index.x + grid_dim.x * int64_t(index.y);
58 c[offset] = Op()(a[offset], b[0]);
61template <
typename T,
typename U,
typename Op>
66 uint2 index [[thread_position_in_grid]],
67 uint2 grid_dim [[threads_per_grid]]) {
68 int64_t offset = index.x + grid_dim.x * int64_t(index.y);
69 c[offset] = Op()(a[offset], b[offset]);
72template <
typename T,
typename U,
typename Op,
typename IdxT =
int64_t>
77 constant
const int64_t& a_stride,
78 constant
const int64_t& b_stride,
79 uint index [[thread_position_in_grid]]) {
82 c[index] = Op()(a[a_idx], b[b_idx]);
85template <
typename T,
typename U,
typename Op,
typename IdxT =
int64_t>
90 constant
const int64_t a_strides[2],
91 constant
const int64_t b_strides[2],
92 uint2 index [[thread_position_in_grid]],
93 uint2 grid_dim [[threads_per_grid]]) {
96 IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
97 c[out_idx] = Op()(a[a_idx], b[b_idx]);
100template <
typename T,
typename U,
typename Op,
typename IdxT =
int64_t>
105 constant
const int64_t a_strides[3],
106 constant
const int64_t b_strides[3],
107 uint3 index [[thread_position_in_grid]],
108 uint3 grid_dim [[threads_per_grid]]) {
111 IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
112 c[out_idx] = Op()(a[a_idx], b[b_idx]);
120 typename IdxT = int64_t>
125 constant
const int* shape,
126 constant
const int64_t* a_strides,
127 constant
const int64_t* b_strides,
128 constant
const int& ndim,
129 uint3 index [[thread_position_in_grid]],
130 uint3 grid_dim [[threads_per_grid]]) {
132 {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
133 auto xshape = shape[ndim - 1];
134 IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
135 IdxT a_xstride = a_strides[ndim - 1];
136 IdxT b_xstride = b_strides[ndim - 1];
137 for (
int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
138 c[out_idx++] = Op()(a[idx.x], b[idx.y]);
constexpr int N
Definition neon_fp16_simd.h:9