#include <metal_simdgroup>
Go to the source code of this file.
Functions | |
template<typename T, int D, int V = D> | |
void | sdpa_vector (const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &gqa_factor, const constant int &N, const constant size_t &k_head_stride, const constant size_t &k_seq_stride, const constant size_t &v_head_stride, const constant size_t &v_seq_stride, const constant float &scale, const device bool *mask, const constant int &mask_kv_seq_stride, const constant int &mask_q_seq_stride, const constant int &mask_head_stride, uint3 tid, uint3 tpg, uint simd_gid, uint simd_lid) |
template<typename T, int D, int V = D> | |
void | sdpa_vector_2pass_1 (const device T *queries, const device T *keys, const device T *values, device float *out, device float *sums, device float *maxs, const constant int &gqa_factor, const constant int &N, const constant size_t &k_head_stride, const constant size_t &k_seq_stride, const constant size_t &v_head_stride, const constant size_t &v_seq_stride, const constant float &scale, const device bool *mask, const constant int &mask_kv_seq_stride, const constant int &mask_q_seq_stride, const constant int &mask_head_stride, uint3 tid, uint3 tpg, uint simd_gid, uint simd_lid) |
template<typename T, int D> | |
void | sdpa_vector_2pass_2 (const device float *partials, const device float *sums, const device float *maxs, device T *out, uint3 tid, uint3 tpg, uint simd_gid, uint simd_lid) |
Variables | |
constant bool | has_mask |
constant bool | query_transposed |
void sdpa_vector | ( | const device T * | queries, |
const device T * | keys, | ||
const device T * | values, | ||
device T * | out, | ||
const constant int & | gqa_factor, | ||
const constant int & | N, | ||
const constant size_t & | k_head_stride, | ||
const constant size_t & | k_seq_stride, | ||
const constant size_t & | v_head_stride, | ||
const constant size_t & | v_seq_stride, | ||
const constant float & | scale, | ||
const device bool * | mask, | ||
const constant int & | mask_kv_seq_stride, | ||
const constant int & | mask_q_seq_stride, | ||
const constant int & | mask_head_stride, | ||
uint3 | tid, | ||
uint3 | tpg, | ||
uint | simd_gid, | ||
uint | simd_lid ) |
void sdpa_vector_2pass_1 | ( | const device T * | queries, |
const device T * | keys, | ||
const device T * | values, | ||
device float * | out, | ||
device float * | sums, | ||
device float * | maxs, | ||
const constant int & | gqa_factor, | ||
const constant int & | N, | ||
const constant size_t & | k_head_stride, | ||
const constant size_t & | k_seq_stride, | ||
const constant size_t & | v_head_stride, | ||
const constant size_t & | v_seq_stride, | ||
const constant float & | scale, | ||
const device bool * | mask, | ||
const constant int & | mask_kv_seq_stride, | ||
const constant int & | mask_q_seq_stride, | ||
const constant int & | mask_head_stride, | ||
uint3 | tid, | ||
uint3 | tpg, | ||
uint | simd_gid, | ||
uint | simd_lid ) |
void sdpa_vector_2pass_2 | ( | const device float * | partials, |
const device float * | sums, | ||
const device float * | maxs, | ||
device T * | out, | ||
uint3 | tid, | ||
uint3 | tpg, | ||
uint | simd_gid, | ||
uint | simd_lid ) |
constant bool has_mask |
constant bool query_transposed |