3template <
typename T,
typename U>
5 device
const T* src [[buffer(0)]],
6 device U* dst [[buffer(1)]],
7 uint index [[thread_position_in_grid]]) {
8 dst[index] =
static_cast<U
>(src[0]);
11template <
typename T,
typename U>
13 device
const T* src [[buffer(0)]],
14 device U* dst [[buffer(1)]],
15 uint index [[thread_position_in_grid]]) {
16 dst[index] =
static_cast<U
>(src[index]);
19template <
typename T,
typename U>
21 device
const T* src [[buffer(0)]],
22 device U* dst [[buffer(1)]],
23 uint2 index [[thread_position_in_grid]],
24 uint2 grid_dim [[threads_per_grid]]) {
25 size_t offset = index.x + grid_dim.x * size_t(index.y);
26 dst[offset] =
static_cast<U
>(src[0]);
29template <
typename T,
typename U>
31 device
const T* src [[buffer(0)]],
32 device U* dst [[buffer(1)]],
33 uint2 index [[thread_position_in_grid]],
34 uint2 grid_dim [[threads_per_grid]]) {
35 size_t offset = index.x + grid_dim.x * size_t(index.y);
36 dst[offset] =
static_cast<U
>(src[offset]);
39template <
typename T,
typename U>
41 device
const T* src [[buffer(0)]],
42 device U* dst [[buffer(1)]],
43 constant
const int64_t& src_stride [[buffer(3)]],
44 uint index [[thread_position_in_grid]]) {
46 dst[index] =
static_cast<U
>(src[src_idx]);
49template <
typename T,
typename U>
51 device
const T* src [[buffer(0)]],
52 device U* dst [[buffer(1)]],
53 constant
const int64_t* src_strides [[buffer(3)]],
54 uint2 index [[thread_position_in_grid]],
55 uint2 grid_dim [[threads_per_grid]]) {
57 int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
58 dst[dst_idx] =
static_cast<U
>(src[src_idx]);
61template <
typename T,
typename U>
63 device
const T* src [[buffer(0)]],
64 device U* dst [[buffer(1)]],
65 constant
const int64_t* src_strides [[buffer(3)]],
66 uint3 index [[thread_position_in_grid]],
67 uint3 grid_dim [[threads_per_grid]]) {
70 index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
71 dst[dst_idx] =
static_cast<U
>(src[src_idx]);
74template <
typename T,
typename U,
int DIM>
76 device
const T* src [[buffer(0)]],
77 device U* dst [[buffer(1)]],
78 constant
const int* src_shape [[buffer(2)]],
79 constant
const int64_t* src_strides [[buffer(3)]],
80 uint3 index [[thread_position_in_grid]],
81 uint3 grid_dim [[threads_per_grid]]) {
82 auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
84 index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
85 dst[dst_idx] =
static_cast<U
>(src[src_idx]);
88template <
typename T,
typename U>
90 device
const T* src [[buffer(0)]],
91 device U* dst [[buffer(1)]],
92 constant
const int* src_shape [[buffer(2)]],
93 constant
const int64_t* src_strides [[buffer(3)]],
94 constant
const int& ndim [[buffer(5)]],
95 uint3 index [[thread_position_in_grid]],
96 uint3 grid_dim [[threads_per_grid]]) {
97 auto src_idx =
elem_to_loc(index, src_shape, src_strides, ndim);
99 index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
100 dst[dst_idx] =
static_cast<U
>(src[src_idx]);
103template <
typename T,
typename U>
105 device
const T* src [[buffer(0)]],
106 device U* dst [[buffer(1)]],
107 constant
const int64_t& src_stride [[buffer(3)]],
108 constant
const int64_t& dst_stride [[buffer(4)]],
109 uint index [[thread_position_in_grid]]) {
112 dst[dst_idx] =
static_cast<U
>(src[src_idx]);
115template <
typename T,
typename U>
117 device
const T* src [[buffer(0)]],
118 device U* dst [[buffer(1)]],
119 constant
const int64_t* src_strides [[buffer(3)]],
120 constant
const int64_t* dst_strides [[buffer(4)]],
121 uint2 index [[thread_position_in_grid]]) {
124 dst[dst_idx] =
static_cast<U
>(src[src_idx]);
127template <
typename T,
typename U>
129 device
const T* src [[buffer(0)]],
130 device U* dst [[buffer(1)]],
131 constant
const int64_t* src_strides [[buffer(3)]],
132 constant
const int64_t* dst_strides [[buffer(4)]],
133 uint3 index [[thread_position_in_grid]]) {
136 dst[dst_idx] =
static_cast<U
>(src[src_idx]);
139template <
typename T,
typename U,
int DIM>
141 device
const T* src [[buffer(0)]],
142 device U* dst [[buffer(1)]],
143 constant
const int* src_shape [[buffer(2)]],
144 constant
const int64_t* src_strides [[buffer(3)]],
145 constant
const int64_t* dst_strides [[buffer(4)]],
146 uint3 index [[thread_position_in_grid]]) {
147 auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
148 auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
149 dst[dst_idx] =
static_cast<U
>(src[src_idx]);
152template <
typename T,
typename U>
154 device
const T* src [[buffer(0)]],
155 device U* dst [[buffer(1)]],
156 constant
const int* src_shape [[buffer(2)]],
157 constant
const int64_t* src_strides [[buffer(3)]],
158 constant
const int64_t* dst_strides [[buffer(4)]],
159 constant
const int& ndim [[buffer(5)]],
160 uint3 index [[thread_position_in_grid]]) {
161 auto src_idx =
elem_to_loc(index, src_shape, src_strides, ndim);
162 auto dst_idx =
elem_to_loc(index, src_shape, dst_strides, ndim);
163 dst[dst_idx] =
static_cast<U
>(src[src_idx]);