MLX
Loading...
Searching...
No Matches
copy.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3template <typename T, typename U>
4[[kernel]] void copy_s(
5 device const T* src [[buffer(0)]],
6 device U* dst [[buffer(1)]],
7 uint index [[thread_position_in_grid]]) {
8 dst[index] = static_cast<U>(src[0]);
9}
10
11template <typename T, typename U>
12[[kernel]] void copy_v(
13 device const T* src [[buffer(0)]],
14 device U* dst [[buffer(1)]],
15 uint index [[thread_position_in_grid]]) {
16 dst[index] = static_cast<U>(src[index]);
17}
18
19template <typename T, typename U>
20[[kernel]] void copy_g_nd1(
21 device const T* src [[buffer(0)]],
22 device U* dst [[buffer(1)]],
23 constant const int64_t& src_stride [[buffer(3)]],
24 uint index [[thread_position_in_grid]]) {
25 auto src_idx = elem_to_loc_1(index, src_stride);
26 dst[index] = static_cast<U>(src[src_idx]);
27}
28
29template <typename T, typename U>
30[[kernel]] void copy_g_nd2(
31 device const T* src [[buffer(0)]],
32 device U* dst [[buffer(1)]],
33 constant const int64_t* src_strides [[buffer(3)]],
34 uint2 index [[thread_position_in_grid]],
35 uint2 grid_dim [[threads_per_grid]]) {
36 auto src_idx = elem_to_loc_2(index, src_strides);
37 int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
38 dst[dst_idx] = static_cast<U>(src[src_idx]);
39}
40
41template <typename T, typename U>
42[[kernel]] void copy_g_nd3(
43 device const T* src [[buffer(0)]],
44 device U* dst [[buffer(1)]],
45 constant const int64_t* src_strides [[buffer(3)]],
46 uint3 index [[thread_position_in_grid]],
47 uint3 grid_dim [[threads_per_grid]]) {
48 auto src_idx = elem_to_loc_3(index, src_strides);
49 int64_t dst_idx =
50 index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
51 dst[dst_idx] = static_cast<U>(src[src_idx]);
52}
53
54template <typename T, typename U, int DIM>
55[[kernel]] void copy_g_nd(
56 device const T* src [[buffer(0)]],
57 device U* dst [[buffer(1)]],
58 constant const int* src_shape [[buffer(2)]],
59 constant const int64_t* src_strides [[buffer(3)]],
60 uint3 index [[thread_position_in_grid]],
61 uint3 grid_dim [[threads_per_grid]]) {
62 auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
63 int64_t dst_idx =
64 index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
65 dst[dst_idx] = static_cast<U>(src[src_idx]);
66}
67
68template <typename T, typename U>
69[[kernel]] void copy_g(
70 device const T* src [[buffer(0)]],
71 device U* dst [[buffer(1)]],
72 constant const int* src_shape [[buffer(2)]],
73 constant const int64_t* src_strides [[buffer(3)]],
74 constant const int& ndim [[buffer(5)]],
75 uint3 index [[thread_position_in_grid]],
76 uint3 grid_dim [[threads_per_grid]]) {
77 auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
78 int64_t dst_idx =
79 index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
80 dst[dst_idx] = static_cast<U>(src[src_idx]);
81}
82
83template <typename T, typename U>
84[[kernel]] void copy_gg_nd1(
85 device const T* src [[buffer(0)]],
86 device U* dst [[buffer(1)]],
87 constant const int64_t& src_stride [[buffer(3)]],
88 constant const int64_t& dst_stride [[buffer(4)]],
89 uint index [[thread_position_in_grid]]) {
90 auto src_idx = elem_to_loc_1(index, src_stride);
91 auto dst_idx = elem_to_loc_1(index, dst_stride);
92 dst[dst_idx] = static_cast<U>(src[src_idx]);
93}
94
95template <typename T, typename U>
96[[kernel]] void copy_gg_nd2(
97 device const T* src [[buffer(0)]],
98 device U* dst [[buffer(1)]],
99 constant const int64_t* src_strides [[buffer(3)]],
100 constant const int64_t* dst_strides [[buffer(4)]],
101 uint2 index [[thread_position_in_grid]]) {
102 auto src_idx = elem_to_loc_2(index, src_strides);
103 auto dst_idx = elem_to_loc_2(index, dst_strides);
104 dst[dst_idx] = static_cast<U>(src[src_idx]);
105}
106
107template <typename T, typename U>
108[[kernel]] void copy_gg_nd3(
109 device const T* src [[buffer(0)]],
110 device U* dst [[buffer(1)]],
111 constant const int64_t* src_strides [[buffer(3)]],
112 constant const int64_t* dst_strides [[buffer(4)]],
113 uint3 index [[thread_position_in_grid]]) {
114 auto src_idx = elem_to_loc_3(index, src_strides);
115 auto dst_idx = elem_to_loc_3(index, dst_strides);
116 dst[dst_idx] = static_cast<U>(src[src_idx]);
117}
118
119template <typename T, typename U, int DIM>
120[[kernel]] void copy_gg_nd(
121 device const T* src [[buffer(0)]],
122 device U* dst [[buffer(1)]],
123 constant const int* src_shape [[buffer(2)]],
124 constant const int64_t* src_strides [[buffer(3)]],
125 constant const int64_t* dst_strides [[buffer(4)]],
126 uint3 index [[thread_position_in_grid]]) {
127 auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
128 auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
129 dst[dst_idx] = static_cast<U>(src[src_idx]);
130}
131
132template <typename T, typename U>
133[[kernel]] void copy_gg(
134 device const T* src [[buffer(0)]],
135 device U* dst [[buffer(1)]],
136 constant const int* src_shape [[buffer(2)]],
137 constant const int64_t* src_strides [[buffer(3)]],
138 constant const int64_t* dst_strides [[buffer(4)]],
139 constant const int& ndim [[buffer(5)]],
140 uint3 index [[thread_position_in_grid]]) {
141 auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
142 auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim);
143 dst[dst_idx] = static_cast<U>(src[src_idx]);
144}
METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t &stride)
Definition utils.h:123
METAL_FUNC stride_t elem_to_loc_3(uint3 elem, constant const stride_t strides[3])
Definition utils.h:135
METAL_FUNC stride_t elem_to_loc(uint elem, device const int *shape, device const stride_t *strides, int ndim)
Definition utils.h:77
METAL_FUNC stride_t elem_to_loc_2(uint2 elem, constant const stride_t strides[2])
Definition utils.h:129
void copy_g_nd(device const T *src, device U *dst, constant const int *src_shape, constant const int64_t *src_strides, uint3 index, uint3 grid_dim)
Definition copy.h:55
void copy_g(device const T *src, device U *dst, constant const int *src_shape, constant const int64_t *src_strides, constant const int &ndim, uint3 index, uint3 grid_dim)
Definition copy.h:69
void copy_gg_nd(device const T *src, device U *dst, constant const int *src_shape, constant const int64_t *src_strides, constant const int64_t *dst_strides, uint3 index)
Definition copy.h:120
void copy_gg_nd1(device const T *src, device U *dst, constant const int64_t &src_stride, constant const int64_t &dst_stride, uint index)
Definition copy.h:84
void copy_gg_nd2(device const T *src, device U *dst, constant const int64_t *src_strides, constant const int64_t *dst_strides, uint2 index)
Definition copy.h:96
void copy_gg_nd3(device const T *src, device U *dst, constant const int64_t *src_strides, constant const int64_t *dst_strides, uint3 index)
Definition copy.h:108
void copy_g_nd3(device const T *src, device U *dst, constant const int64_t *src_strides, uint3 index, uint3 grid_dim)
Definition copy.h:42
void copy_gg(device const T *src, device U *dst, constant const int *src_shape, constant const int64_t *src_strides, constant const int64_t *dst_strides, constant const int &ndim, uint3 index)
Definition copy.h:133
void copy_g_nd1(device const T *src, device U *dst, constant const int64_t &src_stride, uint index)
Definition copy.h:20
void copy_v(device const T *src, device U *dst, uint index)
Definition copy.h:12
void copy_g_nd2(device const T *src, device U *dst, constant const int64_t *src_strides, uint2 index, uint2 grid_dim)
Definition copy.h:30
void copy_s(device const T *src, device U *dst, uint index)
Definition copy.h:4