MLX
Loading...
Searching...
No Matches
copy.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3template <typename T, typename U>
4[[kernel]] void copy_s(
5 device const T* src [[buffer(0)]],
6 device U* dst [[buffer(1)]],
7 uint index [[thread_position_in_grid]]) {
8 dst[index] = static_cast<U>(src[0]);
9}
10
11template <typename T, typename U>
12[[kernel]] void copy_v(
13 device const T* src [[buffer(0)]],
14 device U* dst [[buffer(1)]],
15 uint index [[thread_position_in_grid]]) {
16 dst[index] = static_cast<U>(src[index]);
17}
18
19template <typename T, typename U>
20[[kernel]] void copy_s2(
21 device const T* src [[buffer(0)]],
22 device U* dst [[buffer(1)]],
23 uint2 index [[thread_position_in_grid]],
24 uint2 grid_dim [[threads_per_grid]]) {
25 size_t offset = index.x + grid_dim.x * size_t(index.y);
26 dst[offset] = static_cast<U>(src[0]);
27}
28
29template <typename T, typename U>
30[[kernel]] void copy_v2(
31 device const T* src [[buffer(0)]],
32 device U* dst [[buffer(1)]],
33 uint2 index [[thread_position_in_grid]],
34 uint2 grid_dim [[threads_per_grid]]) {
35 size_t offset = index.x + grid_dim.x * size_t(index.y);
36 dst[offset] = static_cast<U>(src[offset]);
37}
38
39template <typename T, typename U>
40[[kernel]] void copy_g_nd1(
41 device const T* src [[buffer(0)]],
42 device U* dst [[buffer(1)]],
43 constant const int64_t& src_stride [[buffer(3)]],
44 uint index [[thread_position_in_grid]]) {
45 auto src_idx = elem_to_loc_1(index, src_stride);
46 dst[index] = static_cast<U>(src[src_idx]);
47}
48
49template <typename T, typename U>
50[[kernel]] void copy_g_nd2(
51 device const T* src [[buffer(0)]],
52 device U* dst [[buffer(1)]],
53 constant const int64_t* src_strides [[buffer(3)]],
54 uint2 index [[thread_position_in_grid]],
55 uint2 grid_dim [[threads_per_grid]]) {
56 auto src_idx = elem_to_loc_2(index, src_strides);
57 int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
58 dst[dst_idx] = static_cast<U>(src[src_idx]);
59}
60
61template <typename T, typename U>
62[[kernel]] void copy_g_nd3(
63 device const T* src [[buffer(0)]],
64 device U* dst [[buffer(1)]],
65 constant const int64_t* src_strides [[buffer(3)]],
66 uint3 index [[thread_position_in_grid]],
67 uint3 grid_dim [[threads_per_grid]]) {
68 auto src_idx = elem_to_loc_3(index, src_strides);
69 int64_t dst_idx =
70 index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
71 dst[dst_idx] = static_cast<U>(src[src_idx]);
72}
73
74template <typename T, typename U, int DIM>
75[[kernel]] void copy_g_nd(
76 device const T* src [[buffer(0)]],
77 device U* dst [[buffer(1)]],
78 constant const int* src_shape [[buffer(2)]],
79 constant const int64_t* src_strides [[buffer(3)]],
80 uint3 index [[thread_position_in_grid]],
81 uint3 grid_dim [[threads_per_grid]]) {
82 auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
83 int64_t dst_idx =
84 index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
85 dst[dst_idx] = static_cast<U>(src[src_idx]);
86}
87
88template <typename T, typename U>
89[[kernel]] void copy_g(
90 device const T* src [[buffer(0)]],
91 device U* dst [[buffer(1)]],
92 constant const int* src_shape [[buffer(2)]],
93 constant const int64_t* src_strides [[buffer(3)]],
94 constant const int& ndim [[buffer(5)]],
95 uint3 index [[thread_position_in_grid]],
96 uint3 grid_dim [[threads_per_grid]]) {
97 auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
98 int64_t dst_idx =
99 index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
100 dst[dst_idx] = static_cast<U>(src[src_idx]);
101}
102
103template <typename T, typename U>
104[[kernel]] void copy_gg_nd1(
105 device const T* src [[buffer(0)]],
106 device U* dst [[buffer(1)]],
107 constant const int64_t& src_stride [[buffer(3)]],
108 constant const int64_t& dst_stride [[buffer(4)]],
109 uint index [[thread_position_in_grid]]) {
110 auto src_idx = elem_to_loc_1(index, src_stride);
111 auto dst_idx = elem_to_loc_1(index, dst_stride);
112 dst[dst_idx] = static_cast<U>(src[src_idx]);
113}
114
115template <typename T, typename U>
116[[kernel]] void copy_gg_nd2(
117 device const T* src [[buffer(0)]],
118 device U* dst [[buffer(1)]],
119 constant const int64_t* src_strides [[buffer(3)]],
120 constant const int64_t* dst_strides [[buffer(4)]],
121 uint2 index [[thread_position_in_grid]]) {
122 auto src_idx = elem_to_loc_2(index, src_strides);
123 auto dst_idx = elem_to_loc_2(index, dst_strides);
124 dst[dst_idx] = static_cast<U>(src[src_idx]);
125}
126
127template <typename T, typename U>
128[[kernel]] void copy_gg_nd3(
129 device const T* src [[buffer(0)]],
130 device U* dst [[buffer(1)]],
131 constant const int64_t* src_strides [[buffer(3)]],
132 constant const int64_t* dst_strides [[buffer(4)]],
133 uint3 index [[thread_position_in_grid]]) {
134 auto src_idx = elem_to_loc_3(index, src_strides);
135 auto dst_idx = elem_to_loc_3(index, dst_strides);
136 dst[dst_idx] = static_cast<U>(src[src_idx]);
137}
138
139template <typename T, typename U, int DIM>
140[[kernel]] void copy_gg_nd(
141 device const T* src [[buffer(0)]],
142 device U* dst [[buffer(1)]],
143 constant const int* src_shape [[buffer(2)]],
144 constant const int64_t* src_strides [[buffer(3)]],
145 constant const int64_t* dst_strides [[buffer(4)]],
146 uint3 index [[thread_position_in_grid]]) {
147 auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
148 auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
149 dst[dst_idx] = static_cast<U>(src[src_idx]);
150}
151
152template <typename T, typename U>
153[[kernel]] void copy_gg(
154 device const T* src [[buffer(0)]],
155 device U* dst [[buffer(1)]],
156 constant const int* src_shape [[buffer(2)]],
157 constant const int64_t* src_strides [[buffer(3)]],
158 constant const int64_t* dst_strides [[buffer(4)]],
159 constant const int& ndim [[buffer(5)]],
160 uint3 index [[thread_position_in_grid]]) {
161 auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
162 auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim);
163 dst[dst_idx] = static_cast<U>(src[src_idx]);
164}
METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t &stride)
Definition utils.h:161
METAL_FUNC stride_t elem_to_loc_3(uint3 elem, constant const stride_t strides[3])
Definition utils.h:173
METAL_FUNC stride_t elem_to_loc(uint elem, device const int *shape, device const stride_t *strides, int ndim)
Definition utils.h:87
METAL_FUNC stride_t elem_to_loc_2(uint2 elem, constant const stride_t strides[2])
Definition utils.h:167
void copy_g_nd(device const T *src, device U *dst, constant const int *src_shape, constant const int64_t *src_strides, uint3 index, uint3 grid_dim)
Definition copy.h:75
void copy_g(device const T *src, device U *dst, constant const int *src_shape, constant const int64_t *src_strides, constant const int &ndim, uint3 index, uint3 grid_dim)
Definition copy.h:89
void copy_gg_nd(device const T *src, device U *dst, constant const int *src_shape, constant const int64_t *src_strides, constant const int64_t *dst_strides, uint3 index)
Definition copy.h:140
void copy_gg_nd1(device const T *src, device U *dst, constant const int64_t &src_stride, constant const int64_t &dst_stride, uint index)
Definition copy.h:104
void copy_gg_nd2(device const T *src, device U *dst, constant const int64_t *src_strides, constant const int64_t *dst_strides, uint2 index)
Definition copy.h:116
void copy_gg_nd3(device const T *src, device U *dst, constant const int64_t *src_strides, constant const int64_t *dst_strides, uint3 index)
Definition copy.h:128
void copy_s2(device const T *src, device U *dst, uint2 index, uint2 grid_dim)
Definition copy.h:20
void copy_g_nd3(device const T *src, device U *dst, constant const int64_t *src_strides, uint3 index, uint3 grid_dim)
Definition copy.h:62
void copy_gg(device const T *src, device U *dst, constant const int *src_shape, constant const int64_t *src_strides, constant const int64_t *dst_strides, constant const int &ndim, uint3 index)
Definition copy.h:153
void copy_g_nd1(device const T *src, device U *dst, constant const int64_t &src_stride, uint index)
Definition copy.h:40
void copy_v(device const T *src, device U *dst, uint index)
Definition copy.h:12
void copy_v2(device const T *src, device U *dst, uint2 index, uint2 grid_dim)
Definition copy.h:30
void copy_g_nd2(device const T *src, device U *dst, constant const int64_t *src_strides, uint2 index, uint2 grid_dim)
Definition copy.h:50
void copy_s(device const T *src, device U *dst, uint index)
Definition copy.h:4