MLX
Loading...
Searching...
No Matches
sdpa_vector.h File Reference
#include <metal_simdgroup>

Go to the source code of this file.

Functions

template<typename T , int D>
void sdpa_vector (const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, uint3 tid, uint simd_gid, uint simd_lid)
 

Function Documentation

◆ sdpa_vector()

template<typename T , int D>
void sdpa_vector ( const device T * queries,
const device T * keys,
const device T * values,
device T * out,
const constant int & gqa_factor,
const constant int & N,
const constant size_t & k_stride,
const constant size_t & v_stride,
const constant float & scale,
uint3 tid,
uint simd_gid,
uint simd_lid )