#include <metal_simdgroup>
Go to the source code of this file.
 | 
| template<typename T , int D>  | 
| void  | sdpa_vector (const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, uint3 tid, uint simd_gid, uint simd_lid) | 
|   | 
| template<typename T , int D>  | 
| void  | sdpa_vector_2pass_1 (const device T *queries, const device T *keys, const device T *values, device float *out, device float *sums, device float *maxs, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, uint3 tid, uint simd_gid, uint simd_lid) | 
|   | 
| template<typename T , int D>  | 
| void  | sdpa_vector_2pass_2 (const device float *partials, const device float *sums, const device float *maxs, device T *out, uint3 tid, uint simd_gid, uint simd_lid) | 
|   | 
◆ sdpa_vector()
template<typename T , int D> 
      
        
          | void sdpa_vector  | 
          ( | 
          const device T * |           queries,  | 
        
        
           | 
           | 
          const device T * |           keys,  | 
        
        
           | 
           | 
          const device T * |           values,  | 
        
        
           | 
           | 
          device T * |           out,  | 
        
        
           | 
           | 
          const constant int & |           gqa_factor,  | 
        
        
           | 
           | 
          const constant int & |           N,  | 
        
        
           | 
           | 
          const constant size_t & |           k_stride,  | 
        
        
           | 
           | 
          const constant size_t & |           v_stride,  | 
        
        
           | 
           | 
          const constant float & |           scale,  | 
        
        
           | 
           | 
          uint3 |           tid,  | 
        
        
           | 
           | 
          uint |           simd_gid,  | 
        
        
           | 
           | 
          uint |           simd_lid ) | 
        
      
 
 
◆ sdpa_vector_2pass_1()
template<typename T , int D> 
      
        
          | void sdpa_vector_2pass_1  | 
          ( | 
          const device T * |           queries,  | 
        
        
           | 
           | 
          const device T * |           keys,  | 
        
        
           | 
           | 
          const device T * |           values,  | 
        
        
           | 
           | 
          device float * |           out,  | 
        
        
           | 
           | 
          device float * |           sums,  | 
        
        
           | 
           | 
          device float * |           maxs,  | 
        
        
           | 
           | 
          const constant int & |           gqa_factor,  | 
        
        
           | 
           | 
          const constant int & |           N,  | 
        
        
           | 
           | 
          const constant size_t & |           k_stride,  | 
        
        
           | 
           | 
          const constant size_t & |           v_stride,  | 
        
        
           | 
           | 
          const constant float & |           scale,  | 
        
        
           | 
           | 
          uint3 |           tid,  | 
        
        
           | 
           | 
          uint |           simd_gid,  | 
        
        
           | 
           | 
          uint |           simd_lid ) | 
        
      
 
 
◆ sdpa_vector_2pass_2()
template<typename T , int D> 
      
        
          | void sdpa_vector_2pass_2  | 
          ( | 
          const device float * |           partials,  | 
        
        
           | 
           | 
          const device float * |           sums,  | 
        
        
           | 
           | 
          const device float * |           maxs,  | 
        
        
           | 
           | 
          device T * |           out,  | 
        
        
           | 
           | 
          uint3 |           tid,  | 
        
        
           | 
           | 
          uint |           simd_gid,  | 
        
        
           | 
           | 
          uint |           simd_lid ) |