MLX
 
Loading...
Searching...
No Matches
sdpa_vector.h File Reference
#include <metal_simdgroup>

Go to the source code of this file.

Functions

template<typename T, int D, int V = D>
void sdpa_vector (const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, const device bool *mask, const constant int &mask_seq_stride, const constant int &mask_head_stride, uint3 tid, uint simd_gid, uint simd_lid)
 
template<typename T, int D, int V = D>
void sdpa_vector_2pass_1 (const device T *queries, const device T *keys, const device T *values, device float *out, device float *sums, device float *maxs, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, const device bool *mask, const constant int &mask_seq_stride, const constant int &mask_head_stride, uint3 tid, uint simd_gid, uint simd_lid)
 
template<typename T, int D>
void sdpa_vector_2pass_2 (const device float *partials, const device float *sums, const device float *maxs, device T *out, uint3 tid, uint simd_gid, uint simd_lid)
 

Variables

constant bool has_mask
 

Function Documentation

◆ sdpa_vector()

template<typename T, int D, int V = D>
void sdpa_vector ( const device T * queries,
const device T * keys,
const device T * values,
device T * out,
const constant int & gqa_factor,
const constant int & N,
const constant size_t & k_stride,
const constant size_t & v_stride,
const constant float & scale,
const device bool * mask,
const constant int & mask_seq_stride,
const constant int & mask_head_stride,
uint3 tid,
uint simd_gid,
uint simd_lid )

◆ sdpa_vector_2pass_1()

template<typename T, int D, int V = D>
void sdpa_vector_2pass_1 ( const device T * queries,
const device T * keys,
const device T * values,
device float * out,
device float * sums,
device float * maxs,
const constant int & gqa_factor,
const constant int & N,
const constant size_t & k_stride,
const constant size_t & v_stride,
const constant float & scale,
const device bool * mask,
const constant int & mask_seq_stride,
const constant int & mask_head_stride,
uint3 tid,
uint simd_gid,
uint simd_lid )

◆ sdpa_vector_2pass_2()

template<typename T, int D>
void sdpa_vector_2pass_2 ( const device float * partials,
const device float * sums,
const device float * maxs,
device T * out,
uint3 tid,
uint simd_gid,
uint simd_lid )

Variable Documentation

◆ has_mask

constant bool has_mask