3template <
typename T,
typename U>
5 device
const T* src [[buffer(0)]],
6 device U* dst [[buffer(1)]],
7 uint index [[thread_position_in_grid]]) {
8 dst[index] =
static_cast<U
>(src[0]);
11template <
typename T,
typename U>
13 device
const T* src [[buffer(0)]],
14 device U* dst [[buffer(1)]],
15 uint index [[thread_position_in_grid]]) {
16 dst[index] =
static_cast<U
>(src[index]);
19template <
typename T,
typename U>
21 device
const T* src [[buffer(0)]],
22 device U* dst [[buffer(1)]],
23 uint2 index [[thread_position_in_grid]],
24 uint2 grid_dim [[threads_per_grid]]) {
25 auto offset = index.x + grid_dim.x * int64_t(index.y);
26 dst[offset] =
static_cast<U
>(src[0]);
29template <
typename T,
typename U>
31 device
const T* src [[buffer(0)]],
32 device U* dst [[buffer(1)]],
33 uint2 index [[thread_position_in_grid]],
34 uint2 grid_dim [[threads_per_grid]]) {
35 auto offset = index.x + grid_dim.x * int64_t(index.y);
36 dst[offset] =
static_cast<U
>(src[offset]);
39template <
typename T,
typename U,
typename IdxT =
int64_t>
41 device
const T* src [[buffer(0)]],
42 device U* dst [[buffer(1)]],
43 constant
const int64_t& src_stride [[buffer(3)]],
44 uint index [[thread_position_in_grid]]) {
46 dst[index] =
static_cast<U
>(src[src_idx]);
49template <
typename T,
typename U,
typename IdxT =
int64_t>
51 device
const T* src [[buffer(0)]],
52 device U* dst [[buffer(1)]],
53 constant
const int64_t* src_strides [[buffer(3)]],
54 uint2 index [[thread_position_in_grid]],
55 uint2 grid_dim [[threads_per_grid]]) {
57 IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y;
58 dst[dst_idx] =
static_cast<U
>(src[src_idx]);
61template <
typename T,
typename U,
typename IdxT =
int64_t>
63 device
const T* src [[buffer(0)]],
64 device U* dst [[buffer(1)]],
65 constant
const int64_t* src_strides [[buffer(3)]],
66 uint3 index [[thread_position_in_grid]],
67 uint3 grid_dim [[threads_per_grid]]) {
70 index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z);
71 dst[dst_idx] =
static_cast<U
>(src[src_idx]);
74template <
typename T,
typename U,
int N = 1,
typename IdxT =
int64_t>
76 device
const T* src [[buffer(0)]],
77 device U* dst [[buffer(1)]],
78 constant
const int* src_shape [[buffer(2)]],
79 constant
const int64_t* src_strides [[buffer(3)]],
80 constant
const int& ndim [[buffer(5)]],
81 uint3 index [[thread_position_in_grid]],
82 uint3 grid_dim [[threads_per_grid]]) {
84 {N * index.x, index.y, index.z}, src_shape, src_strides, ndim);
87 index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
88 dst[dst_idx] =
static_cast<U
>(src[src_idx]);
91 auto xshape = src_shape[ndim - 1];
92 IdxT dst_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
93 auto src_xstride = src_strides[ndim - 1];
94 for (
int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
95 dst[dst_idx + i] =
static_cast<U
>(src[src_idx]);
96 src_idx += src_xstride;
100template <
typename T,
typename U,
typename IdxT =
int64_t>
102 device
const T* src [[buffer(0)]],
103 device U* dst [[buffer(1)]],
104 constant
const int64_t& src_stride [[buffer(3)]],
105 constant
const int64_t& dst_stride [[buffer(4)]],
106 uint index [[thread_position_in_grid]]) {
109 dst[dst_idx] =
static_cast<U
>(src[src_idx]);
112template <
typename T,
typename U,
typename IdxT =
int64_t>
114 device
const T* src [[buffer(0)]],
115 device U* dst [[buffer(1)]],
116 constant
const int64_t* src_strides [[buffer(3)]],
117 constant
const int64_t* dst_strides [[buffer(4)]],
118 uint2 index [[thread_position_in_grid]]) {
121 dst[dst_idx] =
static_cast<U
>(src[src_idx]);
124template <
typename T,
typename U,
typename IdxT =
int64_t>
126 device
const T* src [[buffer(0)]],
127 device U* dst [[buffer(1)]],
128 constant
const int64_t* src_strides [[buffer(3)]],
129 constant
const int64_t* dst_strides [[buffer(4)]],
130 uint3 index [[thread_position_in_grid]]) {
133 dst[dst_idx] =
static_cast<U
>(src[src_idx]);
136template <
typename T,
typename U,
int N = 1,
typename IdxT =
int64_t>
138 device
const T* src [[buffer(0)]],
139 device U* dst [[buffer(1)]],
140 constant
const int* src_shape [[buffer(2)]],
141 constant
const int64_t* src_strides [[buffer(3)]],
142 constant
const int64_t* dst_strides [[buffer(4)]],
143 constant
const int& ndim [[buffer(5)]],
144 uint3 index [[thread_position_in_grid]]) {
146 {N * index.x, index.y, index.z},
152 dst[idx.y] =
static_cast<U
>(src[idx.x]);
155 IdxT src_xstride = src_strides[ndim - 1];
156 IdxT dst_xstride = dst_strides[ndim - 1];
157 auto xshape = src_shape[ndim - 1];
158 for (
int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
159 dst[idx.y] =
static_cast<U
>(src[idx.x]);
160 idx.x += src_xstride;
161 idx.y += dst_xstride;
165template <
typename T,
typename U,
typename IdxT =
int64_t>
167 device
const T* src [[buffer(0)]],
168 device U* dst [[buffer(1)]],
169 constant
const int64_t& src_stride [[buffer(3)]],
170 constant
const int64_t& dst_stride [[buffer(4)]],
171 constant
const int64_t& src_offset [[buffer(6)]],
172 constant
const int64_t& dst_offset [[buffer(7)]],
173 uint index [[thread_position_in_grid]]) {
176 dst[dst_idx + dst_offset] = src[src_idx + src_offset];
179template <
typename T,
typename U,
typename IdxT =
int64_t>
181 device
const T* src [[buffer(0)]],
182 device U* dst [[buffer(1)]],
183 constant
const int64_t* src_strides [[buffer(3)]],
184 constant
const int64_t* dst_strides [[buffer(4)]],
185 constant
const int64_t& src_offset [[buffer(6)]],
186 constant
const int64_t& dst_offset [[buffer(7)]],
187 uint2 index [[thread_position_in_grid]]) {
190 dst[dst_idx + dst_offset] = src[src_idx + src_offset];
193template <
typename T,
typename U,
typename IdxT =
int64_t>
195 device
const T* src [[buffer(0)]],
196 device U* dst [[buffer(1)]],
197 constant
const int64_t* src_strides [[buffer(3)]],
198 constant
const int64_t* dst_strides [[buffer(4)]],
199 constant
const int64_t& src_offset [[buffer(6)]],
200 constant
const int64_t& dst_offset [[buffer(7)]],
201 uint3 index [[thread_position_in_grid]]) {
204 dst[dst_idx + dst_offset] = src[src_idx + src_offset];
207template <
typename T,
typename U,
int N = 1,
typename IdxT =
int64_t>
209 device
const T* src [[buffer(0)]],
210 device U* dst [[buffer(1)]],
211 constant
const int* src_shape [[buffer(2)]],
212 constant
const int64_t* src_strides [[buffer(3)]],
213 constant
const int64_t* dst_strides [[buffer(4)]],
214 constant
const int& ndim [[buffer(5)]],
215 constant
const int64_t& src_offset [[buffer(6)]],
216 constant
const int64_t& dst_offset [[buffer(7)]],
217 uint3 index [[thread_position_in_grid]]) {
221 {N * index.x, index.y, index.z},
227 dst[idx.y] = src[idx.x];
230 IdxT src_xstride = src_strides[ndim - 1];
231 IdxT dst_xstride = dst_strides[ndim - 1];
232 auto xshape = src_shape[ndim - 1];
233 for (
int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
234 dst[idx.y] = src[idx.x];
235 idx.x += src_xstride;
236 idx.y += dst_xstride;