3template <
typename T, 
typename U>
 
    5    device 
const T* src [[buffer(0)]],
 
    6    device U* dst [[buffer(1)]],
 
    7    uint index [[thread_position_in_grid]]) {
 
    8  dst[index] = 
static_cast<U
>(src[0]);
 
 
   11template <
typename T, 
typename U>
 
   13    device 
const T* src [[buffer(0)]],
 
   14    device U* dst [[buffer(1)]],
 
   15    uint index [[thread_position_in_grid]]) {
 
   16  dst[index] = 
static_cast<U
>(src[index]);
 
 
   19template <
typename T, 
typename U>
 
   21    device 
const T* src [[buffer(0)]],
 
   22    device U* dst [[buffer(1)]],
 
   23    uint2 index [[thread_position_in_grid]],
 
   24    uint2 grid_dim [[threads_per_grid]]) {
 
   25  size_t offset = index.x + grid_dim.x * size_t(index.y);
 
   26  dst[offset] = 
static_cast<U
>(src[0]);
 
 
   29template <
typename T, 
typename U>
 
   31    device 
const T* src [[buffer(0)]],
 
   32    device U* dst [[buffer(1)]],
 
   33    uint2 index [[thread_position_in_grid]],
 
   34    uint2 grid_dim [[threads_per_grid]]) {
 
   35  size_t offset = index.x + grid_dim.x * size_t(index.y);
 
   36  dst[offset] = 
static_cast<U
>(src[offset]);
 
 
   39template <
typename T, 
typename U, 
typename IdxT = 
int64_t>
 
   41    device 
const T* src [[buffer(0)]],
 
   42    device U* dst [[buffer(1)]],
 
   43    constant 
const int64_t& src_stride [[buffer(3)]],
 
   44    uint index [[thread_position_in_grid]]) {
 
   46  dst[index] = 
static_cast<U
>(src[src_idx]);
 
 
   49template <
typename T, 
typename U, 
typename IdxT = 
int64_t>
 
   51    device 
const T* src [[buffer(0)]],
 
   52    device U* dst [[buffer(1)]],
 
   53    constant 
const int64_t* src_strides [[buffer(3)]],
 
   54    uint2 index [[thread_position_in_grid]],
 
   55    uint2 grid_dim [[threads_per_grid]]) {
 
   57  IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y;
 
   58  dst[dst_idx] = 
static_cast<U
>(src[src_idx]);
 
 
   61template <
typename T, 
typename U, 
typename IdxT = 
int64_t>
 
   63    device 
const T* src [[buffer(0)]],
 
   64    device U* dst [[buffer(1)]],
 
   65    constant 
const int64_t* src_strides [[buffer(3)]],
 
   66    uint3 index [[thread_position_in_grid]],
 
   67    uint3 grid_dim [[threads_per_grid]]) {
 
   70      index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z);
 
   71  dst[dst_idx] = 
static_cast<U
>(src[src_idx]);
 
 
   74template <
typename T, 
typename U, 
int N = 1, 
typename IdxT = 
int64_t>
 
   76    device 
const T* src [[buffer(0)]],
 
   77    device U* dst [[buffer(1)]],
 
   78    constant 
const int* src_shape [[buffer(2)]],
 
   79    constant 
const int64_t* src_strides [[buffer(3)]],
 
   80    constant 
const int& ndim [[buffer(5)]],
 
   81    uint3 index [[thread_position_in_grid]],
 
   82    uint3 grid_dim [[threads_per_grid]]) {
 
   84      {N * index.x, index.y, index.z}, src_shape, src_strides, ndim);
 
   87        index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
 
   88    dst[dst_idx] = 
static_cast<U
>(src[src_idx]);
 
   91  auto xshape = src_shape[ndim - 1];
 
   92  IdxT dst_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
 
   93  auto src_xstride = src_strides[ndim - 1];
 
   94  for (
int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
 
   95    dst[dst_idx + i] = 
static_cast<U
>(src[src_idx]);
 
   96    src_idx += src_xstride;
 
 
  100template <
typename T, 
typename U, 
typename IdxT = 
int64_t>
 
  102    device 
const T* src [[buffer(0)]],
 
  103    device U* dst [[buffer(1)]],
 
  104    constant 
const int64_t& src_stride [[buffer(3)]],
 
  105    constant 
const int64_t& dst_stride [[buffer(4)]],
 
  106    uint index [[thread_position_in_grid]]) {
 
  109  dst[dst_idx] = 
static_cast<U
>(src[src_idx]);
 
 
  112template <
typename T, 
typename U, 
typename IdxT = 
int64_t>
 
  114    device 
const T* src [[buffer(0)]],
 
  115    device U* dst [[buffer(1)]],
 
  116    constant 
const int64_t* src_strides [[buffer(3)]],
 
  117    constant 
const int64_t* dst_strides [[buffer(4)]],
 
  118    uint2 index [[thread_position_in_grid]]) {
 
  121  dst[dst_idx] = 
static_cast<U
>(src[src_idx]);
 
 
  124template <
typename T, 
typename U, 
typename IdxT = 
int64_t>
 
  126    device 
const T* src [[buffer(0)]],
 
  127    device U* dst [[buffer(1)]],
 
  128    constant 
const int64_t* src_strides [[buffer(3)]],
 
  129    constant 
const int64_t* dst_strides [[buffer(4)]],
 
  130    uint3 index [[thread_position_in_grid]]) {
 
  133  dst[dst_idx] = 
static_cast<U
>(src[src_idx]);
 
 
  136template <
typename T, 
typename U, 
int N = 1, 
typename IdxT = 
int64_t>
 
  138    device 
const T* src [[buffer(0)]],
 
  139    device U* dst [[buffer(1)]],
 
  140    constant 
const int* src_shape [[buffer(2)]],
 
  141    constant 
const int64_t* src_strides [[buffer(3)]],
 
  142    constant 
const int64_t* dst_strides [[buffer(4)]],
 
  143    constant 
const int& ndim [[buffer(5)]],
 
  144    uint3 index [[thread_position_in_grid]]) {
 
  146      {N * index.x, index.y, index.z},
 
  152    dst[idx.y] = 
static_cast<U
>(src[idx.x]);
 
  155  IdxT src_xstride = src_strides[ndim - 1];
 
  156  IdxT dst_xstride = dst_strides[ndim - 1];
 
  157  auto xshape = src_shape[ndim - 1];
 
  158  for (
int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
 
  159    dst[idx.y] = 
static_cast<U
>(src[idx.x]);
 
  160    idx.x += src_xstride;
 
  161    idx.y += dst_xstride;