MLX
 
Loading...
Searching...
No Matches
copy.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3template <typename T, typename U>
4[[kernel]] void copy_s(
5 device const T* src [[buffer(0)]],
6 device U* dst [[buffer(1)]],
7 uint index [[thread_position_in_grid]]) {
8 dst[index] = static_cast<U>(src[0]);
9}
10
11template <typename T, typename U>
12[[kernel]] void copy_v(
13 device const T* src [[buffer(0)]],
14 device U* dst [[buffer(1)]],
15 uint index [[thread_position_in_grid]]) {
16 dst[index] = static_cast<U>(src[index]);
17}
18
19template <typename T, typename U>
20[[kernel]] void copy_s2(
21 device const T* src [[buffer(0)]],
22 device U* dst [[buffer(1)]],
23 uint2 index [[thread_position_in_grid]],
24 uint2 grid_dim [[threads_per_grid]]) {
25 auto offset = index.x + grid_dim.x * int64_t(index.y);
26 dst[offset] = static_cast<U>(src[0]);
27}
28
29template <typename T, typename U>
30[[kernel]] void copy_v2(
31 device const T* src [[buffer(0)]],
32 device U* dst [[buffer(1)]],
33 uint2 index [[thread_position_in_grid]],
34 uint2 grid_dim [[threads_per_grid]]) {
35 auto offset = index.x + grid_dim.x * int64_t(index.y);
36 dst[offset] = static_cast<U>(src[offset]);
37}
38
39template <typename T, typename U, typename IdxT = int64_t>
40[[kernel]] void copy_g_nd1(
41 device const T* src [[buffer(0)]],
42 device U* dst [[buffer(1)]],
43 constant const int64_t& src_stride [[buffer(3)]],
44 uint index [[thread_position_in_grid]]) {
45 auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);
46 dst[index] = static_cast<U>(src[src_idx]);
47}
48
49template <typename T, typename U, typename IdxT = int64_t>
50[[kernel]] void copy_g_nd2(
51 device const T* src [[buffer(0)]],
52 device U* dst [[buffer(1)]],
53 constant const int64_t* src_strides [[buffer(3)]],
54 uint2 index [[thread_position_in_grid]],
55 uint2 grid_dim [[threads_per_grid]]) {
56 auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);
57 IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y;
58 dst[dst_idx] = static_cast<U>(src[src_idx]);
59}
60
61template <typename T, typename U, typename IdxT = int64_t>
62[[kernel]] void copy_g_nd3(
63 device const T* src [[buffer(0)]],
64 device U* dst [[buffer(1)]],
65 constant const int64_t* src_strides [[buffer(3)]],
66 uint3 index [[thread_position_in_grid]],
67 uint3 grid_dim [[threads_per_grid]]) {
68 auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);
69 IdxT dst_idx =
70 index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z);
71 dst[dst_idx] = static_cast<U>(src[src_idx]);
72}
73
74template <typename T, typename U, int N = 1, typename IdxT = int64_t>
75[[kernel]] void copy_g(
76 device const T* src [[buffer(0)]],
77 device U* dst [[buffer(1)]],
78 constant const int* src_shape [[buffer(2)]],
79 constant const int64_t* src_strides [[buffer(3)]],
80 constant const int& ndim [[buffer(5)]],
81 uint3 index [[thread_position_in_grid]],
82 uint3 grid_dim [[threads_per_grid]]) {
83 auto src_idx = elem_to_loc<IdxT>(
84 {N * index.x, index.y, index.z}, src_shape, src_strides, ndim);
85 if (N == 1) {
86 IdxT dst_idx =
87 index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
88 dst[dst_idx] = static_cast<U>(src[src_idx]);
89 return;
90 }
91 auto xshape = src_shape[ndim - 1];
92 IdxT dst_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
93 auto src_xstride = src_strides[ndim - 1];
94 for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
95 dst[dst_idx + i] = static_cast<U>(src[src_idx]);
96 src_idx += src_xstride;
97 }
98}
99
100template <typename T, typename U, typename IdxT = int64_t>
101[[kernel]] void copy_gg_nd1(
102 device const T* src [[buffer(0)]],
103 device U* dst [[buffer(1)]],
104 constant const int64_t& src_stride [[buffer(3)]],
105 constant const int64_t& dst_stride [[buffer(4)]],
106 uint index [[thread_position_in_grid]]) {
107 auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);
108 auto dst_idx = elem_to_loc_1<IdxT>(index, dst_stride);
109 dst[dst_idx] = static_cast<U>(src[src_idx]);
110}
111
112template <typename T, typename U, typename IdxT = int64_t>
113[[kernel]] void copy_gg_nd2(
114 device const T* src [[buffer(0)]],
115 device U* dst [[buffer(1)]],
116 constant const int64_t* src_strides [[buffer(3)]],
117 constant const int64_t* dst_strides [[buffer(4)]],
118 uint2 index [[thread_position_in_grid]]) {
119 auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);
120 auto dst_idx = elem_to_loc_2<IdxT>(index, dst_strides);
121 dst[dst_idx] = static_cast<U>(src[src_idx]);
122}
123
124template <typename T, typename U, typename IdxT = int64_t>
125[[kernel]] void copy_gg_nd3(
126 device const T* src [[buffer(0)]],
127 device U* dst [[buffer(1)]],
128 constant const int64_t* src_strides [[buffer(3)]],
129 constant const int64_t* dst_strides [[buffer(4)]],
130 uint3 index [[thread_position_in_grid]]) {
131 auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);
132 auto dst_idx = elem_to_loc_3<IdxT>(index, dst_strides);
133 dst[dst_idx] = static_cast<U>(src[src_idx]);
134}
135
136template <typename T, typename U, int N = 1, typename IdxT = int64_t>
137[[kernel]] void copy_gg(
138 device const T* src [[buffer(0)]],
139 device U* dst [[buffer(1)]],
140 constant const int* src_shape [[buffer(2)]],
141 constant const int64_t* src_strides [[buffer(3)]],
142 constant const int64_t* dst_strides [[buffer(4)]],
143 constant const int& ndim [[buffer(5)]],
144 uint3 index [[thread_position_in_grid]]) {
145 auto idx = elem_to_loc_2_nd<IdxT>(
146 {N * index.x, index.y, index.z},
147 src_shape,
148 src_strides,
149 dst_strides,
150 ndim);
151 if (N == 1) {
152 dst[idx.y] = static_cast<U>(src[idx.x]);
153 return;
154 }
155 IdxT src_xstride = src_strides[ndim - 1];
156 IdxT dst_xstride = dst_strides[ndim - 1];
157 auto xshape = src_shape[ndim - 1];
158 for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
159 dst[idx.y] = static_cast<U>(src[idx.x]);
160 idx.x += src_xstride;
161 idx.y += dst_xstride;
162 }
163}
164
165template <typename T, typename U, typename IdxT = int64_t>
166[[kernel]] void copy_gg_dynamic_nd1(
167 device const T* src [[buffer(0)]],
168 device U* dst [[buffer(1)]],
169 constant const int64_t& src_stride [[buffer(3)]],
170 constant const int64_t& dst_stride [[buffer(4)]],
171 constant const int64_t& src_offset [[buffer(6)]],
172 constant const int64_t& dst_offset [[buffer(7)]],
173 uint index [[thread_position_in_grid]]) {
174 auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);
175 auto dst_idx = elem_to_loc_1<IdxT>(index, dst_stride);
176 dst[dst_idx + dst_offset] = src[src_idx + src_offset];
177}
178
179template <typename T, typename U, typename IdxT = int64_t>
180[[kernel]] void copy_gg_dynamic_nd2(
181 device const T* src [[buffer(0)]],
182 device U* dst [[buffer(1)]],
183 constant const int64_t* src_strides [[buffer(3)]],
184 constant const int64_t* dst_strides [[buffer(4)]],
185 constant const int64_t& src_offset [[buffer(6)]],
186 constant const int64_t& dst_offset [[buffer(7)]],
187 uint2 index [[thread_position_in_grid]]) {
188 auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);
189 auto dst_idx = elem_to_loc_2<IdxT>(index, dst_strides);
190 dst[dst_idx + dst_offset] = src[src_idx + src_offset];
191}
192
193template <typename T, typename U, typename IdxT = int64_t>
194[[kernel]] void copy_gg_dynamic_nd3(
195 device const T* src [[buffer(0)]],
196 device U* dst [[buffer(1)]],
197 constant const int64_t* src_strides [[buffer(3)]],
198 constant const int64_t* dst_strides [[buffer(4)]],
199 constant const int64_t& src_offset [[buffer(6)]],
200 constant const int64_t& dst_offset [[buffer(7)]],
201 uint3 index [[thread_position_in_grid]]) {
202 auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);
203 auto dst_idx = elem_to_loc_3<IdxT>(index, dst_strides);
204 dst[dst_idx + dst_offset] = src[src_idx + src_offset];
205}
206
207template <typename T, typename U, int N = 1, typename IdxT = int64_t>
208[[kernel]] void copy_gg_dynamic(
209 device const T* src [[buffer(0)]],
210 device U* dst [[buffer(1)]],
211 constant const int* src_shape [[buffer(2)]],
212 constant const int64_t* src_strides [[buffer(3)]],
213 constant const int64_t* dst_strides [[buffer(4)]],
214 constant const int& ndim [[buffer(5)]],
215 constant const int64_t& src_offset [[buffer(6)]],
216 constant const int64_t& dst_offset [[buffer(7)]],
217 uint3 index [[thread_position_in_grid]]) {
218 src += src_offset;
219 dst += dst_offset;
220 auto idx = elem_to_loc_2_nd<IdxT>(
221 {N * index.x, index.y, index.z},
222 src_shape,
223 src_strides,
224 dst_strides,
225 ndim);
226 if (N == 1) {
227 dst[idx.y] = src[idx.x];
228 return;
229 }
230 IdxT src_xstride = src_strides[ndim - 1];
231 IdxT dst_xstride = dst_strides[ndim - 1];
232 auto xshape = src_shape[ndim - 1];
233 for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
234 dst[idx.y] = src[idx.x];
235 idx.x += src_xstride;
236 idx.y += dst_xstride;
237 }
238}
METAL_FUNC IdxT elem_to_loc(IdxT elem, constant const int *shape, constant const int64_t *strides, int ndim)
Definition utils.h:93
METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const int64_t &stride)
Definition utils.h:126
METAL_FUNC vec< IdxT, 2 > elem_to_loc_2_nd(uint3 elem, constant const int *shape, constant const int64_t *a_strides, constant const int64_t *b_strides, int ndim)
Definition utils.h:145
METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const int64_t strides[2])
Definition utils.h:131
METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const int64_t strides[3])
Definition utils.h:136
void copy_g_nd1(device const T *src, device U *dst, constant const int64_t &src_stride, uint index)
Definition copy.h:40
void copy_gg_nd1(device const T *src, device U *dst, constant const int64_t &src_stride, constant const int64_t &dst_stride, uint index)
Definition copy.h:101
void copy_g_nd2(device const T *src, device U *dst, constant const int64_t *src_strides, uint2 index, uint2 grid_dim)
Definition copy.h:50
void copy_gg_nd3(device const T *src, device U *dst, constant const int64_t *src_strides, constant const int64_t *dst_strides, uint3 index)
Definition copy.h:125
void copy_g(device const T *src, device U *dst, constant const int *src_shape, constant const int64_t *src_strides, constant const int &ndim, uint3 index, uint3 grid_dim)
Definition copy.h:75
void copy_s2(device const T *src, device U *dst, uint2 index, uint2 grid_dim)
Definition copy.h:20
void copy_gg_dynamic_nd1(device const T *src, device U *dst, constant const int64_t &src_stride, constant const int64_t &dst_stride, constant const int64_t &src_offset, constant const int64_t &dst_offset, uint index)
Definition copy.h:166
void copy_gg_dynamic_nd2(device const T *src, device U *dst, constant const int64_t *src_strides, constant const int64_t *dst_strides, constant const int64_t &src_offset, constant const int64_t &dst_offset, uint2 index)
Definition copy.h:180
void copy_g_nd3(device const T *src, device U *dst, constant const int64_t *src_strides, uint3 index, uint3 grid_dim)
Definition copy.h:62
void copy_gg_dynamic(device const T *src, device U *dst, constant const int *src_shape, constant const int64_t *src_strides, constant const int64_t *dst_strides, constant const int &ndim, constant const int64_t &src_offset, constant const int64_t &dst_offset, uint3 index)
Definition copy.h:208
void copy_gg(device const T *src, device U *dst, constant const int *src_shape, constant const int64_t *src_strides, constant const int64_t *dst_strides, constant const int &ndim, uint3 index)
Definition copy.h:137
void copy_v(device const T *src, device U *dst, uint index)
Definition copy.h:12
void copy_v2(device const T *src, device U *dst, uint2 index, uint2 grid_dim)
Definition copy.h:30
void copy_s(device const T *src, device U *dst, uint index)
Definition copy.h:4
void copy_gg_nd2(device const T *src, device U *dst, constant const int64_t *src_strides, constant const int64_t *dst_strides, uint2 index)
Definition copy.h:113
void copy_gg_dynamic_nd3(device const T *src, device U *dst, constant const int64_t *src_strides, constant const int64_t *dst_strides, constant const int64_t &src_offset, constant const int64_t &dst_offset, uint3 index)
Definition copy.h:194