mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Allow different value dimensions in sdpa_vector (#1811)
This commit is contained in:
parent
b7c9f1d38f
commit
f5cc1eea72
@ -8,14 +8,23 @@ L = 16384
|
|||||||
H = 32
|
H = 32
|
||||||
H_k = H // 4
|
H_k = H // 4
|
||||||
D = 128
|
D = 128
|
||||||
|
V = 128
|
||||||
dtype = mx.float16
|
dtype = mx.float16
|
||||||
loops = 10
|
loops = 10
|
||||||
|
|
||||||
|
|
||||||
def attention(q, k, v, mask=None):
|
def upproject(x, w):
|
||||||
|
if w is None:
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return x @ w.T
|
||||||
|
|
||||||
|
|
||||||
|
def attention(q, k, v, mask=None, w=None):
|
||||||
def _sdpa(q, k, v):
|
def _sdpa(q, k, v):
|
||||||
B, Hq, L, D = q.shape
|
B, Hq, L, D = q.shape
|
||||||
_, Hk, S, _ = k.shape
|
_, Hk, S, _ = k.shape
|
||||||
|
_, _, _, V = v.shape
|
||||||
q = q.reshape(B, Hk, Hq // Hk, L, D)
|
q = q.reshape(B, Hk, Hq // Hk, L, D)
|
||||||
k = k[:, :, None, :, :]
|
k = k[:, :, None, :, :]
|
||||||
v = v[:, :, None, :, :]
|
v = v[:, :, None, :, :]
|
||||||
@ -25,16 +34,18 @@ def attention(q, k, v, mask=None):
|
|||||||
s = mx.where(m, s, mx.finfo(s.dtype).min)
|
s = mx.where(m, s, mx.finfo(s.dtype).min)
|
||||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||||
o = p @ v
|
o = p @ v
|
||||||
return o.reshape(B, Hq, L, D)
|
return o.reshape(B, Hq, L, V)
|
||||||
|
|
||||||
for i in range(loops):
|
for i in range(loops):
|
||||||
q = _sdpa(q, k, v)
|
q = _sdpa(q, k, v)
|
||||||
|
q = upproject(q, w)
|
||||||
return q
|
return q
|
||||||
|
|
||||||
|
|
||||||
def sdpa(q, k, v, mask=None):
|
def sdpa(q, k, v, mask=None, w=None):
|
||||||
for i in range(loops):
|
for i in range(loops):
|
||||||
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
||||||
|
q = upproject(q, w)
|
||||||
return q
|
return q
|
||||||
|
|
||||||
|
|
||||||
@ -42,34 +53,37 @@ def time_self_attention_primitives():
|
|||||||
mx.random.seed(3)
|
mx.random.seed(3)
|
||||||
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||||
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||||
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
|
||||||
mx.eval(q, k, v)
|
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
|
||||||
time_fn(attention, q, k, v)
|
mx.eval(q, k, v, w)
|
||||||
|
time_fn(attention, q, k, v, w=w)
|
||||||
|
|
||||||
|
|
||||||
def time_self_attention_sdpa():
|
def time_self_attention_sdpa():
|
||||||
mx.random.seed(3)
|
mx.random.seed(3)
|
||||||
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||||
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||||
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
|
||||||
mx.eval(q, k, v)
|
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
|
||||||
time_fn(sdpa, q, k, v)
|
mx.eval(q, k, v, w)
|
||||||
|
time_fn(sdpa, q, k, v, w=w)
|
||||||
|
|
||||||
|
|
||||||
def time_self_attention_sdpa_with_mask():
|
def time_self_attention_sdpa_with_mask():
|
||||||
mx.random.seed(3)
|
mx.random.seed(3)
|
||||||
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||||
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||||
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
|
||||||
|
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
|
||||||
mask = mx.full((L,), True)
|
mask = mx.full((L,), True)
|
||||||
mask[L // 2 :] = False
|
mask[L // 2 :] = False
|
||||||
mx.eval(q, k, v, mask)
|
mx.eval(q, k, v, mask, w)
|
||||||
|
|
||||||
def sdpa_mask(*args):
|
def sdpa_mask(*args):
|
||||||
return sdpa(*args, mask=mask)
|
return sdpa(*args, mask=mask, w=w)
|
||||||
|
|
||||||
def attention_mask(*args):
|
def attention_mask(*args):
|
||||||
return attention(*args, mask=mask)
|
return attention(*args, mask=mask, w=w)
|
||||||
|
|
||||||
time_fn(attention_mask, q, k, v)
|
time_fn(attention_mask, q, k, v)
|
||||||
time_fn(sdpa_mask, q, k, v)
|
time_fn(sdpa_mask, q, k, v)
|
||||||
|
@ -7,15 +7,34 @@ using namespace metal;
|
|||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
// SDPA vector instantiations
|
// SDPA vector instantiations
|
||||||
#define instantiate_sdpa_vector(type, head_dim) \
|
#define instantiate_sdpa_vector_aggregation(type, value_dim) \
|
||||||
instantiate_kernel("sdpa_vector_" #type "_" #head_dim, sdpa_vector, type, head_dim) \
|
instantiate_kernel( \
|
||||||
instantiate_kernel("sdpa_vector_2pass_1_" #type "_" #head_dim, sdpa_vector_2pass_1, type, head_dim) \
|
"sdpa_vector_2pass_2_" #type "_" #value_dim, \
|
||||||
instantiate_kernel("sdpa_vector_2pass_2_" #type "_" #head_dim, sdpa_vector_2pass_2, type, head_dim)
|
sdpa_vector_2pass_2, \
|
||||||
|
type, \
|
||||||
|
value_dim)
|
||||||
|
|
||||||
|
#define instantiate_sdpa_vector(type, qk_dim, value_dim) \
|
||||||
|
instantiate_kernel( \
|
||||||
|
"sdpa_vector_" #type "_" #qk_dim "_" #value_dim, \
|
||||||
|
sdpa_vector, \
|
||||||
|
type, \
|
||||||
|
qk_dim, \
|
||||||
|
value_dim) \
|
||||||
|
instantiate_kernel( \
|
||||||
|
"sdpa_vector_2pass_1_" #type "_" #qk_dim "_" #value_dim, \
|
||||||
|
sdpa_vector_2pass_1, \
|
||||||
|
type, \
|
||||||
|
qk_dim, \
|
||||||
|
value_dim)
|
||||||
|
|
||||||
#define instantiate_sdpa_vector_heads(type) \
|
#define instantiate_sdpa_vector_heads(type) \
|
||||||
instantiate_sdpa_vector(type, 64) \
|
instantiate_sdpa_vector(type, 64, 64) \
|
||||||
instantiate_sdpa_vector(type, 96) \
|
instantiate_sdpa_vector(type, 96, 96) \
|
||||||
instantiate_sdpa_vector(type, 128)
|
instantiate_sdpa_vector(type, 128, 128) \
|
||||||
|
instantiate_sdpa_vector_aggregation(type, 64) \
|
||||||
|
instantiate_sdpa_vector_aggregation(type, 96) \
|
||||||
|
instantiate_sdpa_vector_aggregation(type, 128)
|
||||||
|
|
||||||
instantiate_sdpa_vector_heads(float)
|
instantiate_sdpa_vector_heads(float)
|
||||||
instantiate_sdpa_vector_heads(bfloat16_t)
|
instantiate_sdpa_vector_heads(bfloat16_t)
|
||||||
|
@ -6,7 +6,7 @@ using namespace metal;
|
|||||||
|
|
||||||
constant bool has_mask [[function_constant(20)]];
|
constant bool has_mask [[function_constant(20)]];
|
||||||
|
|
||||||
template <typename T, int D>
|
template <typename T, int D, int V = D>
|
||||||
[[kernel]] void sdpa_vector(
|
[[kernel]] void sdpa_vector(
|
||||||
const device T* queries [[buffer(0)]],
|
const device T* queries [[buffer(0)]],
|
||||||
const device T* keys [[buffer(1)]],
|
const device T* keys [[buffer(1)]],
|
||||||
@ -25,14 +25,16 @@ template <typename T, int D>
|
|||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
constexpr int BN = 32;
|
constexpr int BN = 32;
|
||||||
constexpr int BD = 32;
|
constexpr int BD = 32;
|
||||||
constexpr int elem_per_thread = D / BD;
|
constexpr int qk_per_thread = D / BD;
|
||||||
constexpr int stride = BN * D;
|
constexpr int v_per_thread = V / BD;
|
||||||
|
constexpr int inner_k_stride = BN * D;
|
||||||
|
constexpr int inner_v_stride = BN * V;
|
||||||
|
|
||||||
typedef float U;
|
typedef float U;
|
||||||
|
|
||||||
thread U q[elem_per_thread];
|
thread U q[qk_per_thread];
|
||||||
thread U k[elem_per_thread];
|
thread U k[qk_per_thread];
|
||||||
thread U o[elem_per_thread];
|
thread U o[v_per_thread];
|
||||||
|
|
||||||
threadgroup U outputs[BN * BD];
|
threadgroup U outputs[BN * BD];
|
||||||
threadgroup U max_scores[BN];
|
threadgroup U max_scores[BN];
|
||||||
@ -41,19 +43,19 @@ template <typename T, int D>
|
|||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int head_idx = tid.y;
|
const int head_idx = tid.y;
|
||||||
const int kv_head_idx = head_idx / gqa_factor;
|
const int kv_head_idx = head_idx / gqa_factor;
|
||||||
queries += head_idx * D + simd_lid * elem_per_thread;
|
queries += head_idx * D + simd_lid * qk_per_thread;
|
||||||
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
|
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * qk_per_thread;
|
||||||
values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread;
|
values += kv_head_idx * v_stride + simd_gid * V + simd_lid * v_per_thread;
|
||||||
if (has_mask) {
|
if (has_mask) {
|
||||||
mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride;
|
mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride;
|
||||||
}
|
}
|
||||||
out += head_idx * D + simd_gid * elem_per_thread;
|
out += head_idx * V + simd_gid * v_per_thread;
|
||||||
|
|
||||||
// Read the query and 0 the output accumulator
|
// Read the query and 0 the output accumulator
|
||||||
for (int i = 0; i < elem_per_thread; i++) {
|
for (int i = 0; i < qk_per_thread; i++) {
|
||||||
q[i] = static_cast<U>(scale) * queries[i];
|
q[i] = static_cast<U>(scale) * queries[i];
|
||||||
}
|
}
|
||||||
for (int i = 0; i < elem_per_thread; i++) {
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
o[i] = 0;
|
o[i] = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -64,13 +66,13 @@ template <typename T, int D>
|
|||||||
for (int i = simd_gid; i < N; i += BN) {
|
for (int i = simd_gid; i < N; i += BN) {
|
||||||
if (!has_mask || mask[0]) {
|
if (!has_mask || mask[0]) {
|
||||||
// Read the key
|
// Read the key
|
||||||
for (int j = 0; j < elem_per_thread; j++) {
|
for (int j = 0; j < qk_per_thread; j++) {
|
||||||
k[j] = keys[j];
|
k[j] = keys[j];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute the i-th score
|
// Compute the i-th score
|
||||||
U score = 0;
|
U score = 0;
|
||||||
for (int j = 0; j < elem_per_thread; j++) {
|
for (int j = 0; j < qk_per_thread; j++) {
|
||||||
score += q[j] * k[j];
|
score += q[j] * k[j];
|
||||||
}
|
}
|
||||||
score = simd_sum(score);
|
score = simd_sum(score);
|
||||||
@ -84,14 +86,14 @@ template <typename T, int D>
|
|||||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||||
|
|
||||||
// Update the output accumulator
|
// Update the output accumulator
|
||||||
for (int j = 0; j < elem_per_thread; j++) {
|
for (int j = 0; j < v_per_thread; j++) {
|
||||||
o[j] = o[j] * factor + exp_score * values[j];
|
o[j] = o[j] * factor + exp_score * values[j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Move the pointers to the next kv
|
// Move the pointers to the next kv
|
||||||
keys += stride;
|
keys += inner_k_stride;
|
||||||
values += stride;
|
values += inner_v_stride;
|
||||||
if (has_mask) {
|
if (has_mask) {
|
||||||
mask += BN * mask_seq_stride;
|
mask += BN * mask_seq_stride;
|
||||||
}
|
}
|
||||||
@ -111,7 +113,7 @@ template <typename T, int D>
|
|||||||
sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
|
sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
|
||||||
|
|
||||||
// Now we need to aggregate all the outputs
|
// Now we need to aggregate all the outputs
|
||||||
for (int i = 0; i < elem_per_thread; i++) {
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
outputs[simd_lid * BD + simd_gid] = o[i];
|
outputs[simd_lid * BD + simd_gid] = o[i];
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
|
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
|
||||||
@ -120,13 +122,13 @@ template <typename T, int D>
|
|||||||
|
|
||||||
// And write the output
|
// And write the output
|
||||||
if (simd_lid == 0) {
|
if (simd_lid == 0) {
|
||||||
for (int i = 0; i < elem_per_thread; i++) {
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
out[i] = static_cast<T>(o[i]);
|
out[i] = static_cast<T>(o[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int D>
|
template <typename T, int D, int V = D>
|
||||||
[[kernel]] void sdpa_vector_2pass_1(
|
[[kernel]] void sdpa_vector_2pass_1(
|
||||||
const device T* queries [[buffer(0)]],
|
const device T* queries [[buffer(0)]],
|
||||||
const device T* keys [[buffer(1)]],
|
const device T* keys [[buffer(1)]],
|
||||||
@ -147,15 +149,17 @@ template <typename T, int D>
|
|||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
constexpr int BN = 8;
|
constexpr int BN = 8;
|
||||||
constexpr int BD = 32;
|
constexpr int BD = 32;
|
||||||
constexpr int elem_per_thread = D / BD;
|
constexpr int qk_per_thread = D / BD;
|
||||||
constexpr int stride = BN * D;
|
constexpr int v_per_thread = V / BD;
|
||||||
|
constexpr int inner_k_stride = BN * D;
|
||||||
|
constexpr int inner_v_stride = BN * V;
|
||||||
constexpr int blocks = 32;
|
constexpr int blocks = 32;
|
||||||
|
|
||||||
typedef float U;
|
typedef float U;
|
||||||
|
|
||||||
thread U q[elem_per_thread];
|
thread U q[qk_per_thread];
|
||||||
thread U k[elem_per_thread];
|
thread U k[qk_per_thread];
|
||||||
thread U o[elem_per_thread];
|
thread U o[v_per_thread];
|
||||||
|
|
||||||
threadgroup U outputs[BN * BD];
|
threadgroup U outputs[BN * BD];
|
||||||
threadgroup U max_scores[BN];
|
threadgroup U max_scores[BN];
|
||||||
@ -165,12 +169,12 @@ template <typename T, int D>
|
|||||||
const int block_idx = tid.z;
|
const int block_idx = tid.z;
|
||||||
const int head_idx = tid.y;
|
const int head_idx = tid.y;
|
||||||
const int kv_head_idx = head_idx / gqa_factor;
|
const int kv_head_idx = head_idx / gqa_factor;
|
||||||
queries += head_idx * D + simd_lid * elem_per_thread;
|
queries += head_idx * D + simd_lid * qk_per_thread;
|
||||||
keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D +
|
keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D +
|
||||||
simd_lid * elem_per_thread;
|
simd_lid * qk_per_thread;
|
||||||
values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D +
|
values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * V +
|
||||||
simd_lid * elem_per_thread;
|
simd_lid * v_per_thread;
|
||||||
out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread;
|
out += head_idx * blocks * V + block_idx * V + simd_lid * v_per_thread;
|
||||||
if (has_mask) {
|
if (has_mask) {
|
||||||
mask += head_idx * mask_head_stride +
|
mask += head_idx * mask_head_stride +
|
||||||
(block_idx * BN + simd_gid) * mask_seq_stride;
|
(block_idx * BN + simd_gid) * mask_seq_stride;
|
||||||
@ -179,10 +183,10 @@ template <typename T, int D>
|
|||||||
maxs += head_idx * blocks + block_idx;
|
maxs += head_idx * blocks + block_idx;
|
||||||
|
|
||||||
// Read the query and 0 the output accumulator
|
// Read the query and 0 the output accumulator
|
||||||
for (int i = 0; i < elem_per_thread; i++) {
|
for (int i = 0; i < qk_per_thread; i++) {
|
||||||
q[i] = static_cast<U>(scale) * queries[i];
|
q[i] = static_cast<U>(scale) * queries[i];
|
||||||
}
|
}
|
||||||
for (int i = 0; i < elem_per_thread; i++) {
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
o[i] = 0;
|
o[i] = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -193,13 +197,13 @@ template <typename T, int D>
|
|||||||
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
|
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
|
||||||
if (!has_mask || mask[0]) {
|
if (!has_mask || mask[0]) {
|
||||||
// Read the key
|
// Read the key
|
||||||
for (int i = 0; i < elem_per_thread; i++) {
|
for (int i = 0; i < qk_per_thread; i++) {
|
||||||
k[i] = keys[i];
|
k[i] = keys[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute the i-th score
|
// Compute the i-th score
|
||||||
U score = 0;
|
U score = 0;
|
||||||
for (int i = 0; i < elem_per_thread; i++) {
|
for (int i = 0; i < qk_per_thread; i++) {
|
||||||
score += q[i] * k[i];
|
score += q[i] * k[i];
|
||||||
}
|
}
|
||||||
score = simd_sum(score);
|
score = simd_sum(score);
|
||||||
@ -213,14 +217,14 @@ template <typename T, int D>
|
|||||||
sum_exp_score = sum_exp_score * factor + exp_score;
|
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||||
|
|
||||||
// Update the output accumulator
|
// Update the output accumulator
|
||||||
for (int i = 0; i < elem_per_thread; i++) {
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
o[i] = o[i] * factor + exp_score * values[i];
|
o[i] = o[i] * factor + exp_score * values[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Move the pointers to the next kv
|
// Move the pointers to the next kv
|
||||||
keys += blocks * stride;
|
keys += blocks * inner_k_stride;
|
||||||
values += blocks * stride;
|
values += blocks * inner_v_stride;
|
||||||
if (has_mask) {
|
if (has_mask) {
|
||||||
mask += BN * blocks * mask_seq_stride;
|
mask += BN * blocks * mask_seq_stride;
|
||||||
}
|
}
|
||||||
@ -247,7 +251,7 @@ template <typename T, int D>
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Now we need to aggregate all the outputs
|
// Now we need to aggregate all the outputs
|
||||||
for (int i = 0; i < elem_per_thread; i++) {
|
for (int i = 0; i < v_per_thread; i++) {
|
||||||
outputs[simd_lid * BN + simd_gid] =
|
outputs[simd_lid * BN + simd_gid] =
|
||||||
o[i] * fast::exp(max_scores[simd_gid] - new_max);
|
o[i] * fast::exp(max_scores[simd_gid] - new_max);
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
@ -124,6 +124,8 @@ void sdpa_vector(
|
|||||||
kname += get_type_string(q.dtype());
|
kname += get_type_string(q.dtype());
|
||||||
kname += "_";
|
kname += "_";
|
||||||
kname += std::to_string(q.shape(-1));
|
kname += std::to_string(q.shape(-1));
|
||||||
|
kname += "_";
|
||||||
|
kname += std::to_string(v.shape(-1));
|
||||||
|
|
||||||
// Compute the necessary sizes
|
// Compute the necessary sizes
|
||||||
int gqa_factor = q.shape(1) / k.shape(1);
|
int gqa_factor = q.shape(1) / k.shape(1);
|
||||||
@ -185,6 +187,8 @@ void sdpa_vector_2pass(
|
|||||||
kname += get_type_string(q.dtype());
|
kname += get_type_string(q.dtype());
|
||||||
kname += "_";
|
kname += "_";
|
||||||
kname += std::to_string(q.shape(-1));
|
kname += std::to_string(q.shape(-1));
|
||||||
|
kname += "_";
|
||||||
|
kname += std::to_string(v.shape(-1));
|
||||||
|
|
||||||
// Compute the necessary sizes
|
// Compute the necessary sizes
|
||||||
int gqa_factor = q.shape(1) / k.shape(1);
|
int gqa_factor = q.shape(1) / k.shape(1);
|
||||||
@ -256,7 +260,7 @@ void sdpa_vector_2pass(
|
|||||||
kname += "sdpa_vector_2pass_2_";
|
kname += "sdpa_vector_2pass_2_";
|
||||||
kname += get_type_string(q.dtype());
|
kname += get_type_string(q.dtype());
|
||||||
kname += "_";
|
kname += "_";
|
||||||
kname += std::to_string(q.shape(-1));
|
kname += std::to_string(v.shape(-1));
|
||||||
|
|
||||||
// Get the kernel
|
// Get the kernel
|
||||||
kernel = d.get_kernel(kname);
|
kernel = d.get_kernel(kname);
|
||||||
@ -332,7 +336,7 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre);
|
const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre);
|
||||||
|
|
||||||
// Donate the query if possible
|
// Donate the query if possible
|
||||||
if (q.is_donatable()) {
|
if (q.is_donatable() && q.size() == o.size()) {
|
||||||
o.move_shared_buffer(q);
|
o.move_shared_buffer(q);
|
||||||
} else {
|
} else {
|
||||||
o.set_data(allocator::malloc_or_wait(o.nbytes()));
|
o.set_data(allocator::malloc_or_wait(o.nbytes()));
|
||||||
|
19
mlx/fast.cpp
19
mlx/fast.cpp
@ -684,23 +684,20 @@ array scaled_dot_product_attention(
|
|||||||
const size_t query_head_dim = q.shape(-1);
|
const size_t query_head_dim = q.shape(-1);
|
||||||
const size_t query_sequence_length = q.shape(2);
|
const size_t query_sequence_length = q.shape(2);
|
||||||
|
|
||||||
bool implementation_supports_use_case = query_head_dim == value_head_dim;
|
|
||||||
|
|
||||||
const bool sdpa_vector_supported_head_dim =
|
const bool sdpa_vector_supported_head_dim =
|
||||||
query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128;
|
query_head_dim == value_head_dim &&
|
||||||
const bool sdpa_full_supported_head_dim =
|
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);
|
||||||
query_head_dim == 64 || query_head_dim == 80;
|
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
|
||||||
|
(query_head_dim == 64 || query_head_dim == 80);
|
||||||
|
|
||||||
const bool supports_sdpa_full = query_sequence_length >= threshold &&
|
const bool supports_sdpa_full = query_sequence_length >= threshold && !mask &&
|
||||||
!mask.has_value() && sdpa_full_supported_head_dim &&
|
sdpa_full_supported_head_dim && stream.device == Device::gpu;
|
||||||
stream.device == Device::gpu;
|
|
||||||
|
|
||||||
const bool supported_mask = !mask || (mask->dtype() == bool_);
|
|
||||||
const bool supports_sdpa_vector = query_sequence_length == 1 &&
|
const bool supports_sdpa_vector = query_sequence_length == 1 &&
|
||||||
supported_mask && sdpa_vector_supported_head_dim &&
|
(!mask || mask->dtype() == bool_) && sdpa_vector_supported_head_dim &&
|
||||||
stream.device == Device::gpu;
|
stream.device == Device::gpu;
|
||||||
|
|
||||||
implementation_supports_use_case &=
|
const bool implementation_supports_use_case =
|
||||||
supports_sdpa_full || supports_sdpa_vector;
|
supports_sdpa_full || supports_sdpa_vector;
|
||||||
|
|
||||||
std::vector<array> inputs = {q, k, v};
|
std::vector<array> inputs = {q, k, v};
|
||||||
|
@ -262,6 +262,23 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||||
|
|
||||||
|
@unittest.skip("Different head and value dims is not enabled")
|
||||||
|
def test_fast_sdpa_vector_value_dims(self):
|
||||||
|
D = 192
|
||||||
|
V = 128
|
||||||
|
Nq = 4
|
||||||
|
Nkv = 1
|
||||||
|
scale = 1.0
|
||||||
|
mx.random.seed(0)
|
||||||
|
|
||||||
|
for L in [43, 128, 237, 8192]:
|
||||||
|
q = 5e-1 * mx.random.normal(shape=(1, Nq, 1, D))
|
||||||
|
k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D))
|
||||||
|
v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, V))
|
||||||
|
ref = mlx_primitives_sdpa(q, k, v, scale)
|
||||||
|
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
|
||||||
|
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main(failfast=True)
|
unittest.main(failfast=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user