#include <metal_simdgroup>
Go to the source code of this file.
|
template<typename T, int D, int V = D> |
void | sdpa_vector (const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, const device bool *mask, const constant int &mask_seq_stride, const constant int &mask_head_stride, uint3 tid, uint simd_gid, uint simd_lid) |
|
template<typename T, int D, int V = D> |
void | sdpa_vector_2pass_1 (const device T *queries, const device T *keys, const device T *values, device float *out, device float *sums, device float *maxs, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, const device bool *mask, const constant int &mask_seq_stride, const constant int &mask_head_stride, uint3 tid, uint simd_gid, uint simd_lid) |
|
template<typename T, int D> |
void | sdpa_vector_2pass_2 (const device float *partials, const device float *sums, const device float *maxs, device T *out, uint3 tid, uint simd_gid, uint simd_lid) |
|
◆ sdpa_vector()
template<typename T, int D, int V = D>
void sdpa_vector |
( |
const device T * | queries, |
|
|
const device T * | keys, |
|
|
const device T * | values, |
|
|
device T * | out, |
|
|
const constant int & | gqa_factor, |
|
|
const constant int & | N, |
|
|
const constant size_t & | k_stride, |
|
|
const constant size_t & | v_stride, |
|
|
const constant float & | scale, |
|
|
const device bool * | mask, |
|
|
const constant int & | mask_seq_stride, |
|
|
const constant int & | mask_head_stride, |
|
|
uint3 | tid, |
|
|
uint | simd_gid, |
|
|
uint | simd_lid ) |
◆ sdpa_vector_2pass_1()
template<typename T, int D, int V = D>
void sdpa_vector_2pass_1 |
( |
const device T * | queries, |
|
|
const device T * | keys, |
|
|
const device T * | values, |
|
|
device float * | out, |
|
|
device float * | sums, |
|
|
device float * | maxs, |
|
|
const constant int & | gqa_factor, |
|
|
const constant int & | N, |
|
|
const constant size_t & | k_stride, |
|
|
const constant size_t & | v_stride, |
|
|
const constant float & | scale, |
|
|
const device bool * | mask, |
|
|
const constant int & | mask_seq_stride, |
|
|
const constant int & | mask_head_stride, |
|
|
uint3 | tid, |
|
|
uint | simd_gid, |
|
|
uint | simd_lid ) |
◆ sdpa_vector_2pass_2()
template<typename T, int D>
void sdpa_vector_2pass_2 |
( |
const device float * | partials, |
|
|
const device float * | sums, |
|
|
const device float * | maxs, |
|
|
device T * | out, |
|
|
uint3 | tid, |
|
|
uint | simd_gid, |
|
|
uint | simd_lid ) |
◆ has_mask