MLX
Loading...
Searching...
No Matches
copy.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3constexpr std::string_view copy_kernels = R"(
4template [[host_name("s_{0}")]] [[kernel]] void copy_s<{1}, {2}>(
5 device const {1}* src [[buffer(0)]],
6 device {2}* dst [[buffer(1)]],
7 uint index [[thread_position_in_grid]]);
8template [[host_name("v_{0}")]] [[kernel]] void copy_v<{1}, {2}>(
9 device const {1}* src [[buffer(0)]],
10 device {2}* dst [[buffer(1)]],
11 uint index [[thread_position_in_grid]]);
12
13template [[host_name("g4_{0}")]] [[kernel]] void
14copy_g_nd<{1}, {2}, 4>(
15 device const {1}* src [[buffer(0)]],
16 device {2}* dst [[buffer(1)]],
17 constant const int* src_shape [[buffer(2)]],
18 constant const int64_t* src_strides [[buffer(3)]],
19 uint3 index [[thread_position_in_grid]],
20 uint3 grid_dim [[threads_per_grid]]);
21template [[host_name("gg4_{0}")]] [[kernel]] void
22copy_gg_nd<{1}, {2}, 4>(
23 device const {1}* src [[buffer(0)]],
24 device {2}* dst [[buffer(1)]],
25 constant const int* src_shape [[buffer(2)]],
26 constant const int64_t* src_strides [[buffer(3)]],
27 constant const int64_t* dst_strides [[buffer(4)]],
28 uint3 index [[thread_position_in_grid]]);
29template [[host_name("g5_{0}")]] [[kernel]] void
30copy_g_nd<{1}, {2}, 5>(
31 device const {1}* src [[buffer(0)]],
32 device {2}* dst [[buffer(1)]],
33 constant const int* src_shape [[buffer(2)]],
34 constant const int64_t* src_strides [[buffer(3)]],
35 uint3 index [[thread_position_in_grid]],
36 uint3 grid_dim [[threads_per_grid]]);
37template [[host_name("gg5_{0}")]] [[kernel]] void
38copy_gg_nd<{1}, {2}, 5>(
39 device const {1}* src [[buffer(0)]],
40 device {2}* dst [[buffer(1)]],
41 constant const int* src_shape [[buffer(2)]],
42 constant const int64_t* src_strides [[buffer(3)]],
43 constant const int64_t* dst_strides [[buffer(4)]],
44 uint3 index [[thread_position_in_grid]]);
45template [[host_name("g1_{0}")]] [[kernel]] void copy_g_nd1<{1}, {2}>(
46 device const {1}* src [[buffer(0)]],
47 device {2}* dst [[buffer(1)]],
48 constant const int64_t& src_stride [[buffer(3)]],
49 uint index [[thread_position_in_grid]]);
50template [[host_name("g2_{0}")]] [[kernel]] void copy_g_nd2<{1}, {2}>(
51 device const {1}* src [[buffer(0)]],
52 device {2}* dst [[buffer(1)]],
53 constant const int64_t* src_strides [[buffer(3)]],
54 uint2 index [[thread_position_in_grid]],
55 uint2 grid_dim [[threads_per_grid]]);
56template [[host_name("g3_{0}")]] [[kernel]] void copy_g_nd3<{1}, {2}>(
57 device const {1}* src [[buffer(0)]],
58 device {2}* dst [[buffer(1)]],
59 constant const int64_t* src_strides [[buffer(3)]],
60 uint3 index [[thread_position_in_grid]],
61 uint3 grid_dim [[threads_per_grid]]);
62template [[host_name("gg1_{0}")]] [[kernel]] void
63copy_gg_nd1<{1}, {2}>(
64 device const {1}* src [[buffer(0)]],
65 device {2}* dst [[buffer(1)]],
66 constant const int64_t& src_stride [[buffer(3)]],
67 constant const int64_t& dst_stride [[buffer(4)]],
68 uint index [[thread_position_in_grid]]);
69template [[host_name("gg2_{0}")]] [[kernel]] void
70copy_gg_nd2<{1}, {2}>(
71 device const {1}* src [[buffer(0)]],
72 device {2}* dst [[buffer(1)]],
73 constant const int64_t* src_strides [[buffer(3)]],
74 constant const int64_t* dst_strides [[buffer(4)]],
75 uint2 index [[thread_position_in_grid]]);
76template [[host_name("gg3_{0}")]] [[kernel]] void
77copy_gg_nd3<{1}, {2}>(
78 device const {1}* src [[buffer(0)]],
79 device {2}* dst [[buffer(1)]],
80 constant const int64_t* src_strides [[buffer(3)]],
81 constant const int64_t* dst_strides [[buffer(4)]],
82 uint3 index [[thread_position_in_grid]]);
83
84template [[host_name("g_{0}")]] [[kernel]] void copy_g<{1}, {2}>(
85 device const {1}* src [[buffer(0)]],
86 device {2}* dst [[buffer(1)]],
87 constant const int* src_shape [[buffer(2)]],
88 constant const int64_t* src_strides [[buffer(3)]],
89 constant const int& ndim [[buffer(5)]],
90 uint3 index [[thread_position_in_grid]],
91 uint3 grid_dim [[threads_per_grid]]);
92template [[host_name("gg_{0}")]] [[kernel]] void copy_gg<{1}, {2}>(
93 device const {1}* src [[buffer(0)]],
94 device {2}* dst [[buffer(1)]],
95 constant const int* src_shape [[buffer(2)]],
96 constant const int64_t* src_strides [[buffer(3)]],
97 constant const int64_t* dst_strides [[buffer(4)]],
98 constant const int& ndim [[buffer(5)]],
99 uint3 index [[thread_position_in_grid]]);
100)";
constexpr std::string_view copy_kernels
Definition copy.h:3