This commit is contained in:
Alex Barron 2024-10-22 16:14:29 -07:00
parent 8e88e30d95
commit 5824626c0b
8 changed files with 532 additions and 62 deletions

View File

@ -1,16 +1,18 @@
import argparse
import math
import mlx.core as mx import mlx.core as mx
import numpy as np
from time_utils import time_fn from time_utils import time_fn
L = 1024 L = 30000
H = 32 H = 32
H_k = 32 // 4 H_k = 32 // 4
D = 128 D = 128
def attention(q, k, v): def attention(q, k, v):
k = mx.quantize(k)
v = mx.quantize(v)
k = mx.dequantize(*k)
v = mx.dequantize(*v)
B, Hq, L, D = q.shape B, Hq, L, D = q.shape
_, Hk, S, _ = k.shape _, Hk, S, _ = k.shape
q = q.reshape(B, Hk, Hq // Hk, L, D) q = q.reshape(B, Hk, Hq // Hk, L, D)
@ -23,27 +25,54 @@ def attention(q, k, v):
def sdpa(q, k, v): def sdpa(q, k, v):
k = mx.quantize(k)
v = mx.quantize(v)
k = mx.dequantize(*k)
v = mx.dequantize(*v)
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0) return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
def time_self_attention_primitives(): def quant_sdpa(q, k, v):
mx.random.seed(3) k = mx.quantize(k)
q = mx.random.uniform(shape=(1, H, 1, D)) v = mx.quantize(v)
k = mx.random.uniform(shape=(1, H_k, L, D)) return mx.fast.quantized_scaled_dot_product_attention(q, *k, *v, scale=1.0)
v = mx.random.uniform(shape=(1, H_k, L, D))
mx.eval(q, k, v)
def time_self_attention_primitives(q, k, v):
time_fn(attention, q, k, v) time_fn(attention, q, k, v)
def time_self_attention_sdpa(): def time_self_attention_sdpa(q, k, v):
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D))
k = mx.random.uniform(shape=(1, H_k, L, D))
v = mx.random.uniform(shape=(1, H_k, L, D))
mx.eval(q, k, v)
time_fn(sdpa, q, k, v) time_fn(sdpa, q, k, v)
def time_self_attention_quant_sdpa(q, k, v):
time_fn(quant_sdpa, q, k, v)
if __name__ == "__main__": if __name__ == "__main__":
time_self_attention_sdpa() mx.random.seed(3)
time_self_attention_primitives() q = mx.random.uniform(shape=(1, H, 10, D))
k = mx.random.uniform(shape=(1, H_k, L, D))
v = mx.random.uniform(shape=(1, H_k, L, D))
mx.eval(q, k, v)
k_quant = mx.quantize(k)
v_quant = mx.quantize(v)
mx.eval(k_quant, v_quant)
# time_self_attention_sdpa(q, k, v)
# time_self_attention_quant_sdpa(q, k_quant, v_quant)
# time_self_attention_primitives(q, k, v)
q_sdpa = quant_sdpa(q, k, v)
print(q_sdpa)
o_attention = attention(q, k, v)
print(o_attention)
np.testing.assert_allclose(q_sdpa, o_attention, atol=1e-5)
# o_sdpa = sdpa(q, k, v)
# print(o_sdpa)
# np.testing.assert_allclose(q_sdpa, o_sdpa, atol=1e-5)
# print(o_sdpa[..., :64])
# print()
# print(o_attention[..., :64])
# np.testing.assert_allclose(o_sdpa, o_attention)

View File

@ -927,19 +927,7 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2);
// SDPA vector instantiations // SDPA vector instantiations
#define instantiate_sdpa_vector(type, head_dim) \ #define instantiate_sdpa_vector(type, head_dim) \
template [[host_name("sdpa_vector_" #type "_" #head_dim)]] \ instantiate_kernel("sdpa_vector_" #type "_" #head_dim, sdpa_vector, type, head_dim)
[[kernel]] void sdpa_vector<type, head_dim>( \
const device type* queries [[buffer(0)]], \
const device type* keys [[buffer(1)]], \
const device type* values [[buffer(2)]], \
device type* out [[buffer(3)]], \
const constant int& gqa_factor, \
const constant int& N, \
const constant size_t& k_stride, \
const constant float& scale, \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_sdpa_vector_heads(type) \ #define instantiate_sdpa_vector_heads(type) \
instantiate_sdpa_vector(type, 64) \ instantiate_sdpa_vector(type, 64) \
@ -949,4 +937,18 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2);
instantiate_sdpa_vector_heads(float) instantiate_sdpa_vector_heads(float)
instantiate_sdpa_vector_heads(bfloat16_t) instantiate_sdpa_vector_heads(bfloat16_t)
instantiate_sdpa_vector_heads(float16_t) instantiate_sdpa_vector_heads(float16_t)
// Quantized SDPA vector instantiations
#define instantiate_quant_sdpa_vector(type, head_dim) \
instantiate_kernel("quant_sdpa_vector_" #type "_" #head_dim, quant_sdpa_vector, type, head_dim, 64, 4)
#define instantiate_quant_sdpa_vector_heads(type) \
instantiate_quant_sdpa_vector(type, 64) \
instantiate_quant_sdpa_vector(type, 96) \
instantiate_quant_sdpa_vector(type, 128)
instantiate_quant_sdpa_vector_heads(float)
instantiate_quant_sdpa_vector_heads(bfloat16_t)
instantiate_quant_sdpa_vector_heads(float16_t)
// clang-format on // clang-format on

View File

@ -16,9 +16,11 @@ template <typename T, int D>
const constant float& scale, const constant float& scale,
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]],
uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]]) {
constexpr int BN = 32; constexpr int BN = 32;
constexpr int BD = 32; constexpr int BD = 4;
constexpr int elem_per_thread = D / BD; constexpr int elem_per_thread = D / BD;
const int stride = BN * D; const int stride = BN * D;
@ -36,9 +38,9 @@ template <typename T, int D>
// Adjust positions // Adjust positions
const int head_idx = tid.y; const int head_idx = tid.y;
const int kv_head_idx = head_idx / gqa_factor; const int kv_head_idx = head_idx / gqa_factor;
queries += head_idx * D + simd_lid * elem_per_thread; queries += head_idx * D + quad_lid * elem_per_thread;
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread; keys += kv_head_idx * k_stride + quad_gid * D + quad_lid * elem_per_thread;
values += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread; values += kv_head_idx * k_stride + quad_gid * D + quad_lid * elem_per_thread;
out += head_idx * D + simd_gid * elem_per_thread; out += head_idx * D + simd_gid * elem_per_thread;
// Read the query and 0 the output accumulator // Read the query and 0 the output accumulator
@ -53,7 +55,7 @@ template <typename T, int D>
U sum_exp_score = 0; U sum_exp_score = 0;
// For each key // For each key
for (int i = simd_gid; i < N; i += BN) { for (int i = quad_gid; i < N; i += BN) {
// Read the key // Read the key
for (int i = 0; i < elem_per_thread; i++) { for (int i = 0; i < elem_per_thread; i++) {
k[i] = keys[i]; k[i] = keys[i];
@ -64,7 +66,7 @@ template <typename T, int D>
for (int i = 0; i < elem_per_thread; i++) { for (int i = 0; i < elem_per_thread; i++) {
score += q[i] * k[i]; score += q[i] * k[i];
} }
score = simd_sum(score); score = quad_sum(score);
// Update the accumulators // Update the accumulators
U new_max = max(max_score, score); U new_max = max(max_score, score);
@ -88,9 +90,10 @@ template <typename T, int D>
// Each thread has a partial part of the output so we need to combine them. // Each thread has a partial part of the output so we need to combine them.
// First let's communicate the max and sum_exp // First let's communicate the max and sum_exp
if (simd_lid == 0) { // Each quadgroup communicates it's max score
max_scores[simd_gid] = max_score; if (quad_lid == 0) {
sum_exp_scores[simd_gid] = sum_exp_score; max_scores[quad_gid] = max_score;
sum_exp_scores[quad_gid] = sum_exp_score;
} }
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
max_score = max_scores[simd_lid]; max_score = max_scores[simd_lid];
@ -100,9 +103,174 @@ template <typename T, int D>
// Now we need to aggregate all the outputs // Now we need to aggregate all the outputs
for (int i = 0; i < elem_per_thread; i++) { for (int i = 0; i < elem_per_thread; i++) {
outputs[simd_lid * BD + simd_gid] = o[i]; // 128 threads with 32 values per thread
outputs[simd_gid * BN + simd_lid] = o[i];
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score; o[i] = simd_sum(outputs[simd_lid * BD + simd_gid] * factor) / sum_exp_score;
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// And write the output
if (simd_lid == 0) {
for (int i = 0; i < elem_per_thread; i++) {
out[i] = static_cast<T>(o[i]);
}
}
}
template <typename T, int D, int group_size, int bits>
[[kernel]] void quant_sdpa_vector(
const device T* queries [[buffer(0)]],
const device uint32_t* keys [[buffer(1)]],
const device T* key_scales [[buffer(2)]],
const device T* key_biases [[buffer(3)]],
const device uint32_t* values [[buffer(4)]],
const device T* value_scales [[buffer(5)]],
const device T* value_biases [[buffer(6)]],
device T* out [[buffer(7)]],
const constant int& gqa_factor,
const constant int& N,
const constant size_t& k_stride,
const constant size_t& group_stride,
const constant float& scale,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]],
uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]]) {
constexpr int BN = 32;
constexpr int BD = 4;
constexpr int elem_per_thread = D / BD;
constexpr int pack_factor = 32 / bits;
const int stride = BN * D;
typedef float U;
thread U q[elem_per_thread];
thread U k[elem_per_thread];
thread U v[elem_per_thread];
thread U o[elem_per_thread];
threadgroup U outputs[BN * BD];
threadgroup U max_scores[BN];
threadgroup U sum_exp_scores[BN];
// Adjust positions
const int head_idx = tid.y;
const int kv_head_idx = head_idx / gqa_factor;
queries += head_idx * D + quad_lid * elem_per_thread;
const int kv_idx = quad_gid * D + quad_lid * elem_per_thread;
const int packed_idx = kv_head_idx * k_stride + kv_idx / pack_factor;
const int group_idx = kv_head_idx * group_stride + kv_idx / group_size;
keys += packed_idx;
key_scales += group_idx;
key_biases += group_idx;
values += packed_idx;
value_scales += group_idx;
value_biases += group_idx;
out += head_idx * D + simd_gid * elem_per_thread;
// Read the query and 0 the output accumulator
U query_sum = 0;
U shifts[4] = {1, 16, 256, 4096};
for (int i = 0; i < elem_per_thread; i++) {
// Shift by the appropriate amount here
query_sum += queries[i];
U shift = shifts[i % 4];
q[i] = static_cast<U>(scale) * queries[i] / shift;
}
for (int i = 0; i < elem_per_thread; i++) {
o[i] = 0;
}
U max_score = -INFINITY;
U sum_exp_score = 0;
// For each key
for (int i = quad_gid; i < N; i += BN) {
// Read the key
auto ks = (const device uint16_t*)keys;
for (int i = 0; i < elem_per_thread / 4; i++) {
k[4 * i] = ks[i] & 0x000f;
k[4 * i + 1] = ks[i] & 0x00f0;
k[4 * i + 2] = ks[i] & 0x0f00;
k[4 * i + 3] = ks[i] & 0xf000;
}
// All the keys in a set are in the same group
U key_scale = key_scales[0];
U key_bias = key_biases[0];
// Compute the i-th score
U score = 0;
for (int i = 0; i < elem_per_thread; i++) {
score += q[i] * k[i];
}
score = score * key_scale + query_sum * key_bias;
score = quad_sum(score);
// Update the accumulators
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
U value_scale = value_scales[0];
U value_bias = value_biases[0];
// Load the values
auto vs = (const device uint16_t*)values;
U s[4] = {
value_scale,
value_scale / 16.0f,
value_scale / 256.0f,
value_scale / 4096.0f};
for (int i = 0; i < elem_per_thread / 4; i++) {
v[4 * i] = s[0] * (vs[i] & 0x000f) + value_bias;
v[4 * i + 1] = s[1] * (vs[i] & 0x00f0) + value_bias;
v[4 * i + 2] = s[2] * (vs[i] & 0x0f00) + value_bias;
v[4 * i + 3] = s[3] * (vs[i] & 0xf000) + value_bias;
}
// Update the output accumulator
for (int i = 0; i < elem_per_thread; i++) {
o[i] = o[i] * factor + exp_score * v[i];
}
// Move the pointers to the next kv
keys += stride / pack_factor;
key_scales += stride / group_size;
key_biases += stride / group_size;
values += stride / pack_factor;
value_scales += stride / group_size;
value_biases += stride / group_size;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Each thread has a partial part of the output so we need to combine them.
// First let's communicate the max and sum_exp
// Each quadgroup communicates it's max score
if (quad_lid == 0) {
max_scores[quad_gid] = max_score;
sum_exp_scores[quad_gid] = sum_exp_score;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_score = max_scores[simd_lid];
U new_max = simd_max(max_score);
U factor = fast::exp(max_score - new_max);
sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
// Now we need to aggregate all the outputs
for (int i = 0; i < elem_per_thread; i++) {
// 128 threads with 32 values per thread
outputs[simd_gid * BN + simd_lid] = o[i];
threadgroup_barrier(mem_flags::mem_threadgroup);
o[i] = simd_sum(outputs[simd_lid * BD + simd_gid] * factor) / sum_exp_score;
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
} }

View File

@ -9,6 +9,8 @@
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h" #include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
#include "mlx/fast_primitives.h" #include "mlx/fast_primitives.h"
#include <iostream>
namespace mlx::core::fast { namespace mlx::core::fast {
namespace { namespace {
@ -163,7 +165,7 @@ void sdpa_vector(
int N = k.shape(2); int N = k.shape(2);
int B = q.shape(0) * q.shape(1); int B = q.shape(0) * q.shape(1);
size_t stride = k.strides()[1]; size_t stride = k.strides()[1];
MTL::Size group_dims(1024, 1, 1); MTL::Size group_dims(128, 1, 1);
MTL::Size grid_dims(1, B, 1); MTL::Size grid_dims(1, B, 1);
// Get the kernel // Get the kernel
@ -185,19 +187,67 @@ void sdpa_vector(
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
} }
void quant_sdpa_vector(
const Stream& s,
metal::Device& d,
const array& q,
const array& k,
const array& k_scales,
const array& k_biases,
const array& v,
const array& v_scales,
const array& v_biases,
array& out,
float scale) {
// Set the kernel name
std::string kname;
kname.reserve(96);
kname += "quant_sdpa_vector_";
kname += get_type_string(q.dtype());
kname += "_";
kname += std::to_string(q.shape(-1));
// Compute the necessary sizes
int gqa_factor = q.shape(1) / k.shape(1);
int N = k.shape(2);
int B = q.shape(0) * q.shape(1);
size_t stride = k.strides()[1];
size_t group_stride = k_scales.strides()[1];
MTL::Size group_dims(128, 1, 1);
MTL::Size grid_dims(1, B, 1);
// Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname);
compute_encoder->setComputePipelineState(kernel);
// Set its arguments
compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0);
compute_encoder.set_input_array(k, 1);
compute_encoder.set_input_array(k_scales, 2);
compute_encoder.set_input_array(k_biases, 3);
compute_encoder.set_input_array(v, 4);
compute_encoder.set_input_array(v_scales, 5);
compute_encoder.set_input_array(v_biases, 6);
compute_encoder.set_output_array(out, 7);
compute_encoder->setBytes(&gqa_factor, sizeof(int), 8);
compute_encoder->setBytes(&N, sizeof(int), 9);
compute_encoder->setBytes(&stride, sizeof(size_t), 10);
compute_encoder->setBytes(&group_stride, sizeof(size_t), 11);
compute_encoder->setBytes(&scale, sizeof(float), 12);
// Launch
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
} // namespace } // namespace
void ScaledDotProductAttention::eval_gpu( void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out) { array& out) {
assert(inputs.size() == 3);
auto& s = stream(); auto& s = stream();
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
auto& q_pre = inputs[0];
auto& k_pre = inputs[1];
auto& v_pre = inputs[2];
auto& o = out; auto& o = out;
std::vector<array> copies; std::vector<array> copies;
@ -236,11 +286,25 @@ void ScaledDotProductAttention::eval_gpu(
return strides[3] == 1 && strides[2] == shape[3]; return strides[3] == 1 && strides[2] == shape[3];
}; };
// We are in vector mode ie single query if (quantized_) {
if (q_pre.shape(2) == 1) { auto& q_pre = inputs[0];
auto& k_pre = inputs[1];
auto& k_scales_pre = inputs[2];
auto& k_biases_pre = inputs[3];
auto& v_pre = inputs[4];
auto& v_scales_pre = inputs[5];
auto& v_biases_pre = inputs[6];
// Quantized should only be routed here for single queries
assert(q_pre.shape(2) == 1);
auto q = copy_unless(is_contiguous, q_pre); auto q = copy_unless(is_contiguous, q_pre);
auto k = copy_unless(is_contiguous_except_seq_len, k_pre); auto k = copy_unless(is_contiguous_except_seq_len, k_pre);
auto k_scales = copy_unless(is_contiguous_except_seq_len, k_scales_pre);
auto k_biases = copy_unless(is_contiguous_except_seq_len, k_biases_pre);
auto v = copy_unless(is_contiguous_except_seq_len, v_pre); auto v = copy_unless(is_contiguous_except_seq_len, v_pre);
auto v_scales = copy_unless(is_contiguous_except_seq_len, v_scales_pre);
auto v_biases = copy_unless(is_contiguous_except_seq_len, v_biases_pre);
// Donate the query if possible // Donate the query if possible
if (q.is_donatable()) { if (q.is_donatable()) {
@ -249,17 +313,42 @@ void ScaledDotProductAttention::eval_gpu(
o.set_data(allocator::malloc_or_wait(o.nbytes())); o.set_data(allocator::malloc_or_wait(o.nbytes()));
} }
sdpa_vector(s, d, q, k, v, o, scale_); quant_sdpa_vector(
s, d, q, k, k_scales, k_biases, v, v_scales, v_biases, o, scale_);
} }
// Full attention mode // Non-quantized
else { else {
auto q = copy_unless(is_matrix_contiguous, q_pre); assert(inputs.size() == 3);
auto k = copy_unless(is_matrix_contiguous, k_pre); auto& q_pre = inputs[0];
auto v = copy_unless(is_matrix_contiguous, v_pre); auto& k_pre = inputs[1];
o.set_data(allocator::malloc_or_wait(o.nbytes())); auto& v_pre = inputs[2];
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o); // We are in vector mode ie single query
if (q_pre.shape(2) == 1) {
auto q = copy_unless(is_contiguous, q_pre);
auto k = copy_unless(is_contiguous_except_seq_len, k_pre);
auto v = copy_unless(is_contiguous_except_seq_len, v_pre);
// Donate the query if possible
if (q.is_donatable()) {
o.move_shared_buffer(q);
} else {
o.set_data(allocator::malloc_or_wait(o.nbytes()));
}
sdpa_vector(s, d, q, k, v, o, scale_);
}
// Full attention mode
else {
auto q = copy_unless(is_matrix_contiguous, q_pre);
auto k = copy_unless(is_matrix_contiguous, k_pre);
auto v = copy_unless(is_matrix_contiguous, v_pre);
o.set_data(allocator::malloc_or_wait(o.nbytes()));
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);
}
} }
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);

View File

@ -10,6 +10,8 @@
#include "mlx/ops.h" #include "mlx/ops.h"
#include "mlx/transforms.h" #include "mlx/transforms.h"
#include <iostream>
namespace mlx::core::fast { namespace mlx::core::fast {
std::vector<array> Custom::vjp( std::vector<array> Custom::vjp(
@ -648,7 +650,7 @@ array scaled_dot_product_attention(
std::move(out_shape), std::move(out_shape),
final_type, final_type,
std::make_shared<ScaledDotProductAttention>( std::make_shared<ScaledDotProductAttention>(
stream, fallback, scale, false), stream, fallback, scale, /*needs_mask=*/false, /*quantized=*/false),
{q, k, v}); {q, k, v});
} }
@ -662,7 +664,124 @@ array scaled_dot_product_attention(
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const { bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
const ScaledDotProductAttention& a_other = const ScaledDotProductAttention& a_other =
static_cast<const ScaledDotProductAttention&>(other); static_cast<const ScaledDotProductAttention&>(other);
return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_; return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_ &&
quantized_ == a_other.quantized_;
}
array quantized_scaled_dot_product_attention(
const array& queries,
const array& keys,
const array& key_scales,
const array& key_biases,
const array& values,
const array& value_scales,
const array& value_biases,
const float scale,
const std::optional<array>& mask,
const int group_size,
const int bits,
StreamOrDevice s) {
int el_per_int = 32 / bits;
int out_dim = values.shape(-1) * el_per_int;
auto n_q_heads = queries.shape(-3);
auto n_kv_heads = keys.shape(-3);
std::cout << "group bits " << group_size << " " << bits << std::endl;
auto out_shape = std::vector<int>(
{queries.shape(0), queries.shape(1), queries.shape(2), out_dim});
auto stream = to_stream(s);
bool needs_mask = mask.has_value();
auto fallback =
[scale, needs_mask, n_q_heads, n_kv_heads, group_size, bits, &s](
const std::vector<array>& inputs) -> std::vector<array> {
int n_repeats = n_q_heads / n_kv_heads;
auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
auto k = inputs[1];
auto k_scales = inputs[2];
auto k_biases = inputs[3];
auto v = inputs[4];
auto v_scales = inputs[5];
auto v_biases = inputs[6];
int B = q.shape(0);
int L = q.shape(2);
if (n_repeats > 1) {
q = reshape(q, {B, n_kv_heads, n_repeats, L, -1}, s);
k = expand_dims(k, 2, s);
k_scales = expand_dims(k_scales, 2, s);
k_biases = expand_dims(k_biases, 2, s);
v = expand_dims(v, 2, s);
v_scales = expand_dims(v_scales, 2, s);
v_biases = expand_dims(v_biases, 2, s);
}
array scores = quantized_matmul(
q,
k,
k_scales,
k_biases,
/*transpose=*/true,
/*group_size=*/group_size,
/*bits=*/bits,
s);
if (needs_mask) {
scores = add(scores, inputs[7], s);
}
scores = softmax(scores, std::vector<int>{-1}, true, s);
array out = quantized_matmul(
scores,
v,
v_scales,
v_biases,
/*transpose=*/false,
/*group_size=*/group_size,
/*bits=*/bits,
s);
if (n_repeats > 1) {
out = reshape(out, {B, n_q_heads, L, -1}, s);
}
return std::vector<array>{out};
};
if (true) {
if (needs_mask) {
return fallback(
{queries,
keys,
key_scales,
key_biases,
values,
value_scales,
value_biases,
mask.value()})[0];
} else {
return fallback(
{queries,
keys,
key_scales,
key_biases,
values,
value_scales,
value_biases})[0];
}
} else {
return array(
std::move(out_shape),
queries.dtype(),
std::make_shared<ScaledDotProductAttention>(
stream, fallback, scale, /*needs_mask=*/false, /*quantized=*/true),
{queries,
keys,
key_scales,
key_biases,
values,
value_scales,
value_biases});
}
} }
array pack_and_quantize( array pack_and_quantize(

View File

@ -41,6 +41,21 @@ array scaled_dot_product_attention(
const std::optional<int> memory_efficient_threshold = std::nullopt, const std::optional<int> memory_efficient_threshold = std::nullopt,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Computes: `O = softmax(Q @ K.T) @ V` where K and V are quantized. **/
array quantized_scaled_dot_product_attention(
const array& queries,
const array& keys,
const array& key_scales,
const array& key_biases,
const array& values,
const array& value_scales,
const array& value_biases,
const float scale,
const std::optional<array>& mask = std::nullopt,
const int group_size = 64,
const int bits = 4,
StreamOrDevice s = {});
std::tuple<array, array, array> affine_quantize( std::tuple<array, array, array> affine_quantize(
const array& w, const array& w,
int group_size = 64, int group_size = 64,

View File

@ -190,8 +190,12 @@ class ScaledDotProductAttention : public Custom {
Stream stream, Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback, std::function<std::vector<array>(std::vector<array>)> fallback,
const float scale, const float scale,
const bool needs_mask) const bool needs_mask,
: Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask) {} const bool quantized)
: Custom(stream, fallback),
scale_(scale),
needs_mask_(needs_mask),
quantized_(quantized) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override { override {
@ -212,6 +216,7 @@ class ScaledDotProductAttention : public Custom {
std::function<std::vector<array>(std::vector<array>)> fallback_; std::function<std::vector<array>(std::vector<array>)> fallback_;
float scale_; float scale_;
bool needs_mask_; bool needs_mask_;
bool quantized_;
}; };
class AffineQuantize : public Custom { class AffineQuantize : public Custom {

View File

@ -150,6 +150,49 @@ void init_fast(nb::module_& parent_module) {
array: The output array. array: The output array.
)pbdoc"); )pbdoc");
m.def(
"quantized_scaled_dot_product_attention",
&fast::quantized_scaled_dot_product_attention,
"q"_a,
"k"_a,
"k_scales"_a,
"k_biases"_a,
"v"_a,
"v_scales"_a,
"v_biases"_a,
nb::kw_only(),
"scale"_a,
"mask"_a = nb::none(),
"group_size"_a = 64,
"bits"_a = 4,
"stream"_a = nb::none(),
nb::sig(
"def quantized_scaled_dot_product_attention(q: array, k: array, k_scales: array, k_biases: array, v: array, v_scales: array, v_biases: array, *, scale: float, mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.
Supports:
* `Multi-Head Attention <https://arxiv.org/abs/1706.03762>`_
* `Grouped Query Attention <https://arxiv.org/abs/2305.13245>`_
* `Multi-Query Attention <https://arxiv.org/abs/1911.02150>`_
Note: The softmax operation is performed in ``float32`` regardless of
the input precision.
Note: For Grouped Query Attention and Multi-Query Attention, the ``k``
and ``v`` inputs should not be pre-tiled to match ``q``.
Args:
q (array): Input query array.
k (array): Input keys array.
v (array): Input values array.
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
mask (array, optional): An additive mask to apply to the query-key scores.
Returns:
array: The output array.
)pbdoc");
m.def( m.def(
"affine_quantize", "affine_quantize",
nb::overload_cast< nb::overload_cast<