MLX
Loading...
Searching...
No Matches
sdpa_vector.h File Reference
#include <metal_simdgroup>

Go to the source code of this file.

Functions

template<typename T , int D>
void sdpa_vector (const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, uint3 tid, uint simd_gid, uint simd_lid)
 
template<typename T , int D>
void sdpa_vector_2pass_1 (const device T *queries, const device T *keys, const device T *values, device float *out, device float *sums, device float *maxs, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, uint3 tid, uint simd_gid, uint simd_lid)
 
template<typename T , int D>
void sdpa_vector_2pass_2 (const device float *partials, const device float *sums, const device float *maxs, device T *out, uint3 tid, uint simd_gid, uint simd_lid)
 

Function Documentation

◆ sdpa_vector()

template<typename T , int D>
void sdpa_vector ( const device T * queries,
const device T * keys,
const device T * values,
device T * out,
const constant int & gqa_factor,
const constant int & N,
const constant size_t & k_stride,
const constant size_t & v_stride,
const constant float & scale,
uint3 tid,
uint simd_gid,
uint simd_lid )

◆ sdpa_vector_2pass_1()

template<typename T , int D>
void sdpa_vector_2pass_1 ( const device T * queries,
const device T * keys,
const device T * values,
device float * out,
device float * sums,
device float * maxs,
const constant int & gqa_factor,
const constant int & N,
const constant size_t & k_stride,
const constant size_t & v_stride,
const constant float & scale,
uint3 tid,
uint simd_gid,
uint simd_lid )

◆ sdpa_vector_2pass_2()

template<typename T , int D>
void sdpa_vector_2pass_2 ( const device float * partials,
const device float * sums,
const device float * maxs,
device T * out,
uint3 tid,
uint simd_gid,
uint simd_lid )