From 047a584e3d00b398aae7230500ace1c8ab39bc15 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Tue, 22 Oct 2024 20:09:27 -0700 Subject: [PATCH] 8 bit working --- benchmarks/python/sdpa_vector_bench.py | 16 +-- .../scaled_dot_product_attention.metal | 22 ++++- mlx/backend/metal/kernels/sdpa_vector.h | 97 +++++++++++++------ .../metal/scaled_dot_product_attention.cpp | 22 ++++- mlx/fast.cpp | 8 +- mlx/fast_primitives.h | 10 +- 6 files changed, 127 insertions(+), 48 deletions(-) diff --git a/benchmarks/python/sdpa_vector_bench.py b/benchmarks/python/sdpa_vector_bench.py index c8c4eea30..2dc2da7f8 100644 --- a/benchmarks/python/sdpa_vector_bench.py +++ b/benchmarks/python/sdpa_vector_bench.py @@ -25,18 +25,18 @@ def attention(q, k, v): def sdpa(q, k, v): - k = mx.quantize(k) - v = mx.quantize(v) - k = mx.dequantize(*k) - v = mx.dequantize(*v) - return mx.fast.scaled_dot_product_attention(q, k, v, scale=0.08, mask=None) + k = mx.quantize(k, bits=8) + v = mx.quantize(v, bits=8) + k = mx.dequantize(*k, bits=8) + v = mx.dequantize(*v, bits=8) + return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None) def quant_sdpa(q, k, v): - k = mx.quantize(k) - v = mx.quantize(v) + k = mx.quantize(k, bits=8) + v = mx.quantize(v, bits=8) return mx.fast.quantized_scaled_dot_product_attention( - q, *k, *v, scale=0.08, mask=None + q, *k, *v, scale=1.0, mask=None, bits=8 ) diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index 3717febfb..daed3ebbc 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -939,13 +939,25 @@ instantiate_sdpa_vector_heads(bfloat16_t) instantiate_sdpa_vector_heads(float16_t) // Quantized SDPA vector instantiations -#define instantiate_quant_sdpa_vector(type, head_dim) \ - instantiate_kernel("quant_sdpa_vector_" #type "_" #head_dim, quant_sdpa_vector, type, head_dim, 64, 4) +#define instantiate_quant_sdpa_vector(type, head_dim, group_size, bits) \ + instantiate_kernel( \ + "quant_sdpa_vector_" #type "_" #head_dim "_" #group_size "_" #bits, \ + quant_sdpa_vector, type, head_dim, group_size, bits) + +#define instantiate_quant_sdpa_vector_bits(type, heads, group_size) \ + instantiate_quant_sdpa_vector(type, heads, group_size, 4) \ + instantiate_quant_sdpa_vector(type, heads, group_size, 8) + +#define instantiate_quant_sdpa_vector_group_size(type, heads) \ + instantiate_quant_sdpa_vector_bits(type, heads, 32) \ + instantiate_quant_sdpa_vector_bits(type, heads, 64) \ + instantiate_quant_sdpa_vector_bits(type, heads, 128) #define instantiate_quant_sdpa_vector_heads(type) \ - instantiate_quant_sdpa_vector(type, 64) \ - instantiate_quant_sdpa_vector(type, 96) \ - instantiate_quant_sdpa_vector(type, 128) + instantiate_quant_sdpa_vector_group_size(type, 64) \ + instantiate_quant_sdpa_vector_group_size(type, 96) \ + instantiate_quant_sdpa_vector_group_size(type, 128) + instantiate_quant_sdpa_vector_heads(float) instantiate_quant_sdpa_vector_heads(bfloat16_t) diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index ac2da6567..38b75413d 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -118,6 +118,67 @@ template } } +template +METAL_FUNC U load_queries(const device T* queries, thread U* q, U scale) { + U query_sum = 0; + if (bits == 4) { + for (int i = 0; i < elem_per_thread; i += 4) { + q[i] = scale * queries[i]; + q[i + 1] = scale * queries[i + 1]; + q[i + 2] = scale * queries[i + 2]; + q[i + 3] = scale * queries[i + 3]; + query_sum += q[i] + q[i + 1] + q[i + 2] + q[i + 3]; + q[i + 1] /= 16.0f; + q[i + 2] /= 256.0f; + q[i + 3] /= 4096.0f; + } + } else if (bits == 8) { + for (int i = 0; i < elem_per_thread; i++) { + q[i] = scale * queries[i]; + query_sum += q[i]; + } + } + return query_sum; +} + +template +METAL_FUNC void load_keys(const device uint32_t* keys, thread U* k) { + if (bits == 4) { + auto ks = (const device uint16_t*)keys; + for (int i = 0; i < elem_per_thread / 4; i++) { + k[4 * i] = ks[i] & 0x000f; + k[4 * i + 1] = ks[i] & 0x00f0; + k[4 * i + 2] = ks[i] & 0x0f00; + k[4 * i + 3] = ks[i] & 0xf000; + } + } else if (bits == 8) { + auto ks = (const device uint8_t*)keys; + for (int i = 0; i < elem_per_thread; i++) { + k[i] = ks[i]; + } + } +} + +template +METAL_FUNC void load_values( + const device uint32_t* values, + thread U* v, + U value_scale, + U value_bias) { + auto vs = (const device uint8_t*)values; + if (bits == 4) { + U s[2] = {value_scale, value_scale / 16.0f}; + for (int i = 0; i < elem_per_thread / 2; i++) { + v[2 * i] = s[0] * (vs[i] & 0x0f) + value_bias; + v[2 * i + 1] = s[1] * (vs[i] & 0xf0) + value_bias; + } + } else if (bits == 8) { + for (int i = 0; i < elem_per_thread; i++) { + v[i] = value_scale * vs[i] + value_bias; + } + } +} + template [[kernel]] void quant_sdpa_vector( const device T* queries [[buffer(0)]], @@ -174,15 +235,8 @@ template out += head_idx * D + simd_gid * elem_per_thread; // Read the query and 0 the output accumulator - U query_sum = 0; - U shifts[4] = {1, 16, 256, 4096}; - for (int i = 0; i < elem_per_thread; i++) { - // Shift by the appropriate amount here - U shift = shifts[i % 4]; - q[i] = static_cast(scale) * queries[i]; - query_sum += q[i]; - q[i] /= shift; - } + U query_sum = load_queries( + queries, q, static_cast(scale)); for (int i = 0; i < elem_per_thread; i++) { o[i] = 0; } @@ -192,15 +246,9 @@ template // For each key for (int i = quad_gid; i < N; i += BN) { - // Read the key - auto ks = (const device uint16_t*)keys; - for (int i = 0; i < elem_per_thread / 4; i++) { - k[4 * i] = ks[i] & 0x000f; - k[4 * i + 1] = ks[i] & 0x00f0; - k[4 * i + 2] = ks[i] & 0x0f00; - k[4 * i + 3] = ks[i] & 0xf000; - } - // All the keys in a set are in the same group + load_keys(keys, k); + + // Assume D % group_size == 0 so all the keys are in the same group U key_scale = key_scales[0]; U key_bias = key_biases[0]; @@ -224,18 +272,7 @@ template U value_bias = value_biases[0]; // Load the values - auto vs = (const device uint16_t*)values; - U s[4] = { - value_scale, - value_scale / 16.0f, - value_scale / 256.0f, - value_scale / 4096.0f}; - for (int i = 0; i < elem_per_thread / 4; i++) { - v[4 * i] = s[0] * (vs[i] & 0x000f) + value_bias; - v[4 * i + 1] = s[1] * (vs[i] & 0x00f0) + value_bias; - v[4 * i + 2] = s[2] * (vs[i] & 0x0f00) + value_bias; - v[4 * i + 3] = s[3] * (vs[i] & 0xf000) + value_bias; - } + load_values(values, v, value_scale, value_bias); // Update the output accumulator for (int i = 0; i < elem_per_thread; i++) { diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index e85c95acc..c5251556f 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -198,7 +198,9 @@ void quant_sdpa_vector( const array& v_scales, const array& v_biases, array& out, - float scale) { + float scale, + int group_size, + int bits) { // Set the kernel name std::string kname; kname.reserve(96); @@ -206,6 +208,10 @@ void quant_sdpa_vector( kname += get_type_string(q.dtype()); kname += "_"; kname += std::to_string(q.shape(-1)); + kname += "_"; + kname += std::to_string(group_size); + kname += "_"; + kname += std::to_string(bits); // Compute the necessary sizes int gqa_factor = q.shape(1) / k.shape(1); @@ -314,7 +320,19 @@ void ScaledDotProductAttention::eval_gpu( } quant_sdpa_vector( - s, d, q, k, k_scales, k_biases, v, v_scales, v_biases, o, scale_); + s, + d, + q, + k, + k_scales, + k_biases, + v, + v_scales, + v_biases, + o, + scale_, + group_size_, + bits_); } diff --git a/mlx/fast.cpp b/mlx/fast.cpp index a6f377cd6..f01e4fe2c 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -773,7 +773,13 @@ array quantized_scaled_dot_product_attention( std::move(out_shape), queries.dtype(), std::make_shared( - stream, fallback, scale, /*needs_mask=*/false, /*quantized=*/true), + stream, + fallback, + scale, + /*needs_mask=*/false, + /*quantized=*/true, + group_size, + bits), {queries, keys, key_scales, diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index f97808d35..cb79aee31 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -191,11 +191,15 @@ class ScaledDotProductAttention : public Custom { std::function(std::vector)> fallback, const float scale, const bool needs_mask, - const bool quantized) + const bool quantized, + const int group_size = 64, + const int bits = 4) : Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask), - quantized_(quantized) {} + quantized_(quantized), + group_size_(group_size), + bits_(bits) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override { @@ -217,6 +221,8 @@ class ScaledDotProductAttention : public Custom { float scale_; bool needs_mask_; bool quantized_; + int group_size_; + int bits_; }; class AffineQuantize : public Custom {