mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Support transposed head/seq for kv (#1950)
* support transposed head/seq for kv * fix flaky test * nit
This commit is contained in:
@@ -15,8 +15,10 @@ template <typename T, int D, int V = D>
|
||||
device T* out [[buffer(3)]],
|
||||
const constant int& gqa_factor,
|
||||
const constant int& N,
|
||||
const constant size_t& k_stride,
|
||||
const constant size_t& v_stride,
|
||||
const constant size_t& k_head_stride,
|
||||
const constant size_t& k_seq_stride,
|
||||
const constant size_t& v_head_stride,
|
||||
const constant size_t& v_seq_stride,
|
||||
const constant float& scale,
|
||||
const device bool* mask [[function_constant(has_mask)]],
|
||||
const constant int& mask_kv_seq_stride [[function_constant(has_mask)]],
|
||||
@@ -30,8 +32,8 @@ template <typename T, int D, int V = D>
|
||||
constexpr int BD = 32;
|
||||
constexpr int qk_per_thread = D / BD;
|
||||
constexpr int v_per_thread = V / BD;
|
||||
constexpr int inner_k_stride = BN * D;
|
||||
constexpr int inner_v_stride = BN * V;
|
||||
int inner_k_stride = BN * int(k_seq_stride);
|
||||
int inner_v_stride = BN * int(v_seq_stride);
|
||||
|
||||
typedef float U;
|
||||
|
||||
@@ -51,8 +53,10 @@ template <typename T, int D, int V = D>
|
||||
const int q_offset =
|
||||
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
|
||||
queries += q_offset * D + simd_lid * qk_per_thread;
|
||||
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * qk_per_thread;
|
||||
values += kv_head_idx * v_stride + simd_gid * V + simd_lid * v_per_thread;
|
||||
keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride +
|
||||
simd_lid * qk_per_thread;
|
||||
values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride +
|
||||
simd_lid * v_per_thread;
|
||||
if (has_mask) {
|
||||
mask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
|
||||
q_seq_idx * mask_q_seq_stride;
|
||||
@@ -147,8 +151,10 @@ template <typename T, int D, int V = D>
|
||||
device float* maxs [[buffer(5)]],
|
||||
const constant int& gqa_factor,
|
||||
const constant int& N,
|
||||
const constant size_t& k_stride,
|
||||
const constant size_t& v_stride,
|
||||
const constant size_t& k_head_stride,
|
||||
const constant size_t& k_seq_stride,
|
||||
const constant size_t& v_head_stride,
|
||||
const constant size_t& v_seq_stride,
|
||||
const constant float& scale,
|
||||
const device bool* mask [[function_constant(has_mask)]],
|
||||
const constant int& mask_kv_seq_stride [[function_constant(has_mask)]],
|
||||
@@ -162,8 +168,8 @@ template <typename T, int D, int V = D>
|
||||
constexpr int BD = 32;
|
||||
constexpr int qk_per_thread = D / BD;
|
||||
constexpr int v_per_thread = V / BD;
|
||||
constexpr int inner_k_stride = BN * D;
|
||||
constexpr int inner_v_stride = BN * V;
|
||||
int inner_k_stride = BN * int(k_seq_stride);
|
||||
int inner_v_stride = BN * int(v_seq_stride);
|
||||
constexpr int blocks = 32;
|
||||
|
||||
typedef float U;
|
||||
@@ -186,10 +192,10 @@ template <typename T, int D, int V = D>
|
||||
const int kv_head_idx = head_idx / gqa_factor;
|
||||
|
||||
queries += q_offset * D + simd_lid * qk_per_thread;
|
||||
keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D +
|
||||
simd_lid * qk_per_thread;
|
||||
values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * V +
|
||||
simd_lid * v_per_thread;
|
||||
keys += kv_head_idx * k_head_stride +
|
||||
(block_idx * BN + simd_gid) * k_seq_stride + simd_lid * qk_per_thread;
|
||||
values += kv_head_idx * v_head_stride +
|
||||
(block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread;
|
||||
out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
|
||||
if (has_mask) {
|
||||
mask += head_idx * mask_head_stride +
|
||||
|
||||
Reference in New Issue
Block a user