8 bit working

This commit is contained in:
Alex Barron 2024-10-22 20:09:27 -07:00
parent ef14b1e9c3
commit 047a584e3d
6 changed files with 127 additions and 48 deletions

View File

@ -25,18 +25,18 @@ def attention(q, k, v):
def sdpa(q, k, v):
k = mx.quantize(k)
v = mx.quantize(v)
k = mx.dequantize(*k)
v = mx.dequantize(*v)
return mx.fast.scaled_dot_product_attention(q, k, v, scale=0.08, mask=None)
k = mx.quantize(k, bits=8)
v = mx.quantize(v, bits=8)
k = mx.dequantize(*k, bits=8)
v = mx.dequantize(*v, bits=8)
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None)
def quant_sdpa(q, k, v):
k = mx.quantize(k)
v = mx.quantize(v)
k = mx.quantize(k, bits=8)
v = mx.quantize(v, bits=8)
return mx.fast.quantized_scaled_dot_product_attention(
q, *k, *v, scale=0.08, mask=None
q, *k, *v, scale=1.0, mask=None, bits=8
)

View File

@ -939,13 +939,25 @@ instantiate_sdpa_vector_heads(bfloat16_t)
instantiate_sdpa_vector_heads(float16_t)
// Quantized SDPA vector instantiations
#define instantiate_quant_sdpa_vector(type, head_dim) \
instantiate_kernel("quant_sdpa_vector_" #type "_" #head_dim, quant_sdpa_vector, type, head_dim, 64, 4)
#define instantiate_quant_sdpa_vector(type, head_dim, group_size, bits) \
instantiate_kernel( \
"quant_sdpa_vector_" #type "_" #head_dim "_" #group_size "_" #bits, \
quant_sdpa_vector, type, head_dim, group_size, bits)
#define instantiate_quant_sdpa_vector_bits(type, heads, group_size) \
instantiate_quant_sdpa_vector(type, heads, group_size, 4) \
instantiate_quant_sdpa_vector(type, heads, group_size, 8)
#define instantiate_quant_sdpa_vector_group_size(type, heads) \
instantiate_quant_sdpa_vector_bits(type, heads, 32) \
instantiate_quant_sdpa_vector_bits(type, heads, 64) \
instantiate_quant_sdpa_vector_bits(type, heads, 128)
#define instantiate_quant_sdpa_vector_heads(type) \
instantiate_quant_sdpa_vector(type, 64) \
instantiate_quant_sdpa_vector(type, 96) \
instantiate_quant_sdpa_vector(type, 128)
instantiate_quant_sdpa_vector_group_size(type, 64) \
instantiate_quant_sdpa_vector_group_size(type, 96) \
instantiate_quant_sdpa_vector_group_size(type, 128)
instantiate_quant_sdpa_vector_heads(float)
instantiate_quant_sdpa_vector_heads(bfloat16_t)

View File

@ -118,6 +118,67 @@ template <typename T, int D>
}
}
template <typename T, typename U, int elem_per_thread, int bits>
METAL_FUNC U load_queries(const device T* queries, thread U* q, U scale) {
U query_sum = 0;
if (bits == 4) {
for (int i = 0; i < elem_per_thread; i += 4) {
q[i] = scale * queries[i];
q[i + 1] = scale * queries[i + 1];
q[i + 2] = scale * queries[i + 2];
q[i + 3] = scale * queries[i + 3];
query_sum += q[i] + q[i + 1] + q[i + 2] + q[i + 3];
q[i + 1] /= 16.0f;
q[i + 2] /= 256.0f;
q[i + 3] /= 4096.0f;
}
} else if (bits == 8) {
for (int i = 0; i < elem_per_thread; i++) {
q[i] = scale * queries[i];
query_sum += q[i];
}
}
return query_sum;
}
template <typename U, int elem_per_thread, int bits>
METAL_FUNC void load_keys(const device uint32_t* keys, thread U* k) {
if (bits == 4) {
auto ks = (const device uint16_t*)keys;
for (int i = 0; i < elem_per_thread / 4; i++) {
k[4 * i] = ks[i] & 0x000f;
k[4 * i + 1] = ks[i] & 0x00f0;
k[4 * i + 2] = ks[i] & 0x0f00;
k[4 * i + 3] = ks[i] & 0xf000;
}
} else if (bits == 8) {
auto ks = (const device uint8_t*)keys;
for (int i = 0; i < elem_per_thread; i++) {
k[i] = ks[i];
}
}
}
template <typename U, int elem_per_thread, int bits>
METAL_FUNC void load_values(
const device uint32_t* values,
thread U* v,
U value_scale,
U value_bias) {
auto vs = (const device uint8_t*)values;
if (bits == 4) {
U s[2] = {value_scale, value_scale / 16.0f};
for (int i = 0; i < elem_per_thread / 2; i++) {
v[2 * i] = s[0] * (vs[i] & 0x0f) + value_bias;
v[2 * i + 1] = s[1] * (vs[i] & 0xf0) + value_bias;
}
} else if (bits == 8) {
for (int i = 0; i < elem_per_thread; i++) {
v[i] = value_scale * vs[i] + value_bias;
}
}
}
template <typename T, int D, int group_size, int bits>
[[kernel]] void quant_sdpa_vector(
const device T* queries [[buffer(0)]],
@ -174,15 +235,8 @@ template <typename T, int D, int group_size, int bits>
out += head_idx * D + simd_gid * elem_per_thread;
// Read the query and 0 the output accumulator
U query_sum = 0;
U shifts[4] = {1, 16, 256, 4096};
for (int i = 0; i < elem_per_thread; i++) {
// Shift by the appropriate amount here
U shift = shifts[i % 4];
q[i] = static_cast<U>(scale) * queries[i];
query_sum += q[i];
q[i] /= shift;
}
U query_sum = load_queries<T, U, elem_per_thread, bits>(
queries, q, static_cast<U>(scale));
for (int i = 0; i < elem_per_thread; i++) {
o[i] = 0;
}
@ -192,15 +246,9 @@ template <typename T, int D, int group_size, int bits>
// For each key
for (int i = quad_gid; i < N; i += BN) {
// Read the key
auto ks = (const device uint16_t*)keys;
for (int i = 0; i < elem_per_thread / 4; i++) {
k[4 * i] = ks[i] & 0x000f;
k[4 * i + 1] = ks[i] & 0x00f0;
k[4 * i + 2] = ks[i] & 0x0f00;
k[4 * i + 3] = ks[i] & 0xf000;
}
// All the keys in a set are in the same group
load_keys<U, elem_per_thread, bits>(keys, k);
// Assume D % group_size == 0 so all the keys are in the same group
U key_scale = key_scales[0];
U key_bias = key_biases[0];
@ -224,18 +272,7 @@ template <typename T, int D, int group_size, int bits>
U value_bias = value_biases[0];
// Load the values
auto vs = (const device uint16_t*)values;
U s[4] = {
value_scale,
value_scale / 16.0f,
value_scale / 256.0f,
value_scale / 4096.0f};
for (int i = 0; i < elem_per_thread / 4; i++) {
v[4 * i] = s[0] * (vs[i] & 0x000f) + value_bias;
v[4 * i + 1] = s[1] * (vs[i] & 0x00f0) + value_bias;
v[4 * i + 2] = s[2] * (vs[i] & 0x0f00) + value_bias;
v[4 * i + 3] = s[3] * (vs[i] & 0xf000) + value_bias;
}
load_values<U, elem_per_thread, bits>(values, v, value_scale, value_bias);
// Update the output accumulator
for (int i = 0; i < elem_per_thread; i++) {

View File

@ -198,7 +198,9 @@ void quant_sdpa_vector(
const array& v_scales,
const array& v_biases,
array& out,
float scale) {
float scale,
int group_size,
int bits) {
// Set the kernel name
std::string kname;
kname.reserve(96);
@ -206,6 +208,10 @@ void quant_sdpa_vector(
kname += get_type_string(q.dtype());
kname += "_";
kname += std::to_string(q.shape(-1));
kname += "_";
kname += std::to_string(group_size);
kname += "_";
kname += std::to_string(bits);
// Compute the necessary sizes
int gqa_factor = q.shape(1) / k.shape(1);
@ -314,7 +320,19 @@ void ScaledDotProductAttention::eval_gpu(
}
quant_sdpa_vector(
s, d, q, k, k_scales, k_biases, v, v_scales, v_biases, o, scale_);
s,
d,
q,
k,
k_scales,
k_biases,
v,
v_scales,
v_biases,
o,
scale_,
group_size_,
bits_);
}

View File

@ -773,7 +773,13 @@ array quantized_scaled_dot_product_attention(
std::move(out_shape),
queries.dtype(),
std::make_shared<ScaledDotProductAttention>(
stream, fallback, scale, /*needs_mask=*/false, /*quantized=*/true),
stream,
fallback,
scale,
/*needs_mask=*/false,
/*quantized=*/true,
group_size,
bits),
{queries,
keys,
key_scales,

View File

@ -191,11 +191,15 @@ class ScaledDotProductAttention : public Custom {
std::function<std::vector<array>(std::vector<array>)> fallback,
const float scale,
const bool needs_mask,
const bool quantized)
const bool quantized,
const int group_size = 64,
const int bits = 4)
: Custom(stream, fallback),
scale_(scale),
needs_mask_(needs_mask),
quantized_(quantized) {}
quantized_(quantized),
group_size_(group_size),
bits_(bits) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
@ -217,6 +221,8 @@ class ScaledDotProductAttention : public Custom {
float scale_;
bool needs_mask_;
bool quantized_;
int group_size_;
int bits_;
};
class AffineQuantize : public Custom {