3template <
typename T, 
typename U>
 
    5    device 
const T* src [[buffer(0)]],
 
    6    device U* dst [[buffer(1)]],
 
    7    uint index [[thread_position_in_grid]]) {
 
    8  dst[index] = 
static_cast<U
>(src[0]);
 
 
   11template <
typename T, 
typename U>
 
   13    device 
const T* src [[buffer(0)]],
 
   14    device U* dst [[buffer(1)]],
 
   15    uint index [[thread_position_in_grid]]) {
 
   16  dst[index] = 
static_cast<U
>(src[index]);
 
 
   19template <
typename T, 
typename U>
 
   21    device 
const T* src [[buffer(0)]],
 
   22    device U* dst [[buffer(1)]],
 
   23    uint2 index [[thread_position_in_grid]],
 
   24    uint2 grid_dim [[threads_per_grid]]) {
 
   25  auto offset = index.x + grid_dim.x * int64_t(index.y);
 
   26  dst[offset] = 
static_cast<U
>(src[0]);
 
 
   29template <
typename T, 
typename U>
 
   31    device 
const T* src [[buffer(0)]],
 
   32    device U* dst [[buffer(1)]],
 
   33    uint2 index [[thread_position_in_grid]],
 
   34    uint2 grid_dim [[threads_per_grid]]) {
 
   35  auto offset = index.x + grid_dim.x * int64_t(index.y);
 
   36  dst[offset] = 
static_cast<U
>(src[offset]);
 
 
   39template <
typename T, 
typename U, 
typename IdxT = 
int64_t>
 
   41    device 
const T* src [[buffer(0)]],
 
   42    device U* dst [[buffer(1)]],
 
   43    constant 
const int64_t& src_stride [[buffer(3)]],
 
   44    uint index [[thread_position_in_grid]]) {
 
   46  dst[index] = 
static_cast<U
>(src[src_idx]);
 
 
   49template <
typename T, 
typename U, 
typename IdxT = 
int64_t>
 
   51    device 
const T* src [[buffer(0)]],
 
   52    device U* dst [[buffer(1)]],
 
   53    constant 
const int64_t* src_strides [[buffer(3)]],
 
   54    uint2 index [[thread_position_in_grid]],
 
   55    uint2 grid_dim [[threads_per_grid]]) {
 
   57  IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y;
 
   58  dst[dst_idx] = 
static_cast<U
>(src[src_idx]);
 
 
   61template <
typename T, 
typename U, 
typename IdxT = 
int64_t>
 
   63    device 
const T* src [[buffer(0)]],
 
   64    device U* dst [[buffer(1)]],
 
   65    constant 
const int64_t* src_strides [[buffer(3)]],
 
   66    uint3 index [[thread_position_in_grid]],
 
   67    uint3 grid_dim [[threads_per_grid]]) {
 
   70      index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z);
 
   71  dst[dst_idx] = 
static_cast<U
>(src[src_idx]);
 
 
   74template <
typename T, 
typename U, 
int N = 1, 
typename IdxT = 
int64_t>
 
   76    device 
const T* src [[buffer(0)]],
 
   77    device U* dst [[buffer(1)]],
 
   78    constant 
const int* src_shape [[buffer(2)]],
 
   79    constant 
const int64_t* src_strides [[buffer(3)]],
 
   80    constant 
const int& ndim [[buffer(5)]],
 
   81    uint3 index [[thread_position_in_grid]],
 
   82    uint3 grid_dim [[threads_per_grid]]) {
 
   84      {N * index.x, index.y, index.z}, src_shape, src_strides, ndim);
 
   87        index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
 
   88    dst[dst_idx] = 
static_cast<U
>(src[src_idx]);
 
   91  auto xshape = src_shape[ndim - 1];
 
   92  IdxT dst_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
 
   93  auto src_xstride = src_strides[ndim - 1];
 
   94  for (
int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
 
   95    dst[dst_idx + i] = 
static_cast<U
>(src[src_idx]);
 
   96    src_idx += src_xstride;
 
 
  100template <
typename T, 
typename U, 
typename IdxT = 
int64_t>
 
  102    device 
const T* src [[buffer(0)]],
 
  103    device U* dst [[buffer(1)]],
 
  104    constant 
const int64_t& src_stride [[buffer(3)]],
 
  105    constant 
const int64_t& dst_stride [[buffer(4)]],
 
  106    uint index [[thread_position_in_grid]]) {
 
  109  dst[dst_idx] = 
static_cast<U
>(src[src_idx]);
 
 
  112template <
typename T, 
typename U, 
typename IdxT = 
int64_t>
 
  114    device 
const T* src [[buffer(0)]],
 
  115    device U* dst [[buffer(1)]],
 
  116    constant 
const int64_t* src_strides [[buffer(3)]],
 
  117    constant 
const int64_t* dst_strides [[buffer(4)]],
 
  118    uint2 index [[thread_position_in_grid]]) {
 
  121  dst[dst_idx] = 
static_cast<U
>(src[src_idx]);
 
 
  124template <
typename T, 
typename U, 
typename IdxT = 
int64_t>
 
  126    device 
const T* src [[buffer(0)]],
 
  127    device U* dst [[buffer(1)]],
 
  128    constant 
const int64_t* src_strides [[buffer(3)]],
 
  129    constant 
const int64_t* dst_strides [[buffer(4)]],
 
  130    uint3 index [[thread_position_in_grid]]) {
 
  133  dst[dst_idx] = 
static_cast<U
>(src[src_idx]);
 
 
  136template <
typename T, 
typename U, 
int N = 1, 
typename IdxT = 
int64_t>
 
  138    device 
const T* src [[buffer(0)]],
 
  139    device U* dst [[buffer(1)]],
 
  140    constant 
const int* src_shape [[buffer(2)]],
 
  141    constant 
const int64_t* src_strides [[buffer(3)]],
 
  142    constant 
const int64_t* dst_strides [[buffer(4)]],
 
  143    constant 
const int& ndim [[buffer(5)]],
 
  144    uint3 index [[thread_position_in_grid]]) {
 
  146      {N * index.x, index.y, index.z},
 
  152    dst[idx.y] = 
static_cast<U
>(src[idx.x]);
 
  155  IdxT src_xstride = src_strides[ndim - 1];
 
  156  IdxT dst_xstride = dst_strides[ndim - 1];
 
  157  auto xshape = src_shape[ndim - 1];
 
  158  for (
int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
 
  159    dst[idx.y] = 
static_cast<U
>(src[idx.x]);
 
  160    idx.x += src_xstride;
 
  161    idx.y += dst_xstride;
 
 
  165template <
typename T, 
typename U, 
typename IdxT = 
int64_t>
 
  167    device 
const T* src [[buffer(0)]],
 
  168    device U* dst [[buffer(1)]],
 
  169    constant 
const int64_t& src_stride [[buffer(3)]],
 
  170    constant 
const int64_t& dst_stride [[buffer(4)]],
 
  171    constant 
const int64_t& src_offset [[buffer(6)]],
 
  172    constant 
const int64_t& dst_offset [[buffer(7)]],
 
  173    uint index [[thread_position_in_grid]]) {
 
  176  dst[dst_idx + dst_offset] = src[src_idx + src_offset];
 
 
  179template <
typename T, 
typename U, 
typename IdxT = 
int64_t>
 
  181    device 
const T* src [[buffer(0)]],
 
  182    device U* dst [[buffer(1)]],
 
  183    constant 
const int64_t* src_strides [[buffer(3)]],
 
  184    constant 
const int64_t* dst_strides [[buffer(4)]],
 
  185    constant 
const int64_t& src_offset [[buffer(6)]],
 
  186    constant 
const int64_t& dst_offset [[buffer(7)]],
 
  187    uint2 index [[thread_position_in_grid]]) {
 
  190  dst[dst_idx + dst_offset] = src[src_idx + src_offset];
 
 
  193template <
typename T, 
typename U, 
typename IdxT = 
int64_t>
 
  195    device 
const T* src [[buffer(0)]],
 
  196    device U* dst [[buffer(1)]],
 
  197    constant 
const int64_t* src_strides [[buffer(3)]],
 
  198    constant 
const int64_t* dst_strides [[buffer(4)]],
 
  199    constant 
const int64_t& src_offset [[buffer(6)]],
 
  200    constant 
const int64_t& dst_offset [[buffer(7)]],
 
  201    uint3 index [[thread_position_in_grid]]) {
 
  204  dst[dst_idx + dst_offset] = src[src_idx + src_offset];
 
 
  207template <
typename T, 
typename U, 
int N = 1, 
typename IdxT = 
int64_t>
 
  209    device 
const T* src [[buffer(0)]],
 
  210    device U* dst [[buffer(1)]],
 
  211    constant 
const int* src_shape [[buffer(2)]],
 
  212    constant 
const int64_t* src_strides [[buffer(3)]],
 
  213    constant 
const int64_t* dst_strides [[buffer(4)]],
 
  214    constant 
const int& ndim [[buffer(5)]],
 
  215    constant 
const int64_t& src_offset [[buffer(6)]],
 
  216    constant 
const int64_t& dst_offset [[buffer(7)]],
 
  217    uint3 index [[thread_position_in_grid]]) {
 
  221      {N * index.x, index.y, index.z},
 
  227    dst[idx.y] = src[idx.x];
 
  230  IdxT src_xstride = src_strides[ndim - 1];
 
  231  IdxT dst_xstride = dst_strides[ndim - 1];
 
  232  auto xshape = src_shape[ndim - 1];
 
  233  for (
int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
 
  234    dst[idx.y] = src[idx.x];
 
  235    idx.x += src_xstride;
 
  236    idx.y += dst_xstride;