#include <metal_simdgroup>
Go to the source code of this file.
|
template<typename T , int D> |
void | sdpa_vector (const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, uint3 tid, uint simd_gid, uint simd_lid) |
|
◆ sdpa_vector()
template<typename T , int D>
void sdpa_vector |
( |
const device T * | queries, |
|
|
const device T * | keys, |
|
|
const device T * | values, |
|
|
device T * | out, |
|
|
const constant int & | gqa_factor, |
|
|
const constant int & | N, |
|
|
const constant size_t & | k_stride, |
|
|
const constant size_t & | v_stride, |
|
|
const constant float & | scale, |
|
|
uint3 | tid, |
|
|
uint | simd_gid, |
|
|
uint | simd_lid ) |