MLX
 
Loading...
Searching...
No Matches
gemv_masked.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3constexpr std::string_view gemv_masked_kernel = R"(
4template [[host_name("{name}")]] [[kernel]] void
5gemv_{trans}masked<{itype}, {outm_t}, {opm_t}, {bm}, {bn}, {sm}, {sn}, {tm}, {tn}, {nc}>(
6 const device {itype}* mat [[buffer(0)]],
7 const device {itype}* in_vec [[buffer(1)]],
8 device {itype}* out_vec [[buffer(3)]],
9 const constant int& in_vec_size [[buffer(4)]],
10 const constant int& out_vec_size [[buffer(5)]],
11 const constant int& marix_ld [[buffer(6)]],
12 const constant int& batch_ndim [[buffer(9)]],
13 const constant int* batch_shape [[buffer(10)]],
14 const constant int64_t* vector_batch_stride [[buffer(11)]],
15 const constant int64_t* matrix_batch_stride [[buffer(12)]],
16 const device {outm_t}* out_mask [[buffer(20)]],
17 const device {opm_t}* mat_mask [[buffer(21)]],
18 const device {opm_t}* vec_mask [[buffer(22)]],
19 const constant int* mask_strides [[buffer(23)]],
20 const constant int64_t* mask_batch_strides [[buffer(24)]],
21 uint3 tid [[threadgroup_position_in_grid]],
22 uint3 lid [[thread_position_in_threadgroup]],
23 uint simd_gid [[simdgroup_index_in_threadgroup]],
24 uint simd_lid [[thread_index_in_simdgroup]]);
25)";
constexpr std::string_view gemv_masked_kernel
Definition gemv_masked.h:3