MLX
Loading...
Searching...
No Matches
mlx
backend
metal
jit
gemv_masked.h
Go to the documentation of this file.
1
// Copyright © 2024 Apple Inc.
2
3
constexpr
std::string_view
gemv_masked_kernel
= R
"(
4
template [[host_name("{name}")]] [[kernel]] void
5
gemv_{trans}masked<{itype}, {outm_t}, {opm_t}, {bm}, {bn}, {sm}, {sn}, {tm}, {tn}, {nc}>(
6
const device {itype}* mat [[buffer(0)]],
7
const device {itype}* in_vec [[buffer(1)]],
8
device {itype}* out_vec [[buffer(3)]],
9
const constant int& in_vec_size [[buffer(4)]],
10
const constant int& out_vec_size [[buffer(5)]],
11
const constant int& marix_ld [[buffer(6)]],
12
const constant int& batch_ndim [[buffer(9)]],
13
const constant int* batch_shape [[buffer(10)]],
14
const constant size_t* vector_batch_stride [[buffer(11)]],
15
const constant size_t* matrix_batch_stride [[buffer(12)]],
16
const device {outm_t}* out_mask [[buffer(20)]],
17
const device {opm_t}* mat_mask [[buffer(21)]],
18
const device {opm_t}* vec_mask [[buffer(22)]],
19
const constant int* mask_strides [[buffer(23)]],
20
const constant size_t* mask_batch_strides [[buffer(24)]],
21
uint3 tid [[threadgroup_position_in_grid]],
22
uint3 lid [[thread_position_in_threadgroup]],
23
uint simd_gid [[simdgroup_index_in_threadgroup]],
24
uint simd_lid [[thread_index_in_simdgroup]]);
25
)";
gemv_masked_kernel
constexpr std::string_view gemv_masked_kernel
Definition
gemv_masked.h:3
Generated by
1.10.0