mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-29 13:55:29 +08:00
8 bit working
This commit is contained in:
parent
ef14b1e9c3
commit
047a584e3d
@ -25,18 +25,18 @@ def attention(q, k, v):
|
||||
|
||||
|
||||
def sdpa(q, k, v):
|
||||
k = mx.quantize(k)
|
||||
v = mx.quantize(v)
|
||||
k = mx.dequantize(*k)
|
||||
v = mx.dequantize(*v)
|
||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=0.08, mask=None)
|
||||
k = mx.quantize(k, bits=8)
|
||||
v = mx.quantize(v, bits=8)
|
||||
k = mx.dequantize(*k, bits=8)
|
||||
v = mx.dequantize(*v, bits=8)
|
||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None)
|
||||
|
||||
|
||||
def quant_sdpa(q, k, v):
|
||||
k = mx.quantize(k)
|
||||
v = mx.quantize(v)
|
||||
k = mx.quantize(k, bits=8)
|
||||
v = mx.quantize(v, bits=8)
|
||||
return mx.fast.quantized_scaled_dot_product_attention(
|
||||
q, *k, *v, scale=0.08, mask=None
|
||||
q, *k, *v, scale=1.0, mask=None, bits=8
|
||||
)
|
||||
|
||||
|
||||
|
@ -939,13 +939,25 @@ instantiate_sdpa_vector_heads(bfloat16_t)
|
||||
instantiate_sdpa_vector_heads(float16_t)
|
||||
|
||||
// Quantized SDPA vector instantiations
|
||||
#define instantiate_quant_sdpa_vector(type, head_dim) \
|
||||
instantiate_kernel("quant_sdpa_vector_" #type "_" #head_dim, quant_sdpa_vector, type, head_dim, 64, 4)
|
||||
#define instantiate_quant_sdpa_vector(type, head_dim, group_size, bits) \
|
||||
instantiate_kernel( \
|
||||
"quant_sdpa_vector_" #type "_" #head_dim "_" #group_size "_" #bits, \
|
||||
quant_sdpa_vector, type, head_dim, group_size, bits)
|
||||
|
||||
#define instantiate_quant_sdpa_vector_bits(type, heads, group_size) \
|
||||
instantiate_quant_sdpa_vector(type, heads, group_size, 4) \
|
||||
instantiate_quant_sdpa_vector(type, heads, group_size, 8)
|
||||
|
||||
#define instantiate_quant_sdpa_vector_group_size(type, heads) \
|
||||
instantiate_quant_sdpa_vector_bits(type, heads, 32) \
|
||||
instantiate_quant_sdpa_vector_bits(type, heads, 64) \
|
||||
instantiate_quant_sdpa_vector_bits(type, heads, 128)
|
||||
|
||||
#define instantiate_quant_sdpa_vector_heads(type) \
|
||||
instantiate_quant_sdpa_vector(type, 64) \
|
||||
instantiate_quant_sdpa_vector(type, 96) \
|
||||
instantiate_quant_sdpa_vector(type, 128)
|
||||
instantiate_quant_sdpa_vector_group_size(type, 64) \
|
||||
instantiate_quant_sdpa_vector_group_size(type, 96) \
|
||||
instantiate_quant_sdpa_vector_group_size(type, 128)
|
||||
|
||||
|
||||
instantiate_quant_sdpa_vector_heads(float)
|
||||
instantiate_quant_sdpa_vector_heads(bfloat16_t)
|
||||
|
@ -118,6 +118,67 @@ template <typename T, int D>
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, int elem_per_thread, int bits>
|
||||
METAL_FUNC U load_queries(const device T* queries, thread U* q, U scale) {
|
||||
U query_sum = 0;
|
||||
if (bits == 4) {
|
||||
for (int i = 0; i < elem_per_thread; i += 4) {
|
||||
q[i] = scale * queries[i];
|
||||
q[i + 1] = scale * queries[i + 1];
|
||||
q[i + 2] = scale * queries[i + 2];
|
||||
q[i + 3] = scale * queries[i + 3];
|
||||
query_sum += q[i] + q[i + 1] + q[i + 2] + q[i + 3];
|
||||
q[i + 1] /= 16.0f;
|
||||
q[i + 2] /= 256.0f;
|
||||
q[i + 3] /= 4096.0f;
|
||||
}
|
||||
} else if (bits == 8) {
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
q[i] = scale * queries[i];
|
||||
query_sum += q[i];
|
||||
}
|
||||
}
|
||||
return query_sum;
|
||||
}
|
||||
|
||||
template <typename U, int elem_per_thread, int bits>
|
||||
METAL_FUNC void load_keys(const device uint32_t* keys, thread U* k) {
|
||||
if (bits == 4) {
|
||||
auto ks = (const device uint16_t*)keys;
|
||||
for (int i = 0; i < elem_per_thread / 4; i++) {
|
||||
k[4 * i] = ks[i] & 0x000f;
|
||||
k[4 * i + 1] = ks[i] & 0x00f0;
|
||||
k[4 * i + 2] = ks[i] & 0x0f00;
|
||||
k[4 * i + 3] = ks[i] & 0xf000;
|
||||
}
|
||||
} else if (bits == 8) {
|
||||
auto ks = (const device uint8_t*)keys;
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
k[i] = ks[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int elem_per_thread, int bits>
|
||||
METAL_FUNC void load_values(
|
||||
const device uint32_t* values,
|
||||
thread U* v,
|
||||
U value_scale,
|
||||
U value_bias) {
|
||||
auto vs = (const device uint8_t*)values;
|
||||
if (bits == 4) {
|
||||
U s[2] = {value_scale, value_scale / 16.0f};
|
||||
for (int i = 0; i < elem_per_thread / 2; i++) {
|
||||
v[2 * i] = s[0] * (vs[i] & 0x0f) + value_bias;
|
||||
v[2 * i + 1] = s[1] * (vs[i] & 0xf0) + value_bias;
|
||||
}
|
||||
} else if (bits == 8) {
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
v[i] = value_scale * vs[i] + value_bias;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int D, int group_size, int bits>
|
||||
[[kernel]] void quant_sdpa_vector(
|
||||
const device T* queries [[buffer(0)]],
|
||||
@ -174,15 +235,8 @@ template <typename T, int D, int group_size, int bits>
|
||||
out += head_idx * D + simd_gid * elem_per_thread;
|
||||
|
||||
// Read the query and 0 the output accumulator
|
||||
U query_sum = 0;
|
||||
U shifts[4] = {1, 16, 256, 4096};
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
// Shift by the appropriate amount here
|
||||
U shift = shifts[i % 4];
|
||||
q[i] = static_cast<U>(scale) * queries[i];
|
||||
query_sum += q[i];
|
||||
q[i] /= shift;
|
||||
}
|
||||
U query_sum = load_queries<T, U, elem_per_thread, bits>(
|
||||
queries, q, static_cast<U>(scale));
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
o[i] = 0;
|
||||
}
|
||||
@ -192,15 +246,9 @@ template <typename T, int D, int group_size, int bits>
|
||||
|
||||
// For each key
|
||||
for (int i = quad_gid; i < N; i += BN) {
|
||||
// Read the key
|
||||
auto ks = (const device uint16_t*)keys;
|
||||
for (int i = 0; i < elem_per_thread / 4; i++) {
|
||||
k[4 * i] = ks[i] & 0x000f;
|
||||
k[4 * i + 1] = ks[i] & 0x00f0;
|
||||
k[4 * i + 2] = ks[i] & 0x0f00;
|
||||
k[4 * i + 3] = ks[i] & 0xf000;
|
||||
}
|
||||
// All the keys in a set are in the same group
|
||||
load_keys<U, elem_per_thread, bits>(keys, k);
|
||||
|
||||
// Assume D % group_size == 0 so all the keys are in the same group
|
||||
U key_scale = key_scales[0];
|
||||
U key_bias = key_biases[0];
|
||||
|
||||
@ -224,18 +272,7 @@ template <typename T, int D, int group_size, int bits>
|
||||
U value_bias = value_biases[0];
|
||||
|
||||
// Load the values
|
||||
auto vs = (const device uint16_t*)values;
|
||||
U s[4] = {
|
||||
value_scale,
|
||||
value_scale / 16.0f,
|
||||
value_scale / 256.0f,
|
||||
value_scale / 4096.0f};
|
||||
for (int i = 0; i < elem_per_thread / 4; i++) {
|
||||
v[4 * i] = s[0] * (vs[i] & 0x000f) + value_bias;
|
||||
v[4 * i + 1] = s[1] * (vs[i] & 0x00f0) + value_bias;
|
||||
v[4 * i + 2] = s[2] * (vs[i] & 0x0f00) + value_bias;
|
||||
v[4 * i + 3] = s[3] * (vs[i] & 0xf000) + value_bias;
|
||||
}
|
||||
load_values<U, elem_per_thread, bits>(values, v, value_scale, value_bias);
|
||||
|
||||
// Update the output accumulator
|
||||
for (int i = 0; i < elem_per_thread; i++) {
|
||||
|
@ -198,7 +198,9 @@ void quant_sdpa_vector(
|
||||
const array& v_scales,
|
||||
const array& v_biases,
|
||||
array& out,
|
||||
float scale) {
|
||||
float scale,
|
||||
int group_size,
|
||||
int bits) {
|
||||
// Set the kernel name
|
||||
std::string kname;
|
||||
kname.reserve(96);
|
||||
@ -206,6 +208,10 @@ void quant_sdpa_vector(
|
||||
kname += get_type_string(q.dtype());
|
||||
kname += "_";
|
||||
kname += std::to_string(q.shape(-1));
|
||||
kname += "_";
|
||||
kname += std::to_string(group_size);
|
||||
kname += "_";
|
||||
kname += std::to_string(bits);
|
||||
|
||||
// Compute the necessary sizes
|
||||
int gqa_factor = q.shape(1) / k.shape(1);
|
||||
@ -314,7 +320,19 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
}
|
||||
|
||||
quant_sdpa_vector(
|
||||
s, d, q, k, k_scales, k_biases, v, v_scales, v_biases, o, scale_);
|
||||
s,
|
||||
d,
|
||||
q,
|
||||
k,
|
||||
k_scales,
|
||||
k_biases,
|
||||
v,
|
||||
v_scales,
|
||||
v_biases,
|
||||
o,
|
||||
scale_,
|
||||
group_size_,
|
||||
bits_);
|
||||
|
||||
}
|
||||
|
||||
|
@ -773,7 +773,13 @@ array quantized_scaled_dot_product_attention(
|
||||
std::move(out_shape),
|
||||
queries.dtype(),
|
||||
std::make_shared<ScaledDotProductAttention>(
|
||||
stream, fallback, scale, /*needs_mask=*/false, /*quantized=*/true),
|
||||
stream,
|
||||
fallback,
|
||||
scale,
|
||||
/*needs_mask=*/false,
|
||||
/*quantized=*/true,
|
||||
group_size,
|
||||
bits),
|
||||
{queries,
|
||||
keys,
|
||||
key_scales,
|
||||
|
@ -191,11 +191,15 @@ class ScaledDotProductAttention : public Custom {
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
const float scale,
|
||||
const bool needs_mask,
|
||||
const bool quantized)
|
||||
const bool quantized,
|
||||
const int group_size = 64,
|
||||
const int bits = 4)
|
||||
: Custom(stream, fallback),
|
||||
scale_(scale),
|
||||
needs_mask_(needs_mask),
|
||||
quantized_(quantized) {}
|
||||
quantized_(quantized),
|
||||
group_size_(group_size),
|
||||
bits_(bits) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
@ -217,6 +221,8 @@ class ScaledDotProductAttention : public Custom {
|
||||
float scale_;
|
||||
bool needs_mask_;
|
||||
bool quantized_;
|
||||
int group_size_;
|
||||
int bits_;
|
||||
};
|
||||
|
||||
class AffineQuantize : public Custom {
|
||||
|
Loading…
Reference in New Issue
Block a user