9    const device T* queries [[buffer(0)]],
 
   10    const device T* keys [[buffer(1)]],
 
   11    const device T* values [[buffer(2)]],
 
   12    device T* out [[buffer(3)]],
 
   13    const constant 
int& gqa_factor,
 
   14    const constant 
int& N,
 
   15    const constant 
size_t& k_stride,
 
   16    const constant 
size_t& v_stride,
 
   17    const constant 
float& scale,
 
   18    uint3 tid [[threadgroup_position_in_grid]],
 
   19    uint simd_gid [[simdgroup_index_in_threadgroup]],
 
   20    uint simd_lid [[thread_index_in_simdgroup]]) {
 
   21  constexpr int BN = 32;
 
   22  constexpr int BD = 32;
 
   23  constexpr int elem_per_thread = D / BD;
 
   24  constexpr int stride = BN * D;
 
   28  thread U q[elem_per_thread];
 
   29  thread U k[elem_per_thread];
 
   30  thread U o[elem_per_thread];
 
   32  threadgroup U outputs[BN * BD];
 
   33  threadgroup U max_scores[BN];
 
   34  threadgroup U sum_exp_scores[BN];
 
   37  const int head_idx = tid.y;
 
   38  const int kv_head_idx = head_idx / gqa_factor;
 
   39  queries += head_idx * D + simd_lid * elem_per_thread;
 
   40  keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
 
   41  values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread;
 
   42  out += head_idx * D + simd_gid * elem_per_thread;
 
   45  for (
int i = 0; i < elem_per_thread; i++) {
 
   46    q[i] = 
static_cast<U
>(scale) * queries[i];
 
   48  for (
int i = 0; i < elem_per_thread; i++) {
 
   52  U max_score = -INFINITY;
 
   56  for (
int i = simd_gid; i < N; i += BN) {
 
   58    for (
int i = 0; i < elem_per_thread; i++) {
 
   64    for (
int i = 0; i < elem_per_thread; i++) {
 
   70    U new_max = 
max(max_score, score);
 
   71    U factor = 
fast::exp(max_score - new_max);
 
   75    sum_exp_score = sum_exp_score * factor + exp_score;
 
   78    for (
int i = 0; i < elem_per_thread; i++) {
 
   79      o[i] = o[i] * factor + exp_score * values[i];
 
   91    max_scores[simd_gid] = max_score;
 
   92    sum_exp_scores[simd_gid] = sum_exp_score;
 
   94  threadgroup_barrier(mem_flags::mem_threadgroup);
 
   95  max_score = max_scores[simd_lid];
 
   97  U factor = 
fast::exp(max_score - new_max);
 
   98  sum_exp_score = 
simd_sum(sum_exp_scores[simd_lid] * factor);
 
  101  for (
int i = 0; i < elem_per_thread; i++) {
 
  102    outputs[simd_lid * BD + simd_gid] = o[i];
 
  103    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  104    o[i] = 
simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
 
  105    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  110    for (
int i = 0; i < elem_per_thread; i++) {
 
  111      out[i] = 
static_cast<T
>(o[i]);
 
 
  118    const device T* queries [[buffer(0)]],
 
  119    const device T* keys [[buffer(1)]],
 
  120    const device T* values [[buffer(2)]],
 
  121    device 
float* out [[buffer(3)]],
 
  122    device 
float* sums [[buffer(4)]],
 
  123    device 
float* maxs [[buffer(5)]],
 
  124    const constant 
int& gqa_factor,
 
  125    const constant 
int& N,
 
  126    const constant 
size_t& k_stride,
 
  127    const constant 
size_t& v_stride,
 
  128    const constant 
float& scale,
 
  129    uint3 tid [[threadgroup_position_in_grid]],
 
  130    uint simd_gid [[simdgroup_index_in_threadgroup]],
 
  131    uint simd_lid [[thread_index_in_simdgroup]]) {
 
  132  constexpr int BN = 8;
 
  133  constexpr int BD = 32;
 
  134  constexpr int elem_per_thread = D / BD;
 
  135  constexpr int stride = BN * D;
 
  136  constexpr int blocks = 32;
 
  140  thread U q[elem_per_thread];
 
  141  thread U k[elem_per_thread];
 
  142  thread U o[elem_per_thread];
 
  144  threadgroup U outputs[BN * BD];
 
  145  threadgroup U max_scores[BN];
 
  146  threadgroup U sum_exp_scores[BN];
 
  149  const int block_idx = tid.z;
 
  150  const int head_idx = tid.y;
 
  151  const int kv_head_idx = head_idx / gqa_factor;
 
  152  queries += head_idx * D + simd_lid * elem_per_thread;
 
  153  keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D +
 
  154      simd_lid * elem_per_thread;
 
  155  values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D +
 
  156      simd_lid * elem_per_thread;
 
  157  out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread;
 
  158  sums += head_idx * blocks + block_idx;
 
  159  maxs += head_idx * blocks + block_idx;
 
  162  for (
int i = 0; i < elem_per_thread; i++) {
 
  163    q[i] = 
static_cast<U
>(scale) * queries[i];
 
  165  for (
int i = 0; i < elem_per_thread; i++) {
 
  173  for (
int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
 
  175    for (
int i = 0; i < elem_per_thread; i++) {
 
  181    for (
int i = 0; i < elem_per_thread; i++) {
 
  182      score += q[i] * k[i];
 
  187    U new_max = 
max(max_score, score);
 
  188    U factor = 
fast::exp(max_score - new_max);
 
  189    U exp_score = 
fast::exp(score - new_max);
 
  192    sum_exp_score = sum_exp_score * factor + exp_score;
 
  195    for (
int i = 0; i < elem_per_thread; i++) {
 
  196      o[i] = o[i] * factor + exp_score * values[i];
 
  200    keys += blocks * stride;
 
  201    values += blocks * stride;
 
  208    max_scores[simd_gid] = max_score;
 
  209    sum_exp_scores[simd_gid] = sum_exp_score;
 
  211  threadgroup_barrier(mem_flags::mem_threadgroup);
 
  212  max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9;
 
  214  U factor = 
fast::exp(max_score - new_max);
 
  215  sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0;
 
  216  sum_exp_score = 
simd_sum(sum_exp_score * factor);
 
  220    sums[0] = sum_exp_score;
 
  225  for (
int i = 0; i < elem_per_thread; i++) {
 
  226    outputs[simd_lid * BN + simd_gid] =
 
  227        o[i] * 
fast::exp(max_scores[simd_gid] - new_max);
 
  228    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  232      U output = outputs[simd_lid * BN];
 
  233      for (
int j = 1; j < BN; j++) {
 
  234        output += outputs[simd_lid * BN + j];
 
  236      out[i] = 
static_cast<T
>(output);
 
  238    threadgroup_barrier(mem_flags::mem_threadgroup);
 
 
  244    const device 
float* partials [[buffer(0)]],
 
  245    const device 
float* sums [[buffer(1)]],
 
  246    const device 
float* maxs [[buffer(2)]],
 
  247    device T* out [[buffer(3)]],
 
  248    uint3 tid [[threadgroup_position_in_grid]],
 
  249    uint simd_gid [[simdgroup_index_in_threadgroup]],
 
  250    uint simd_lid [[thread_index_in_simdgroup]]) {
 
  251  constexpr int BN = 32;
 
  252  constexpr int BD = 32;
 
  253  constexpr int elem_per_thread = D / BD;
 
  254  constexpr int blocks = 32;
 
  258  thread U o[elem_per_thread];
 
  259  threadgroup U outputs[BN * BD];
 
  262  const int head_idx = tid.y;
 
  263  partials += head_idx * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
 
  264  sums += head_idx * blocks;
 
  265  maxs += head_idx * blocks;
 
  266  out += head_idx * D + simd_gid * elem_per_thread;
 
  269  U max_score = maxs[simd_lid];
 
  271  U factor = 
fast::exp(max_score - new_max);
 
  272  U sum_exp_score = 
simd_sum(sums[simd_lid] * factor);
 
  276  for (
int i = 0; i < elem_per_thread; i++) {
 
  279  for (
int i = 0; i < elem_per_thread; i++) {
 
  280    outputs[simd_lid * BD + simd_gid] = o[i];
 
  281    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  282    o[i] = 
simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
 
  283    threadgroup_barrier(mem_flags::mem_threadgroup);
 
  288    for (
int i = 0; i < elem_per_thread; i++) {
 
  289      out[i] = 
static_cast<T
>(o[i]);
 
 
void sdpa_vector_2pass_2(const device float *partials, const device float *sums, const device float *maxs, device T *out, uint3 tid, uint simd_gid, uint simd_lid)
Definition sdpa_vector.h:243
 
void sdpa_vector(const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, uint3 tid, uint simd_gid, uint simd_lid)
Definition sdpa_vector.h:8
 
void sdpa_vector_2pass_1(const device T *queries, const device T *keys, const device T *values, device float *out, device float *sums, device float *maxs, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, uint3 tid, uint simd_gid, uint simd_lid)
Definition sdpa_vector.h:117