mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-27 03:11:16 +08:00
working qsdpa
This commit is contained in:
parent
e047fd977d
commit
12a4d89a7c
@ -1,58 +1,94 @@
|
|||||||
import argparse
|
|
||||||
import math
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
from mlx.utils import tree_map
|
||||||
from time_utils import time_fn
|
from time_utils import time_fn
|
||||||
|
|
||||||
L = 16384
|
L = 32768
|
||||||
H = 32
|
H = 32
|
||||||
H_k = H // 4
|
H_k = H // 4
|
||||||
D = 128
|
D = 128
|
||||||
dtype = mx.float16
|
dtype = mx.float16
|
||||||
loops = 10
|
bits = 8
|
||||||
|
|
||||||
|
loops = 20
|
||||||
|
|
||||||
|
|
||||||
def attention(q, k, v):
|
def attention(q, k, v):
|
||||||
def _sdpa(q, k, v):
|
for _ in range(loops):
|
||||||
B, Hq, L, D = q.shape
|
B, Hq, L, D = q.shape
|
||||||
_, Hk, S, _ = k.shape
|
_, Hk, S, _ = k.shape
|
||||||
q = q.reshape(B, Hk, Hq // Hk, L, D)
|
q = q.reshape(B, Hk, Hq // Hk, L, D)
|
||||||
k = k[:, :, None, :, :]
|
ke = k[:, :, None, :, :]
|
||||||
v = v[:, :, None, :, :]
|
ve = v[:, :, None, :, :]
|
||||||
s = q @ k.transpose(0, 1, 2, 4, 3)
|
s = q @ ke.transpose(0, 1, 2, 4, 3)
|
||||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||||
o = p @ v
|
q = p @ ve
|
||||||
return o.reshape(B, Hq, L, D)
|
q = q.reshape(B, Hq, L, D)
|
||||||
|
|
||||||
for i in range(loops):
|
|
||||||
q = _sdpa(q, k, v)
|
|
||||||
return q
|
return q
|
||||||
|
|
||||||
|
|
||||||
def sdpa(q, k, v):
|
def sdpa(q, k, v):
|
||||||
for i in range(loops):
|
for _ in range(loops):
|
||||||
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
|
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None)
|
||||||
return q
|
return q
|
||||||
|
|
||||||
|
|
||||||
def time_self_attention_primitives():
|
def quant_sdpa(q, k, v, bits=4):
|
||||||
mx.random.seed(3)
|
for _ in range(loops):
|
||||||
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
q = mx.fast.quantized_scaled_dot_product_attention(
|
||||||
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
q, *k, *v, scale=1.0, mask=None, bits=bits
|
||||||
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
)
|
||||||
mx.eval(q, k, v)
|
return q
|
||||||
|
|
||||||
|
|
||||||
|
def quant_attention(q, k, v, bits=4):
|
||||||
|
for _ in range(loops):
|
||||||
|
B, Hq, L, D = q.shape
|
||||||
|
Hk = k[0].shape[1]
|
||||||
|
|
||||||
|
q = q.reshape((B, Hk, Hq // Hk, L, D))
|
||||||
|
ke = tree_map(lambda x: mx.expand_dims(x, axis=2), k)
|
||||||
|
ve = tree_map(lambda x: mx.expand_dims(x, axis=2), v)
|
||||||
|
|
||||||
|
scores = mx.quantized_matmul(q, *ke, transpose=True, bits=bits)
|
||||||
|
scores = mx.softmax(scores, axis=-1)
|
||||||
|
|
||||||
|
q = mx.quantized_matmul(scores, *ve, transpose=False, bits=bits)
|
||||||
|
q = q.reshape((B, Hq, L, D))
|
||||||
|
return q
|
||||||
|
|
||||||
|
|
||||||
|
def time_self_attention_primitives(q, k, v):
|
||||||
time_fn(attention, q, k, v)
|
time_fn(attention, q, k, v)
|
||||||
|
|
||||||
|
|
||||||
def time_self_attention_sdpa():
|
def time_self_attention_sdpa(q, k, v):
|
||||||
mx.random.seed(3)
|
|
||||||
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
|
||||||
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
|
||||||
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
|
||||||
mx.eval(q, k, v)
|
|
||||||
time_fn(sdpa, q, k, v)
|
time_fn(sdpa, q, k, v)
|
||||||
|
|
||||||
|
|
||||||
|
def time_self_attention_quant_sdpa(q, k, v, bits=4):
|
||||||
|
time_fn(quant_sdpa, q, k, v, bits)
|
||||||
|
|
||||||
|
|
||||||
|
def time_self_attention_quant_primitives(q, k, v, bits=4):
|
||||||
|
time_fn(quant_attention, q, k, v, bits)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
time_self_attention_sdpa()
|
mx.random.seed(3)
|
||||||
time_self_attention_primitives()
|
q = mx.random.uniform(shape=(1, H, 1, D), dtype=dtype)
|
||||||
|
k = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype)
|
||||||
|
v = mx.random.uniform(shape=(1, H_k, L, D), dtype=dtype)
|
||||||
|
mx.eval(q, k, v)
|
||||||
|
|
||||||
|
k_quant = mx.quantize(k, bits=bits)
|
||||||
|
v_quant = mx.quantize(v, bits=bits)
|
||||||
|
mx.eval(k_quant, v_quant)
|
||||||
|
|
||||||
|
k = mx.dequantize(*k_quant, bits=bits)
|
||||||
|
v = mx.dequantize(*v_quant, bits=bits)
|
||||||
|
|
||||||
|
time_self_attention_sdpa(q, k, v)
|
||||||
|
time_self_attention_quant_sdpa(q, k_quant, v_quant, bits)
|
||||||
|
time_self_attention_primitives(q, k, v)
|
||||||
|
time_self_attention_quant_primitives(q, k_quant, v_quant, bits)
|
||||||
|
@ -20,4 +20,33 @@ using namespace metal;
|
|||||||
instantiate_sdpa_vector_heads(float)
|
instantiate_sdpa_vector_heads(float)
|
||||||
instantiate_sdpa_vector_heads(bfloat16_t)
|
instantiate_sdpa_vector_heads(bfloat16_t)
|
||||||
instantiate_sdpa_vector_heads(float16_t)
|
instantiate_sdpa_vector_heads(float16_t)
|
||||||
|
|
||||||
|
// Quantized SDPA vector instantiations
|
||||||
|
#define instantiate_quant_sdpa_vector(name, type, head_dim, group_size, bits) \
|
||||||
|
instantiate_kernel( \
|
||||||
|
#name "_" #type "_" #head_dim "_" #group_size "_" #bits, \
|
||||||
|
name, type, head_dim, group_size, bits)
|
||||||
|
|
||||||
|
#define instantiate_quant_sdpa_vector_passes(type, heads, group_size, bits) \
|
||||||
|
instantiate_quant_sdpa_vector(quant_sdpa_vector, type, heads, group_size, bits) \
|
||||||
|
instantiate_quant_sdpa_vector(quant_sdpa_vector_2pass_1, type, heads, group_size, bits)
|
||||||
|
|
||||||
|
#define instantiate_quant_sdpa_vector_bits(type, heads, group_size) \
|
||||||
|
instantiate_quant_sdpa_vector_passes(type, heads, group_size, 4) \
|
||||||
|
instantiate_quant_sdpa_vector_passes(type, heads, group_size, 8)
|
||||||
|
|
||||||
|
#define instantiate_quant_sdpa_vector_group_size(type, heads) \
|
||||||
|
instantiate_quant_sdpa_vector_bits(type, heads, 32) \
|
||||||
|
instantiate_quant_sdpa_vector_bits(type, heads, 64) \
|
||||||
|
instantiate_quant_sdpa_vector_bits(type, heads, 128)
|
||||||
|
|
||||||
|
#define instantiate_quant_sdpa_vector_heads(type) \
|
||||||
|
instantiate_quant_sdpa_vector_group_size(type, 64) \
|
||||||
|
instantiate_quant_sdpa_vector_group_size(type, 96) \
|
||||||
|
instantiate_quant_sdpa_vector_group_size(type, 128)
|
||||||
|
|
||||||
|
instantiate_quant_sdpa_vector_heads(float)
|
||||||
|
instantiate_quant_sdpa_vector_heads(bfloat16_t)
|
||||||
|
instantiate_quant_sdpa_vector_heads(float16_t)
|
||||||
|
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
@ -113,6 +113,208 @@ template <typename T, int D>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, int elem_per_thread, int bits>
|
||||||
|
METAL_FUNC U load_queries(const device T* queries, thread U* q, U scale) {
|
||||||
|
U query_sum = 0;
|
||||||
|
if (bits == 4) {
|
||||||
|
for (int i = 0; i < elem_per_thread; i += 4) {
|
||||||
|
q[i] = scale * queries[i];
|
||||||
|
q[i + 1] = scale * queries[i + 1];
|
||||||
|
q[i + 2] = scale * queries[i + 2];
|
||||||
|
q[i + 3] = scale * queries[i + 3];
|
||||||
|
query_sum += q[i] + q[i + 1] + q[i + 2] + q[i + 3];
|
||||||
|
q[i + 1] /= 16.0f;
|
||||||
|
q[i + 2] /= 256.0f;
|
||||||
|
q[i + 3] /= 4096.0f;
|
||||||
|
}
|
||||||
|
} else if (bits == 8) {
|
||||||
|
for (int i = 0; i < elem_per_thread; i++) {
|
||||||
|
q[i] = scale * queries[i];
|
||||||
|
query_sum += q[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return query_sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename U, int elem_per_thread, int bits>
|
||||||
|
METAL_FUNC void load_keys(const device uint32_t* keys, thread U* k) {
|
||||||
|
if (bits == 4) {
|
||||||
|
auto ks = (const device uint16_t*)keys;
|
||||||
|
for (int i = 0; i < elem_per_thread / 4; i++) {
|
||||||
|
k[4 * i] = ks[i] & 0x000f;
|
||||||
|
k[4 * i + 1] = ks[i] & 0x00f0;
|
||||||
|
k[4 * i + 2] = ks[i] & 0x0f00;
|
||||||
|
k[4 * i + 3] = ks[i] & 0xf000;
|
||||||
|
}
|
||||||
|
} else if (bits == 8) {
|
||||||
|
auto ks = (const device uint8_t*)keys;
|
||||||
|
for (int i = 0; i < elem_per_thread; i++) {
|
||||||
|
k[i] = ks[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename U, int elem_per_thread, int bits>
|
||||||
|
METAL_FUNC void load_values(
|
||||||
|
const device uint32_t* values,
|
||||||
|
thread U* v,
|
||||||
|
U value_scale,
|
||||||
|
U value_bias) {
|
||||||
|
auto vs = (const device uint8_t*)values;
|
||||||
|
if (bits == 4) {
|
||||||
|
U s[2] = {value_scale, value_scale / 16.0f};
|
||||||
|
for (int i = 0; i < elem_per_thread / 2; i++) {
|
||||||
|
v[2 * i] = s[0] * (vs[i] & 0x0f) + value_bias;
|
||||||
|
v[2 * i + 1] = s[1] * (vs[i] & 0xf0) + value_bias;
|
||||||
|
}
|
||||||
|
} else if (bits == 8) {
|
||||||
|
for (int i = 0; i < elem_per_thread; i++) {
|
||||||
|
v[i] = value_scale * vs[i] + value_bias;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int D, int group_size, int bits>
|
||||||
|
[[kernel]] void quant_sdpa_vector(
|
||||||
|
const device T* queries [[buffer(0)]],
|
||||||
|
const device uint32_t* keys [[buffer(1)]],
|
||||||
|
const device T* key_scales [[buffer(2)]],
|
||||||
|
const device T* key_biases [[buffer(3)]],
|
||||||
|
const device uint32_t* values [[buffer(4)]],
|
||||||
|
const device T* value_scales [[buffer(5)]],
|
||||||
|
const device T* value_biases [[buffer(6)]],
|
||||||
|
device T* out [[buffer(7)]],
|
||||||
|
const constant int& gqa_factor,
|
||||||
|
const constant int& N,
|
||||||
|
const constant size_t& k_stride,
|
||||||
|
const constant size_t& group_stride,
|
||||||
|
const constant float& scale,
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]],
|
||||||
|
uint quad_gid [[quadgroup_index_in_threadgroup]],
|
||||||
|
uint quad_lid [[thread_index_in_quadgroup]]) {
|
||||||
|
constexpr int BN = 32;
|
||||||
|
constexpr int BD = 4;
|
||||||
|
constexpr int elem_per_thread = D / BD;
|
||||||
|
constexpr int pack_factor = 32 / bits;
|
||||||
|
|
||||||
|
const int stride = BN * D;
|
||||||
|
|
||||||
|
typedef float U;
|
||||||
|
|
||||||
|
thread U q[elem_per_thread];
|
||||||
|
thread U k[elem_per_thread];
|
||||||
|
thread U v[elem_per_thread];
|
||||||
|
thread U o[elem_per_thread];
|
||||||
|
|
||||||
|
threadgroup U outputs[BN * BD];
|
||||||
|
threadgroup U max_scores[BN];
|
||||||
|
threadgroup U sum_exp_scores[BN];
|
||||||
|
|
||||||
|
// Adjust positions
|
||||||
|
const int head_idx = tid.y;
|
||||||
|
const int kv_head_idx = head_idx / gqa_factor;
|
||||||
|
queries += head_idx * D + quad_lid * elem_per_thread;
|
||||||
|
|
||||||
|
const int kv_idx = quad_gid * D + quad_lid * elem_per_thread;
|
||||||
|
const int packed_idx = kv_head_idx * k_stride + kv_idx / pack_factor;
|
||||||
|
const int group_idx = kv_head_idx * group_stride + kv_idx / group_size;
|
||||||
|
keys += packed_idx;
|
||||||
|
key_scales += group_idx;
|
||||||
|
key_biases += group_idx;
|
||||||
|
values += packed_idx;
|
||||||
|
value_scales += group_idx;
|
||||||
|
value_biases += group_idx;
|
||||||
|
|
||||||
|
out += head_idx * D + simd_gid * elem_per_thread;
|
||||||
|
|
||||||
|
// Read the query and 0 the output accumulator
|
||||||
|
U query_sum = load_queries<T, U, elem_per_thread, bits>(
|
||||||
|
queries, q, static_cast<U>(scale));
|
||||||
|
for (int i = 0; i < elem_per_thread; i++) {
|
||||||
|
o[i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
U max_score = -INFINITY;
|
||||||
|
U sum_exp_score = 0;
|
||||||
|
|
||||||
|
// For each key
|
||||||
|
for (int i = quad_gid; i < N; i += BN) {
|
||||||
|
load_keys<U, elem_per_thread, bits>(keys, k);
|
||||||
|
|
||||||
|
// Assume D % group_size == 0 so all the keys are in the same group
|
||||||
|
U key_scale = key_scales[0];
|
||||||
|
U key_bias = key_biases[0];
|
||||||
|
|
||||||
|
// Compute the i-th score
|
||||||
|
U score = 0;
|
||||||
|
for (int i = 0; i < elem_per_thread; i++) {
|
||||||
|
score += q[i] * k[i];
|
||||||
|
}
|
||||||
|
score = score * key_scale + query_sum * key_bias;
|
||||||
|
score = quad_sum(score);
|
||||||
|
|
||||||
|
// Update the accumulators
|
||||||
|
U new_max = max(max_score, score);
|
||||||
|
U factor = fast::exp(max_score - new_max);
|
||||||
|
U exp_score = fast::exp(score - new_max);
|
||||||
|
|
||||||
|
max_score = new_max;
|
||||||
|
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||||
|
|
||||||
|
U value_scale = value_scales[0];
|
||||||
|
U value_bias = value_biases[0];
|
||||||
|
|
||||||
|
// Load the values
|
||||||
|
load_values<U, elem_per_thread, bits>(values, v, value_scale, value_bias);
|
||||||
|
|
||||||
|
// Update the output accumulator
|
||||||
|
for (int i = 0; i < elem_per_thread; i++) {
|
||||||
|
o[i] = o[i] * factor + exp_score * v[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move the pointers to the next kv
|
||||||
|
keys += stride / pack_factor;
|
||||||
|
key_scales += stride / group_size;
|
||||||
|
key_biases += stride / group_size;
|
||||||
|
values += stride / pack_factor;
|
||||||
|
value_scales += stride / group_size;
|
||||||
|
value_biases += stride / group_size;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Each thread has a partial part of the output so we need to combine them.
|
||||||
|
|
||||||
|
// First let's communicate the max and sum_exp
|
||||||
|
// Each quadgroup communicates it's max score
|
||||||
|
if (quad_lid == 0) {
|
||||||
|
max_scores[quad_gid] = max_score;
|
||||||
|
sum_exp_scores[quad_gid] = sum_exp_score;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
max_score = max_scores[simd_lid];
|
||||||
|
U new_max = simd_max(max_score);
|
||||||
|
U factor = fast::exp(max_score - new_max);
|
||||||
|
sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
|
||||||
|
|
||||||
|
// Now we need to aggregate all the outputs
|
||||||
|
for (int i = 0; i < elem_per_thread; i++) {
|
||||||
|
// 128 threads with 32 values per thread
|
||||||
|
outputs[simd_gid * BN + simd_lid] = o[i];
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
o[i] = simd_sum(outputs[simd_lid * BD + simd_gid] * factor) / sum_exp_score;
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
}
|
||||||
|
|
||||||
|
// And write the output
|
||||||
|
if (simd_lid == 0) {
|
||||||
|
for (int i = 0; i < elem_per_thread; i++) {
|
||||||
|
out[i] = static_cast<T>(o[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, int D>
|
template <typename T, int D>
|
||||||
[[kernel]] void sdpa_vector_2pass_1(
|
[[kernel]] void sdpa_vector_2pass_1(
|
||||||
const device T* queries [[buffer(0)]],
|
const device T* queries [[buffer(0)]],
|
||||||
@ -290,3 +492,158 @@ template <typename T, int D>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, int D, int group_size, int bits>
|
||||||
|
[[kernel]] void quant_sdpa_vector_2pass_1(
|
||||||
|
const device T* queries [[buffer(0)]],
|
||||||
|
const device uint32_t* keys [[buffer(1)]],
|
||||||
|
const device T* key_scales [[buffer(2)]],
|
||||||
|
const device T* key_biases [[buffer(3)]],
|
||||||
|
const device uint32_t* values [[buffer(4)]],
|
||||||
|
const device T* value_scales [[buffer(5)]],
|
||||||
|
const device T* value_biases [[buffer(6)]],
|
||||||
|
device float* out [[buffer(7)]],
|
||||||
|
device float* sums [[buffer(8)]],
|
||||||
|
device float* maxs [[buffer(9)]],
|
||||||
|
const constant int& gqa_factor,
|
||||||
|
const constant int& N,
|
||||||
|
const constant size_t& k_stride,
|
||||||
|
const constant size_t& v_stride,
|
||||||
|
const constant size_t& k_group_stride,
|
||||||
|
const constant size_t& v_group_stride,
|
||||||
|
const constant float& scale,
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]],
|
||||||
|
uint quad_gid [[quadgroup_index_in_threadgroup]],
|
||||||
|
uint quad_lid [[thread_index_in_quadgroup]]) {
|
||||||
|
constexpr int BN = 8;
|
||||||
|
constexpr int BD = 4;
|
||||||
|
constexpr int elem_per_thread = D / BD;
|
||||||
|
const int stride = BN * D;
|
||||||
|
constexpr int blocks = 32;
|
||||||
|
constexpr int pack_factor = 32 / bits;
|
||||||
|
|
||||||
|
typedef float U;
|
||||||
|
|
||||||
|
thread U q[elem_per_thread];
|
||||||
|
thread U k[elem_per_thread];
|
||||||
|
thread U v[elem_per_thread];
|
||||||
|
thread U o[elem_per_thread];
|
||||||
|
|
||||||
|
threadgroup U outputs[BN * BD];
|
||||||
|
threadgroup U max_scores[BN];
|
||||||
|
threadgroup U sum_exp_scores[BN];
|
||||||
|
|
||||||
|
// Adjust positions
|
||||||
|
const int block_idx = tid.z;
|
||||||
|
const int head_idx = tid.y;
|
||||||
|
const int kv_head_idx = head_idx / gqa_factor;
|
||||||
|
queries += head_idx * D + quad_lid * elem_per_thread;
|
||||||
|
|
||||||
|
const int kv_idx =
|
||||||
|
(block_idx * BN + quad_gid) * D + quad_lid * elem_per_thread;
|
||||||
|
const int packed_idx = kv_idx / pack_factor;
|
||||||
|
const int k_group_idx = kv_head_idx * k_group_stride + kv_idx / group_size;
|
||||||
|
const int v_group_idx = kv_head_idx * v_group_stride + kv_idx / group_size;
|
||||||
|
|
||||||
|
keys += kv_head_idx * k_stride + packed_idx;
|
||||||
|
key_scales += k_group_idx;
|
||||||
|
key_biases += k_group_idx;
|
||||||
|
values += kv_head_idx * v_stride + packed_idx;
|
||||||
|
value_scales += v_group_idx;
|
||||||
|
value_biases += v_group_idx;
|
||||||
|
|
||||||
|
out += head_idx * blocks * D + block_idx * D + quad_lid * elem_per_thread;
|
||||||
|
sums += head_idx * blocks + block_idx;
|
||||||
|
maxs += head_idx * blocks + block_idx;
|
||||||
|
|
||||||
|
// Read the query and 0 the output accumulator
|
||||||
|
U query_sum = load_queries<T, U, elem_per_thread, bits>(
|
||||||
|
queries, q, static_cast<U>(scale));
|
||||||
|
for (int i = 0; i < elem_per_thread; i++) {
|
||||||
|
o[i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
U max_score = -1e9;
|
||||||
|
U sum_exp_score = 0;
|
||||||
|
|
||||||
|
// For each key
|
||||||
|
for (int i = block_idx * BN + quad_gid; i < N; i += blocks * BN) {
|
||||||
|
// Read the key
|
||||||
|
load_keys<U, elem_per_thread, bits>(keys, k);
|
||||||
|
|
||||||
|
// Assume D % group_size == 0 so all the keys are in the same group
|
||||||
|
U key_scale = key_scales[0];
|
||||||
|
U key_bias = key_biases[0];
|
||||||
|
|
||||||
|
// Compute the i-th score
|
||||||
|
U score = 0;
|
||||||
|
for (int i = 0; i < elem_per_thread; i++) {
|
||||||
|
score += q[i] * k[i];
|
||||||
|
}
|
||||||
|
score = score * key_scale + query_sum * key_bias;
|
||||||
|
score = quad_sum(score);
|
||||||
|
|
||||||
|
// Update the accumulators
|
||||||
|
U new_max = max(max_score, score);
|
||||||
|
U factor = fast::exp(max_score - new_max);
|
||||||
|
U exp_score = fast::exp(score - new_max);
|
||||||
|
|
||||||
|
max_score = new_max;
|
||||||
|
sum_exp_score = sum_exp_score * factor + exp_score;
|
||||||
|
|
||||||
|
U value_scale = value_scales[0];
|
||||||
|
U value_bias = value_biases[0];
|
||||||
|
load_values<U, elem_per_thread, bits>(values, v, value_scale, value_bias);
|
||||||
|
|
||||||
|
// Update the output accumulator
|
||||||
|
for (int i = 0; i < elem_per_thread; i++) {
|
||||||
|
o[i] = o[i] * factor + exp_score * v[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move the pointers to the next kv
|
||||||
|
keys += blocks * stride / pack_factor;
|
||||||
|
key_scales += blocks * stride / group_size;
|
||||||
|
key_biases += blocks * stride / group_size;
|
||||||
|
values += blocks * stride / pack_factor;
|
||||||
|
value_scales += blocks * stride / group_size;
|
||||||
|
value_biases += blocks * stride / group_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Each thread has a partial part of the output so we need to combine them.
|
||||||
|
|
||||||
|
// First let's communicate the max and sum_exp
|
||||||
|
if (quad_lid == 0) {
|
||||||
|
max_scores[quad_gid] = max_score;
|
||||||
|
sum_exp_scores[quad_gid] = sum_exp_score;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9;
|
||||||
|
U new_max = simd_max(max_score);
|
||||||
|
U factor = fast::exp(max_score - new_max);
|
||||||
|
sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0;
|
||||||
|
sum_exp_score = simd_sum(sum_exp_score * factor);
|
||||||
|
|
||||||
|
// Write the sum and new max
|
||||||
|
if (simd_gid == 0) {
|
||||||
|
sums[0] = sum_exp_score;
|
||||||
|
maxs[0] = new_max;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now we need to aggregate all the outputs
|
||||||
|
for (int i = 0; i < elem_per_thread; i++) {
|
||||||
|
outputs[quad_lid * BN + quad_gid] =
|
||||||
|
o[i] * fast::exp(max_scores[quad_gid] - new_max);
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
if (quad_gid == 0) {
|
||||||
|
U output = outputs[quad_lid * BN];
|
||||||
|
for (int j = 1; j < BN; j++) {
|
||||||
|
output += outputs[quad_lid * BN + j];
|
||||||
|
}
|
||||||
|
out[i] = static_cast<T>(output);
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -242,6 +242,171 @@ void sdpa_vector_2pass(
|
|||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void quant_sdpa_vector(
|
||||||
|
const Stream& s,
|
||||||
|
metal::Device& d,
|
||||||
|
const array& q,
|
||||||
|
const array& k,
|
||||||
|
const array& k_scales,
|
||||||
|
const array& k_biases,
|
||||||
|
const array& v,
|
||||||
|
const array& v_scales,
|
||||||
|
const array& v_biases,
|
||||||
|
array& out,
|
||||||
|
float scale,
|
||||||
|
int group_size,
|
||||||
|
int bits) {
|
||||||
|
// Set the kernel name
|
||||||
|
std::string kname;
|
||||||
|
kname.reserve(96);
|
||||||
|
kname += "quant_sdpa_vector_";
|
||||||
|
kname += get_type_string(q.dtype());
|
||||||
|
kname += "_";
|
||||||
|
kname += std::to_string(q.shape(-1));
|
||||||
|
kname += "_";
|
||||||
|
kname += std::to_string(group_size);
|
||||||
|
kname += "_";
|
||||||
|
kname += std::to_string(bits);
|
||||||
|
|
||||||
|
// Compute the necessary sizes
|
||||||
|
int gqa_factor = q.shape(1) / k.shape(1);
|
||||||
|
int N = k.shape(2);
|
||||||
|
int B = q.shape(0) * q.shape(1);
|
||||||
|
size_t stride = k.strides()[1];
|
||||||
|
size_t group_stride = k_scales.strides()[1];
|
||||||
|
MTL::Size group_dims(128, 1, 1);
|
||||||
|
MTL::Size grid_dims(1, B, 1);
|
||||||
|
|
||||||
|
// Get the kernel
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = d.get_kernel(kname);
|
||||||
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
|
// Set its arguments
|
||||||
|
compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0);
|
||||||
|
compute_encoder.set_input_array(k, 1);
|
||||||
|
compute_encoder.set_input_array(k_scales, 2);
|
||||||
|
compute_encoder.set_input_array(k_biases, 3);
|
||||||
|
compute_encoder.set_input_array(v, 4);
|
||||||
|
compute_encoder.set_input_array(v_scales, 5);
|
||||||
|
compute_encoder.set_input_array(v_biases, 6);
|
||||||
|
compute_encoder.set_output_array(out, 7);
|
||||||
|
compute_encoder.set_bytes(&gqa_factor, sizeof(int), 8);
|
||||||
|
compute_encoder.set_bytes(&N, sizeof(int), 9);
|
||||||
|
compute_encoder.set_bytes(&stride, sizeof(size_t), 10);
|
||||||
|
compute_encoder.set_bytes(&group_stride, sizeof(size_t), 11);
|
||||||
|
compute_encoder.set_bytes(&scale, sizeof(float), 12);
|
||||||
|
|
||||||
|
// Launch
|
||||||
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
void quant_sdpa_vector_2pass(
|
||||||
|
const Stream& s,
|
||||||
|
metal::Device& d,
|
||||||
|
const array& q,
|
||||||
|
const array& k,
|
||||||
|
const array& k_scales,
|
||||||
|
const array& k_biases,
|
||||||
|
const array& v,
|
||||||
|
const array& v_scales,
|
||||||
|
const array& v_biases,
|
||||||
|
array& out,
|
||||||
|
float scale,
|
||||||
|
int group_size,
|
||||||
|
int bits) {
|
||||||
|
// Set the kernel name
|
||||||
|
std::string kname;
|
||||||
|
kname.reserve(96);
|
||||||
|
kname += "quant_sdpa_vector_2pass_1_";
|
||||||
|
kname += get_type_string(q.dtype());
|
||||||
|
kname += "_";
|
||||||
|
kname += std::to_string(q.shape(-1));
|
||||||
|
kname += "_";
|
||||||
|
kname += std::to_string(group_size);
|
||||||
|
kname += "_";
|
||||||
|
kname += std::to_string(bits);
|
||||||
|
|
||||||
|
// Compute the necessary sizes
|
||||||
|
int gqa_factor = q.shape(1) / k.shape(1);
|
||||||
|
int N = k.shape(2);
|
||||||
|
int blocks = 32;
|
||||||
|
int B = q.shape(0) * q.shape(1);
|
||||||
|
size_t k_stride = k.strides()[1];
|
||||||
|
size_t v_stride = v.strides()[1];
|
||||||
|
size_t k_group_stride = k_scales.strides()[1];
|
||||||
|
size_t v_group_stride = v_scales.strides()[1];
|
||||||
|
MTL::Size group_dims(8 * 4, 1, 1);
|
||||||
|
MTL::Size grid_dims(1, B, blocks);
|
||||||
|
|
||||||
|
// Allocate the intermediates
|
||||||
|
std::vector<int> intermediate_shape;
|
||||||
|
intermediate_shape.reserve(out.ndim() + 1);
|
||||||
|
intermediate_shape.insert(
|
||||||
|
intermediate_shape.end(), out.shape().begin(), out.shape().end() - 1);
|
||||||
|
intermediate_shape.push_back(blocks);
|
||||||
|
intermediate_shape.push_back(out.shape().back());
|
||||||
|
array intermediate(intermediate_shape, float32, nullptr, {});
|
||||||
|
intermediate_shape.pop_back();
|
||||||
|
array sums(intermediate_shape, float32, nullptr, {});
|
||||||
|
array maxs(std::move(intermediate_shape), float32, nullptr, {});
|
||||||
|
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||||
|
sums.set_data(allocator::malloc_or_wait(sums.nbytes()));
|
||||||
|
maxs.set_data(allocator::malloc_or_wait(maxs.nbytes()));
|
||||||
|
d.add_temporary(intermediate, s.index);
|
||||||
|
d.add_temporary(sums, s.index);
|
||||||
|
d.add_temporary(maxs, s.index);
|
||||||
|
|
||||||
|
// Get the kernel
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = d.get_kernel(kname);
|
||||||
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
|
// Set its arguments
|
||||||
|
compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0);
|
||||||
|
compute_encoder.set_input_array(k, 1);
|
||||||
|
compute_encoder.set_input_array(k_scales, 2);
|
||||||
|
compute_encoder.set_input_array(k_biases, 3);
|
||||||
|
compute_encoder.set_input_array(v, 4);
|
||||||
|
compute_encoder.set_input_array(v_scales, 5);
|
||||||
|
compute_encoder.set_input_array(v_biases, 6);
|
||||||
|
compute_encoder.set_output_array(intermediate, 7);
|
||||||
|
compute_encoder.set_output_array(sums, 8);
|
||||||
|
compute_encoder.set_output_array(maxs, 9);
|
||||||
|
compute_encoder.set_bytes(gqa_factor, 10);
|
||||||
|
compute_encoder.set_bytes(N, 11);
|
||||||
|
compute_encoder.set_bytes(k_stride, 12);
|
||||||
|
compute_encoder.set_bytes(v_stride, 13);
|
||||||
|
compute_encoder.set_bytes(k_group_stride, 14);
|
||||||
|
compute_encoder.set_bytes(v_group_stride, 15);
|
||||||
|
compute_encoder.set_bytes(scale, 16);
|
||||||
|
|
||||||
|
// Launch
|
||||||
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
|
|
||||||
|
// Final pass
|
||||||
|
kname.clear();
|
||||||
|
kname += "sdpa_vector_2pass_2_";
|
||||||
|
kname += get_type_string(q.dtype());
|
||||||
|
kname += "_";
|
||||||
|
kname += std::to_string(q.shape(-1));
|
||||||
|
|
||||||
|
// Get the kernel
|
||||||
|
kernel = d.get_kernel(kname);
|
||||||
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
|
// Set its arguments
|
||||||
|
compute_encoder.set_input_array(intermediate, 0);
|
||||||
|
compute_encoder.set_input_array(sums, 1);
|
||||||
|
compute_encoder.set_input_array(maxs, 2);
|
||||||
|
compute_encoder.set_output_array(out, 3);
|
||||||
|
|
||||||
|
// Launch
|
||||||
|
group_dims = MTL::Size(1024, 1, 1);
|
||||||
|
grid_dims = MTL::Size(1, B, 1);
|
||||||
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void ScaledDotProductAttention::eval_gpu(
|
void ScaledDotProductAttention::eval_gpu(
|
||||||
@ -254,7 +419,6 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
|
|
||||||
auto& q_pre = inputs[0];
|
auto& q_pre = inputs[0];
|
||||||
auto& k_pre = inputs[1];
|
auto& k_pre = inputs[1];
|
||||||
auto& v_pre = inputs[2];
|
|
||||||
auto& o = out;
|
auto& o = out;
|
||||||
|
|
||||||
std::vector<array> copies;
|
std::vector<array> copies;
|
||||||
@ -295,9 +459,7 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
|
|
||||||
// We are in vector mode ie single query
|
// We are in vector mode ie single query
|
||||||
if (q_pre.shape(2) == 1) {
|
if (q_pre.shape(2) == 1) {
|
||||||
const auto& q = copy_unless(is_contiguous, q_pre);
|
auto q = copy_unless(is_contiguous, q_pre);
|
||||||
const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre);
|
|
||||||
const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre);
|
|
||||||
|
|
||||||
// Donate the query if possible
|
// Donate the query if possible
|
||||||
if (q.is_donatable()) {
|
if (q.is_donatable()) {
|
||||||
@ -306,20 +468,55 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
o.set_data(allocator::malloc_or_wait(o.nbytes()));
|
o.set_data(allocator::malloc_or_wait(o.nbytes()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// We route to the 2 pass fused attention if
|
if (quantized_) {
|
||||||
// - The device is large and the sequence length long
|
auto& k_scales_pre = inputs[2];
|
||||||
// - The sequence length is even longer and we have gqa
|
auto& k_biases_pre = inputs[3];
|
||||||
char devc = d.get_architecture().back();
|
auto& v_pre = inputs[4];
|
||||||
if ((devc == 'd' && k.shape(2) >= 1024) ||
|
auto& v_scales_pre = inputs[5];
|
||||||
(k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) {
|
auto& v_biases_pre = inputs[6];
|
||||||
sdpa_vector_2pass(s, d, q, k, v, o, scale_);
|
|
||||||
|
auto k = copy_unless(is_contiguous_except_seq_len, k_pre);
|
||||||
|
auto k_scales = copy_unless(is_contiguous_except_seq_len, k_scales_pre);
|
||||||
|
auto k_biases = copy_unless(is_contiguous_except_seq_len, k_biases_pre);
|
||||||
|
auto v = copy_unless(is_contiguous_except_seq_len, v_pre);
|
||||||
|
auto v_scales = copy_unless(is_contiguous_except_seq_len, v_scales_pre);
|
||||||
|
auto v_biases = copy_unless(is_contiguous_except_seq_len, v_biases_pre);
|
||||||
|
|
||||||
|
quant_sdpa_vector_2pass(
|
||||||
|
s,
|
||||||
|
d,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
k_scales,
|
||||||
|
k_biases,
|
||||||
|
v,
|
||||||
|
v_scales,
|
||||||
|
v_biases,
|
||||||
|
o,
|
||||||
|
scale_,
|
||||||
|
group_size_,
|
||||||
|
bits_);
|
||||||
} else {
|
} else {
|
||||||
sdpa_vector(s, d, q, k, v, o, scale_);
|
auto& k_pre = inputs[1];
|
||||||
|
auto& v_pre = inputs[2];
|
||||||
|
|
||||||
|
const auto& k = copy_unless(is_contiguous_except_seq_len, k_pre);
|
||||||
|
const auto& v = copy_unless(is_contiguous_except_seq_len, v_pre);
|
||||||
|
|
||||||
|
char devc = d.get_architecture().back();
|
||||||
|
if ((devc == 'd' && k.shape(2) >= 1024) ||
|
||||||
|
(k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) {
|
||||||
|
sdpa_vector_2pass(s, d, q, k, v, o, scale_);
|
||||||
|
} else {
|
||||||
|
sdpa_vector(s, d, q, k, v, o, scale_);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Full attention mode
|
// Full attention mode
|
||||||
else {
|
else {
|
||||||
|
auto& v_pre = inputs[2];
|
||||||
|
|
||||||
const auto& q = copy_unless(is_matrix_contiguous, q_pre);
|
const auto& q = copy_unless(is_matrix_contiguous, q_pre);
|
||||||
const auto& k = copy_unless(is_matrix_contiguous, k_pre);
|
const auto& k = copy_unless(is_matrix_contiguous, k_pre);
|
||||||
const auto& v = copy_unless(is_matrix_contiguous, v_pre);
|
const auto& v = copy_unless(is_matrix_contiguous, v_pre);
|
||||||
|
127
mlx/fast.cpp
127
mlx/fast.cpp
@ -664,7 +664,7 @@ array scaled_dot_product_attention(
|
|||||||
std::move(out_shape),
|
std::move(out_shape),
|
||||||
final_type,
|
final_type,
|
||||||
std::make_shared<ScaledDotProductAttention>(
|
std::make_shared<ScaledDotProductAttention>(
|
||||||
stream, fallback, scale, false),
|
stream, fallback, scale, /*needs_mask=*/false, /*quantized=*/false),
|
||||||
{q, k, v});
|
{q, k, v});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -678,7 +678,130 @@ array scaled_dot_product_attention(
|
|||||||
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
|
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
|
||||||
const ScaledDotProductAttention& a_other =
|
const ScaledDotProductAttention& a_other =
|
||||||
static_cast<const ScaledDotProductAttention&>(other);
|
static_cast<const ScaledDotProductAttention&>(other);
|
||||||
return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_;
|
return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_ &&
|
||||||
|
quantized_ == a_other.quantized_;
|
||||||
|
}
|
||||||
|
|
||||||
|
array quantized_scaled_dot_product_attention(
|
||||||
|
const array& queries,
|
||||||
|
const array& keys,
|
||||||
|
const array& key_scales,
|
||||||
|
const array& key_biases,
|
||||||
|
const array& values,
|
||||||
|
const array& value_scales,
|
||||||
|
const array& value_biases,
|
||||||
|
const float scale,
|
||||||
|
const std::optional<array>& mask,
|
||||||
|
const int group_size,
|
||||||
|
const int bits,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
int el_per_int = 32 / bits;
|
||||||
|
int out_dim = values.shape(-1) * el_per_int;
|
||||||
|
|
||||||
|
auto n_q_heads = queries.shape(-3);
|
||||||
|
auto n_kv_heads = keys.shape(-3);
|
||||||
|
|
||||||
|
auto out_shape = std::vector<int>(
|
||||||
|
{queries.shape(0), queries.shape(1), queries.shape(2), out_dim});
|
||||||
|
auto stream = to_stream(s);
|
||||||
|
bool needs_mask = mask.has_value();
|
||||||
|
auto fallback =
|
||||||
|
[scale, needs_mask, n_q_heads, n_kv_heads, group_size, bits, &s](
|
||||||
|
const std::vector<array>& inputs) -> std::vector<array> {
|
||||||
|
int n_repeats = n_q_heads / n_kv_heads;
|
||||||
|
|
||||||
|
auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
|
||||||
|
|
||||||
|
auto k = inputs[1];
|
||||||
|
auto k_scales = inputs[2];
|
||||||
|
auto k_biases = inputs[3];
|
||||||
|
auto v = inputs[4];
|
||||||
|
auto v_scales = inputs[5];
|
||||||
|
auto v_biases = inputs[6];
|
||||||
|
|
||||||
|
int B = q.shape(0);
|
||||||
|
int L = q.shape(2);
|
||||||
|
|
||||||
|
if (n_repeats > 1) {
|
||||||
|
q = reshape(q, {B, n_kv_heads, n_repeats, L, -1}, s);
|
||||||
|
k = expand_dims(k, 2, s);
|
||||||
|
k_scales = expand_dims(k_scales, 2, s);
|
||||||
|
k_biases = expand_dims(k_biases, 2, s);
|
||||||
|
v = expand_dims(v, 2, s);
|
||||||
|
v_scales = expand_dims(v_scales, 2, s);
|
||||||
|
v_biases = expand_dims(v_biases, 2, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array scores = quantized_matmul(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
k_scales,
|
||||||
|
k_biases,
|
||||||
|
/*transpose=*/true,
|
||||||
|
/*group_size=*/group_size,
|
||||||
|
/*bits=*/bits,
|
||||||
|
s);
|
||||||
|
if (needs_mask) {
|
||||||
|
scores = add(scores, inputs[7], s);
|
||||||
|
}
|
||||||
|
scores = softmax(scores, std::vector<int>{-1}, true, s);
|
||||||
|
array out = quantized_matmul(
|
||||||
|
scores,
|
||||||
|
v,
|
||||||
|
v_scales,
|
||||||
|
v_biases,
|
||||||
|
/*transpose=*/false,
|
||||||
|
/*group_size=*/group_size,
|
||||||
|
/*bits=*/bits,
|
||||||
|
s);
|
||||||
|
if (n_repeats > 1) {
|
||||||
|
out = reshape(out, {B, n_q_heads, L, -1}, s);
|
||||||
|
}
|
||||||
|
return std::vector<array>{out};
|
||||||
|
};
|
||||||
|
|
||||||
|
int L = queries.shape(2);
|
||||||
|
if (L > 1) {
|
||||||
|
if (needs_mask) {
|
||||||
|
return fallback(
|
||||||
|
{queries,
|
||||||
|
keys,
|
||||||
|
key_scales,
|
||||||
|
key_biases,
|
||||||
|
values,
|
||||||
|
value_scales,
|
||||||
|
value_biases,
|
||||||
|
mask.value()})[0];
|
||||||
|
} else {
|
||||||
|
return fallback(
|
||||||
|
{queries,
|
||||||
|
keys,
|
||||||
|
key_scales,
|
||||||
|
key_biases,
|
||||||
|
values,
|
||||||
|
value_scales,
|
||||||
|
value_biases})[0];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return array(
|
||||||
|
std::move(out_shape),
|
||||||
|
queries.dtype(),
|
||||||
|
std::make_shared<ScaledDotProductAttention>(
|
||||||
|
stream,
|
||||||
|
fallback,
|
||||||
|
scale,
|
||||||
|
/*needs_mask=*/false,
|
||||||
|
/*quantized=*/true,
|
||||||
|
group_size,
|
||||||
|
bits),
|
||||||
|
{queries,
|
||||||
|
keys,
|
||||||
|
key_scales,
|
||||||
|
key_biases,
|
||||||
|
values,
|
||||||
|
value_scales,
|
||||||
|
value_biases});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
array pack_and_quantize(
|
array pack_and_quantize(
|
||||||
|
15
mlx/fast.h
15
mlx/fast.h
@ -41,6 +41,21 @@ array scaled_dot_product_attention(
|
|||||||
const std::optional<int> memory_efficient_threshold = std::nullopt,
|
const std::optional<int> memory_efficient_threshold = std::nullopt,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Computes: `O = softmax(Q @ K.T) @ V` where K and V are quantized. **/
|
||||||
|
array quantized_scaled_dot_product_attention(
|
||||||
|
const array& queries,
|
||||||
|
const array& keys,
|
||||||
|
const array& key_scales,
|
||||||
|
const array& key_biases,
|
||||||
|
const array& values,
|
||||||
|
const array& value_scales,
|
||||||
|
const array& value_biases,
|
||||||
|
const float scale,
|
||||||
|
const std::optional<array>& mask = std::nullopt,
|
||||||
|
const int group_size = 64,
|
||||||
|
const int bits = 4,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
std::tuple<array, array, array> affine_quantize(
|
std::tuple<array, array, array> affine_quantize(
|
||||||
const array& w,
|
const array& w,
|
||||||
int group_size = 64,
|
int group_size = 64,
|
||||||
|
@ -190,8 +190,16 @@ class ScaledDotProductAttention : public Custom {
|
|||||||
Stream stream,
|
Stream stream,
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||||
const float scale,
|
const float scale,
|
||||||
const bool needs_mask)
|
const bool needs_mask,
|
||||||
: Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask) {}
|
const bool quantized,
|
||||||
|
const int group_size = 64,
|
||||||
|
const int bits = 4)
|
||||||
|
: Custom(stream, fallback),
|
||||||
|
scale_(scale),
|
||||||
|
needs_mask_(needs_mask),
|
||||||
|
quantized_(quantized),
|
||||||
|
group_size_(group_size),
|
||||||
|
bits_(bits) {}
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override {
|
override {
|
||||||
@ -212,6 +220,9 @@ class ScaledDotProductAttention : public Custom {
|
|||||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||||
float scale_;
|
float scale_;
|
||||||
bool needs_mask_;
|
bool needs_mask_;
|
||||||
|
bool quantized_;
|
||||||
|
int group_size_;
|
||||||
|
int bits_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class AffineQuantize : public Custom {
|
class AffineQuantize : public Custom {
|
||||||
|
@ -161,6 +161,45 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
array: The output array.
|
array: The output array.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"quantized_scaled_dot_product_attention",
|
||||||
|
&fast::quantized_scaled_dot_product_attention,
|
||||||
|
"q"_a,
|
||||||
|
"k"_a,
|
||||||
|
"k_scales"_a,
|
||||||
|
"k_biases"_a,
|
||||||
|
"v"_a,
|
||||||
|
"v_scales"_a,
|
||||||
|
"v_biases"_a,
|
||||||
|
nb::kw_only(),
|
||||||
|
"scale"_a,
|
||||||
|
"mask"_a = nb::none(),
|
||||||
|
"group_size"_a = 64,
|
||||||
|
"bits"_a = 4,
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def quantized_scaled_dot_product_attention(q: array, k: array, k_scales: array, k_biases: array, v: array, v_scales: array, v_biases: array, *, scale: float, mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
A fast implementation of multi-head attention where the keys and values are quantized.
|
||||||
|
|
||||||
|
see :func:`scaled_dot_product_attention` for more details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q (array): Input query array.
|
||||||
|
k (array): Input keys array.
|
||||||
|
k_scales (array): Scales for the quantized keys array.
|
||||||
|
k_biases (array): Biases for the quantized keys array.
|
||||||
|
v (array): Input values array.
|
||||||
|
v_scales (array): Scales for the quantized values array.
|
||||||
|
v_biases (array): Biases for the quantized values array.
|
||||||
|
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
|
||||||
|
mask (array, optional): An additive mask to apply to the query-key scores.
|
||||||
|
group_size (int): The group size used in the KV quantization.
|
||||||
|
bits (int): The bits used in the KV quantization.
|
||||||
|
Returns:
|
||||||
|
array: The output array.
|
||||||
|
)pbdoc");
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"metal_kernel",
|
"metal_kernel",
|
||||||
[](const std::string& name,
|
[](const std::string& name,
|
||||||
|
Loading…
Reference in New Issue
Block a user