3template <
typename T,
typename U>
5 device
const T* src [[buffer(0)]],
6 device U* dst [[buffer(1)]],
7 uint index [[thread_position_in_grid]]) {
8 dst[index] =
static_cast<U
>(src[0]);
11template <
typename T,
typename U>
13 device
const T* src [[buffer(0)]],
14 device U* dst [[buffer(1)]],
15 uint index [[thread_position_in_grid]]) {
16 dst[index] =
static_cast<U
>(src[index]);
19template <
typename T,
typename U>
21 device
const T* src [[buffer(0)]],
22 device U* dst [[buffer(1)]],
23 uint2 index [[thread_position_in_grid]],
24 uint2 grid_dim [[threads_per_grid]]) {
25 size_t offset = index.x + grid_dim.x * size_t(index.y);
26 dst[offset] =
static_cast<U
>(src[0]);
29template <
typename T,
typename U>
31 device
const T* src [[buffer(0)]],
32 device U* dst [[buffer(1)]],
33 uint2 index [[thread_position_in_grid]],
34 uint2 grid_dim [[threads_per_grid]]) {
35 size_t offset = index.x + grid_dim.x * size_t(index.y);
36 dst[offset] =
static_cast<U
>(src[offset]);
39template <
typename T,
typename U>
41 device
const T* src [[buffer(0)]],
42 device U* dst [[buffer(1)]],
43 constant
const int64_t& src_stride [[buffer(3)]],
44 uint index [[thread_position_in_grid]]) {
46 dst[index] =
static_cast<U
>(src[src_idx]);
49template <
typename T,
typename U>
51 device
const T* src [[buffer(0)]],
52 device U* dst [[buffer(1)]],
53 constant
const int64_t* src_strides [[buffer(3)]],
54 uint2 index [[thread_position_in_grid]],
55 uint2 grid_dim [[threads_per_grid]]) {
57 int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
58 dst[dst_idx] =
static_cast<U
>(src[src_idx]);
61template <
typename T,
typename U>
63 device
const T* src [[buffer(0)]],
64 device U* dst [[buffer(1)]],
65 constant
const int64_t* src_strides [[buffer(3)]],
66 uint3 index [[thread_position_in_grid]],
67 uint3 grid_dim [[threads_per_grid]]) {
70 index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
71 dst[dst_idx] =
static_cast<U
>(src[src_idx]);
74template <
typename T,
typename U,
int N = 1>
76 device
const T* src [[buffer(0)]],
77 device U* dst [[buffer(1)]],
78 constant
const int* src_shape [[buffer(2)]],
79 constant
const int64_t* src_strides [[buffer(3)]],
80 constant
const int& ndim [[buffer(5)]],
81 uint3 index [[thread_position_in_grid]],
82 uint3 grid_dim [[threads_per_grid]]) {
84 {N * index.x, index.y, index.z}, src_shape, src_strides, ndim);
87 index.x + grid_dim.x * (index.y + int64_t(grid_dim.y) * index.z);
88 dst[dst_idx] =
static_cast<U
>(src[src_idx]);
91 auto xshape = src_shape[ndim - 1];
93 N * index.x + xshape * (index.y + int64_t(grid_dim.y) * index.z);
94 auto src_xstride = src_strides[ndim - 1];
95 for (
int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
96 dst[dst_idx + i] =
static_cast<U
>(src[src_idx]);
97 src_idx += src_xstride;
101template <
typename T,
typename U>
103 device
const T* src [[buffer(0)]],
104 device U* dst [[buffer(1)]],
105 constant
const int64_t& src_stride [[buffer(3)]],
106 constant
const int64_t& dst_stride [[buffer(4)]],
107 uint index [[thread_position_in_grid]]) {
110 dst[dst_idx] =
static_cast<U
>(src[src_idx]);
113template <
typename T,
typename U>
115 device
const T* src [[buffer(0)]],
116 device U* dst [[buffer(1)]],
117 constant
const int64_t* src_strides [[buffer(3)]],
118 constant
const int64_t* dst_strides [[buffer(4)]],
119 uint2 index [[thread_position_in_grid]]) {
122 dst[dst_idx] =
static_cast<U
>(src[src_idx]);
125template <
typename T,
typename U>
127 device
const T* src [[buffer(0)]],
128 device U* dst [[buffer(1)]],
129 constant
const int64_t* src_strides [[buffer(3)]],
130 constant
const int64_t* dst_strides [[buffer(4)]],
131 uint3 index [[thread_position_in_grid]]) {
134 dst[dst_idx] =
static_cast<U
>(src[src_idx]);
137template <
typename T,
typename U,
int N = 1>
139 device
const T* src [[buffer(0)]],
140 device U* dst [[buffer(1)]],
141 constant
const int* src_shape [[buffer(2)]],
142 constant
const int64_t* src_strides [[buffer(3)]],
143 constant
const int64_t* dst_strides [[buffer(4)]],
144 constant
const int& ndim [[buffer(5)]],
145 uint3 index [[thread_position_in_grid]]) {
147 {N * index.x, index.y, index.z},
153 dst[idx.y] =
static_cast<U
>(src[idx.x]);
156 auto src_xstride = src_strides[ndim - 1];
157 auto dst_xstride = dst_strides[ndim - 1];
158 auto xshape = src_shape[ndim - 1];
159 for (
int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
160 dst[idx.y] =
static_cast<U
>(src[idx.x]);
161 idx.x += src_xstride;
162 idx.y += dst_xstride;