3template <
typename T,
typename U>
5 device
const T* src [[buffer(0)]],
6 device U* dst [[buffer(1)]],
7 uint index [[thread_position_in_grid]]) {
8 dst[index] =
static_cast<U
>(src[0]);
11template <
typename T,
typename U>
13 device
const T* src [[buffer(0)]],
14 device U* dst [[buffer(1)]],
15 uint index [[thread_position_in_grid]]) {
16 dst[index] =
static_cast<U
>(src[index]);
19template <
typename T,
typename U>
21 device
const T* src [[buffer(0)]],
22 device U* dst [[buffer(1)]],
23 constant
const int64_t& src_stride [[buffer(3)]],
24 uint index [[thread_position_in_grid]]) {
26 dst[index] =
static_cast<U
>(src[src_idx]);
29template <
typename T,
typename U>
31 device
const T* src [[buffer(0)]],
32 device U* dst [[buffer(1)]],
33 constant
const int64_t* src_strides [[buffer(3)]],
34 uint2 index [[thread_position_in_grid]],
35 uint2 grid_dim [[threads_per_grid]]) {
37 int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
38 dst[dst_idx] =
static_cast<U
>(src[src_idx]);
41template <
typename T,
typename U>
43 device
const T* src [[buffer(0)]],
44 device U* dst [[buffer(1)]],
45 constant
const int64_t* src_strides [[buffer(3)]],
46 uint3 index [[thread_position_in_grid]],
47 uint3 grid_dim [[threads_per_grid]]) {
50 index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
51 dst[dst_idx] =
static_cast<U
>(src[src_idx]);
54template <
typename T,
typename U,
int DIM>
56 device
const T* src [[buffer(0)]],
57 device U* dst [[buffer(1)]],
58 constant
const int* src_shape [[buffer(2)]],
59 constant
const int64_t* src_strides [[buffer(3)]],
60 uint3 index [[thread_position_in_grid]],
61 uint3 grid_dim [[threads_per_grid]]) {
62 auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
64 index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
65 dst[dst_idx] =
static_cast<U
>(src[src_idx]);
68template <
typename T,
typename U>
70 device
const T* src [[buffer(0)]],
71 device U* dst [[buffer(1)]],
72 constant
const int* src_shape [[buffer(2)]],
73 constant
const int64_t* src_strides [[buffer(3)]],
74 constant
const int& ndim [[buffer(5)]],
75 uint3 index [[thread_position_in_grid]],
76 uint3 grid_dim [[threads_per_grid]]) {
77 auto src_idx =
elem_to_loc(index, src_shape, src_strides, ndim);
79 index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
80 dst[dst_idx] =
static_cast<U
>(src[src_idx]);
83template <
typename T,
typename U>
85 device
const T* src [[buffer(0)]],
86 device U* dst [[buffer(1)]],
87 constant
const int64_t& src_stride [[buffer(3)]],
88 constant
const int64_t& dst_stride [[buffer(4)]],
89 uint index [[thread_position_in_grid]]) {
92 dst[dst_idx] =
static_cast<U
>(src[src_idx]);
95template <
typename T,
typename U>
97 device
const T* src [[buffer(0)]],
98 device U* dst [[buffer(1)]],
99 constant
const int64_t* src_strides [[buffer(3)]],
100 constant
const int64_t* dst_strides [[buffer(4)]],
101 uint2 index [[thread_position_in_grid]]) {
104 dst[dst_idx] =
static_cast<U
>(src[src_idx]);
107template <
typename T,
typename U>
109 device
const T* src [[buffer(0)]],
110 device U* dst [[buffer(1)]],
111 constant
const int64_t* src_strides [[buffer(3)]],
112 constant
const int64_t* dst_strides [[buffer(4)]],
113 uint3 index [[thread_position_in_grid]]) {
116 dst[dst_idx] =
static_cast<U
>(src[src_idx]);
119template <
typename T,
typename U,
int DIM>
121 device
const T* src [[buffer(0)]],
122 device U* dst [[buffer(1)]],
123 constant
const int* src_shape [[buffer(2)]],
124 constant
const int64_t* src_strides [[buffer(3)]],
125 constant
const int64_t* dst_strides [[buffer(4)]],
126 uint3 index [[thread_position_in_grid]]) {
127 auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
128 auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
129 dst[dst_idx] =
static_cast<U
>(src[src_idx]);
132template <
typename T,
typename U>
134 device
const T* src [[buffer(0)]],
135 device U* dst [[buffer(1)]],
136 constant
const int* src_shape [[buffer(2)]],
137 constant
const int64_t* src_strides [[buffer(3)]],
138 constant
const int64_t* dst_strides [[buffer(4)]],
139 constant
const int& ndim [[buffer(5)]],
140 uint3 index [[thread_position_in_grid]]) {
141 auto src_idx =
elem_to_loc(index, src_shape, src_strides, ndim);
142 auto dst_idx =
elem_to_loc(index, src_shape, dst_strides, ndim);
143 dst[dst_idx] =
static_cast<U
>(src[src_idx]);