MLX
Loading...
Searching...
No Matches
sdpa_vector.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#include <metal_simdgroup>
4
5using namespace metal;
6
7template <typename T, int D>
8[[kernel]] void sdpa_vector(
9 const device T* queries [[buffer(0)]],
10 const device T* keys [[buffer(1)]],
11 const device T* values [[buffer(2)]],
12 device T* out [[buffer(3)]],
13 const constant int& gqa_factor,
14 const constant int& N,
15 const constant size_t& k_stride,
16 const constant size_t& v_stride,
17 const constant float& scale,
18 uint3 tid [[threadgroup_position_in_grid]],
19 uint simd_gid [[simdgroup_index_in_threadgroup]],
20 uint simd_lid [[thread_index_in_simdgroup]]) {
21 constexpr int BN = 32;
22 constexpr int BD = 32;
23 constexpr int elem_per_thread = D / BD;
24
25 const int stride = BN * D;
26
27 typedef float U;
28
29 thread U q[elem_per_thread];
30 thread U k[elem_per_thread];
31 thread U o[elem_per_thread];
32
33 threadgroup U outputs[BN * BD];
34 threadgroup U max_scores[BN];
35 threadgroup U sum_exp_scores[BN];
36
37 // Adjust positions
38 const int head_idx = tid.y;
39 const int kv_head_idx = head_idx / gqa_factor;
40 queries += head_idx * D + simd_lid * elem_per_thread;
41 keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
42 values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread;
43 out += head_idx * D + simd_gid * elem_per_thread;
44
45 // Read the query and 0 the output accumulator
46 for (int i = 0; i < elem_per_thread; i++) {
47 q[i] = static_cast<U>(scale) * queries[i];
48 }
49 for (int i = 0; i < elem_per_thread; i++) {
50 o[i] = 0;
51 }
52
53 U max_score = -INFINITY;
54 U sum_exp_score = 0;
55
56 // For each key
57 for (int i = simd_gid; i < N; i += BN) {
58 // Read the key
59 for (int i = 0; i < elem_per_thread; i++) {
60 k[i] = keys[i];
61 }
62
63 // Compute the i-th score
64 U score = 0;
65 for (int i = 0; i < elem_per_thread; i++) {
66 score += q[i] * k[i];
67 }
68 score = simd_sum(score);
69
70 // Update the accumulators
71 U new_max = max(max_score, score);
72 U factor = fast::exp(max_score - new_max);
73 U exp_score = fast::exp(score - new_max);
74
75 max_score = new_max;
76 sum_exp_score = sum_exp_score * factor + exp_score;
77
78 // Update the output accumulator
79 for (int i = 0; i < elem_per_thread; i++) {
80 o[i] = o[i] * factor + exp_score * values[i];
81 }
82
83 // Move the pointers to the next kv
84 keys += stride;
85 values += stride;
86 }
87 threadgroup_barrier(mem_flags::mem_threadgroup);
88
89 // Each thread has a partial part of the output so we need to combine them.
90
91 // First let's communicate the max and sum_exp
92 if (simd_lid == 0) {
93 max_scores[simd_gid] = max_score;
94 sum_exp_scores[simd_gid] = sum_exp_score;
95 }
96 threadgroup_barrier(mem_flags::mem_threadgroup);
97 max_score = max_scores[simd_lid];
98 U new_max = simd_max(max_score);
99 U factor = fast::exp(max_score - new_max);
100 sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
101
102 // Now we need to aggregate all the outputs
103 for (int i = 0; i < elem_per_thread; i++) {
104 outputs[simd_lid * BD + simd_gid] = o[i];
105 threadgroup_barrier(mem_flags::mem_threadgroup);
106 o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
107 threadgroup_barrier(mem_flags::mem_threadgroup);
108 }
109
110 // And write the output
111 if (simd_lid == 0) {
112 for (int i = 0; i < elem_per_thread; i++) {
113 out[i] = static_cast<T>(o[i]);
114 }
115 }
116}
METAL_FUNC bfloat16_t exp(bfloat16_t x)
Definition bf16_math.h:242
Definition bf16.h:265
METAL_FUNC bfloat16_t simd_max(bfloat16_t data)
Definition bf16_math.h:392
METAL_FUNC bfloat16_t simd_sum(bfloat16_t data)
Definition bf16_math.h:392
METAL_FUNC bfloat16_t max(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:234
void sdpa_vector(const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, uint3 tid, uint simd_gid, uint simd_lid)
Definition sdpa_vector.h:8