MLX
Loading...
Searching...
No Matches
steel_gemm.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3constexpr std::string_view steel_gemm_fused_kernels = R"(
4template [[host_name("{name}")]]
5[[kernel]] void gemm<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {trans_a}, {trans_b}, float>(
6 const device {itype} *A [[buffer(0)]],
7 const device {itype} *B [[buffer(1)]],
8 const device {itype} *C [[buffer(2), function_constant(use_out_source)]],
9 device {itype} *D [[buffer(3)]],
10 const constant GEMMParams* params [[buffer(4)]],
11 const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
12 const constant int* batch_shape [[buffer(6)]],
13 const constant size_t* batch_strides [[buffer(7)]],
14 const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
15 const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
16 const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
17 const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
18 const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
19 const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
20 uint simd_lane_id [[thread_index_in_simdgroup]],
21 uint simd_group_id [[simdgroup_index_in_threadgroup]],
22 uint3 tid [[threadgroup_position_in_grid]],
23 uint3 lid [[thread_position_in_threadgroup]]);
24)";
25
26constexpr std::string_view steel_gemm_masked_kernels = R"(
27template [[host_name("{name}")]] [[kernel]] void
28block_masked_gemm<
29 {itype},
30 {outmasktype},
31 {opmasktype},
32 {bm},
33 {bn},
34 {bk},
35 {wm},
36 {wn},
37 {trans_a},
38 {trans_b},
39 {mn_aligned},
40 {k_aligned}>(
41 const device {itype}* A [[buffer(0)]],
42 const device {itype}* B [[buffer(1)]],
43 device {itype}* D [[buffer(3)]],
44 const constant GEMMParams* params [[buffer(4)]],
45 const constant int* batch_shape [[buffer(6)]],
46 const constant size_t* batch_strides [[buffer(7)]],
47 const device {outmasktype}* out_mask [[buffer(10)]],
48 const device {opmasktype}* lhs_mask [[buffer(11)]],
49 const device {opmasktype}* rhs_mask [[buffer(12)]],
50 const constant int* mask_strides [[buffer(13)]],
51 uint simd_lane_id [[thread_index_in_simdgroup]],
52 uint simd_group_id [[simdgroup_index_in_threadgroup]],
53 uint3 tid [[threadgroup_position_in_grid]],
54 uint3 lid [[thread_position_in_threadgroup]]);
55)";
56
57constexpr std::string_view steel_gemm_splitk_kernels = R"(
58template [[host_name("{name}")]] [[kernel]] void
59gemm_splitk<
60 {itype},
61 {otype},
62 {bm},
63 {bn},
64 {bk},
65 {wm},
66 {wn},
67 {trans_a},
68 {trans_b},
69 {mn_aligned},
70 {k_aligned}>(
71 const device {itype}* A [[buffer(0)]],
72 const device {itype}* B [[buffer(1)]],
73 device {otype}* C [[buffer(2)]],
74 const constant GEMMSpiltKParams* params [[buffer(3)]],
75 uint simd_lane_id [[thread_index_in_simdgroup]],
76 uint simd_group_id [[simdgroup_index_in_threadgroup]],
77 uint3 tid [[threadgroup_position_in_grid]],
78 uint3 lid [[thread_position_in_threadgroup]]);
79)";
80
81constexpr std::string_view steel_gemm_splitk_accum_kernels = R"(
82template [[host_name("{name}")]] [[kernel]] void
83gemm_splitk_accum<{atype}, {otype}>(
84 const device {atype}* C_split [[buffer(0)]],
85 device {otype}* D [[buffer(1)]],
86 const constant int& k_partitions [[buffer(2)]],
87 const constant int& partition_stride [[buffer(3)]],
88 const constant int& ldd [[buffer(4)]],
89 uint2 gid [[thread_position_in_grid]]);
90)";
91
92constexpr std::string_view steel_gemm_splitk_accum_axbpy_kernels = R"(
93template [[host_name("{name}")]] [[kernel]] void
94gemm_splitk_accum_axpby<{atype}, {otype}>(
95 const device {atype}* C_split [[buffer(0)]],
96 device {otype}* D [[buffer(1)]],
97 const constant int& k_partitions [[buffer(2)]],
98 const constant int& partition_stride [[buffer(3)]],
99 const constant int& ldd [[buffer(4)]],
100 const device {otype}* C [[buffer(5)]],
101 const constant int& ldc [[buffer(6)]],
102 const constant int& fdc [[buffer(7)]],
103 const constant float& alpha [[buffer(8)]],
104 const constant float& beta [[buffer(9)]],
105 uint2 gid [[thread_position_in_grid]]);
106)";
constexpr std::string_view steel_gemm_splitk_accum_kernels
Definition steel_gemm.h:81
constexpr std::string_view steel_gemm_splitk_accum_axbpy_kernels
Definition steel_gemm.h:92
constexpr std::string_view steel_gemm_fused_kernels
Definition steel_gemm.h:3
constexpr std::string_view steel_gemm_masked_kernels
Definition steel_gemm.h:26
constexpr std::string_view steel_gemm_splitk_kernels
Definition steel_gemm.h:57