#include <metal_simdgroup>
Go to the source code of this file.
 | 
| template<typename T , int D>  | 
| void  | sdpa_vector (const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant float &scale, uint3 tid, uint simd_gid, uint simd_lid) | 
|   | 
◆ sdpa_vector()
template<typename T , int D> 
      
        
          | void sdpa_vector  | 
          ( | 
          const device T * |           queries,  | 
        
        
           | 
           | 
          const device T * |           keys,  | 
        
        
           | 
           | 
          const device T * |           values,  | 
        
        
           | 
           | 
          device T * |           out,  | 
        
        
           | 
           | 
          const constant int & |           gqa_factor,  | 
        
        
           | 
           | 
          const constant int & |           N,  | 
        
        
           | 
           | 
          const constant size_t & |           k_stride,  | 
        
        
           | 
           | 
          const constant float & |           scale,  | 
        
        
           | 
           | 
          uint3 |           tid,  | 
        
        
           | 
           | 
          uint |           simd_gid,  | 
        
        
           | 
           | 
          uint |           simd_lid ) |