Compare commits

...

7 Commits

Author SHA1 Message Date
Alex Barron
f5b0f11968 add fast::quantized_kv_update 2024-10-26 00:24:49 -07:00
Alex Barron
b509c2ad76 update bench 2024-10-25 12:10:24 -07:00
Alex Barron
852336b8a2 clean 2024-10-25 12:10:24 -07:00
Alex Barron
6649244686 revert sdpa 2024-10-25 12:10:24 -07:00
Alex Barron
047a584e3d 8 bit working 2024-10-25 12:10:24 -07:00
Alex Barron
ef14b1e9c3 4 bit working 2024-10-25 12:10:24 -07:00
Alex Barron
5824626c0b start 2024-10-25 12:10:24 -07:00
14 changed files with 854 additions and 54 deletions

View File

@@ -1,10 +1,9 @@
import argparse
import math
import mlx.core as mx
import numpy as np
from mlx.utils import tree_map
from time_utils import time_fn
L = 1024
L = 65536
H = 32
H_k = 32 // 4
D = 128
@@ -23,27 +22,60 @@ def attention(q, k, v):
def sdpa(q, k, v):
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None)
def time_self_attention_primitives():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D))
k = mx.random.uniform(shape=(1, H_k, L, D))
v = mx.random.uniform(shape=(1, H_k, L, D))
mx.eval(q, k, v)
def quant_sdpa(q, k, v, bits=4):
return mx.fast.quantized_scaled_dot_product_attention(
q, *k, *v, scale=1.0, mask=None, bits=bits
)
def quant_attention(q, k, v, bits=4):
B, Hq, L, D = q.shape
Hk = k[0].shape[1]
q = q.reshape((B, Hk, Hq // Hk, L, D))
k = tree_map(lambda x: mx.expand_dims(x, axis=2), k)
v = tree_map(lambda x: mx.expand_dims(x, axis=2), v)
scores = mx.quantized_matmul(q, *k, transpose=True, bits=bits)
scores = mx.softmax(scores, axis=-1)
out = mx.quantized_matmul(scores, *v, transpose=False, bits=bits)
out = out.reshape((B, Hq, L, D))
return out
def time_self_attention_primitives(q, k, v):
time_fn(attention, q, k, v)
def time_self_attention_sdpa():
def time_self_attention_sdpa(q, k, v):
time_fn(sdpa, q, k, v)
def time_self_attention_quant_sdpa(q, k, v, bits=4):
time_fn(quant_sdpa, q, k, v, bits)
def time_self_attention_quant_primitives(q, k, v, bits=4):
time_fn(quant_attention, q, k, v)
if __name__ == "__main__":
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D))
k = mx.random.uniform(shape=(1, H_k, L, D))
v = mx.random.uniform(shape=(1, H_k, L, D))
mx.eval(q, k, v)
time_fn(sdpa, q, k, v)
bits = 4
k_quant = mx.quantize(k, bits=bits)
v_quant = mx.quantize(v, bits=bits)
mx.eval(k_quant, v_quant)
if __name__ == "__main__":
time_self_attention_sdpa()
time_self_attention_primitives()
time_self_attention_sdpa(q, k, v)
time_self_attention_quant_sdpa(q, k_quant, v_quant, bits)
time_self_attention_primitives(q, k, v)
time_self_attention_quant_primitives(q, k_quant, v_quant, bits)

View File

@@ -1,5 +1,6 @@
// Copyright © 2023-2024 Apple Inc.
#include <iostream>
#include <sstream>
#include "mlx/backend/metal/copy.h"

View File

@@ -15,6 +15,7 @@ void CustomKernel::eval_gpu(
std::vector<array> copies;
for (auto& out : outputs) {
// Copy from previous kernel
out.set_data(allocator::malloc_or_wait(out.nbytes()));
if (init_value_) {
copies.emplace_back(init_value_.value(), out.dtype());

View File

@@ -1737,13 +1737,13 @@ template <
}
template <typename T, const int group_size, const int bits>
[[kernel]] void affine_quantize(
const device T* w [[buffer(0)]],
device uint8_t* out [[buffer(1)]],
device T* scales [[buffer(2)]],
device T* biases [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
METAL_FUNC void affine_quantize_impl(
const device T* w,
device uint8_t* out,
device T* scales,
device T* biases,
uint2 index,
uint2 grid_dim) {
constexpr T eps = T(1e-7);
constexpr int simd_size = 32;
constexpr int uint8_bits = 8;
@@ -1820,6 +1820,18 @@ template <typename T, const int group_size, const int bits>
}
}
template <typename T, const int group_size, const int bits>
[[kernel]] void affine_quantize(
const device T* w [[buffer(0)]],
device uint8_t* out [[buffer(1)]],
device T* scales [[buffer(2)]],
device T* biases [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
affine_quantize_impl<T, group_size, bits>(
w, out, scales, biases, index, grid_dim);
}
template <typename T, const int group_size, const int bits>
[[kernel]] void affine_quantize_scales_biases(
const device T* w [[buffer(0)]],
@@ -1883,3 +1895,41 @@ template <typename T, const int group_size, const int bits>
out[oindex + i] = scale * d + bias;
}
}
template <typename T, const int group_size, const int bits>
[[kernel]] void kv_update(
const device T* new_keys [[buffer(0)]],
const device T* new_values [[buffer(1)]],
device uint8_t* keys [[buffer(2)]],
device T* key_scales [[buffer(3)]],
device T* key_biases [[buffer(4)]],
device uint8_t* values [[buffer(5)]],
device T* value_scales [[buffer(6)]],
device T* value_biases [[buffer(7)]],
const constant int& offset,
const constant int& batch_stride,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
// Get the right offset in the thing
// Need to use the head dim too
constexpr int pack_factor = 8 / bits;
uint batch_idx = index.y * batch_stride * 4 + offset;
new_keys += index.y * 128;
new_values += index.y * 128;
// uint batch_idx = offset;
// // Index to correct slice
uint group_idx = batch_idx * pack_factor / group_size;
keys += batch_idx;
key_scales += group_idx;
key_biases += group_idx;
values += batch_idx;
value_scales += group_idx;
value_biases += group_idx;
uint2 new_index = {index.x, 0};
affine_quantize_impl<T, group_size, bits>(
new_keys, keys, key_scales, key_biases, new_index, grid_dim);
affine_quantize_impl<T, group_size, bits>(
new_values, values, value_scales, value_biases, new_index, grid_dim);
}

View File

@@ -927,19 +927,7 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2);
// SDPA vector instantiations
#define instantiate_sdpa_vector(type, head_dim) \
template [[host_name("sdpa_vector_" #type "_" #head_dim)]] \
[[kernel]] void sdpa_vector<type, head_dim>( \
const device type* queries [[buffer(0)]], \
const device type* keys [[buffer(1)]], \
const device type* values [[buffer(2)]], \
device type* out [[buffer(3)]], \
const constant int& gqa_factor, \
const constant int& N, \
const constant size_t& k_stride, \
const constant float& scale, \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
instantiate_kernel("sdpa_vector_" #type "_" #head_dim, sdpa_vector, type, head_dim)
#define instantiate_sdpa_vector_heads(type) \
instantiate_sdpa_vector(type, 64) \
@@ -949,4 +937,30 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2);
instantiate_sdpa_vector_heads(float)
instantiate_sdpa_vector_heads(bfloat16_t)
instantiate_sdpa_vector_heads(float16_t)
// Quantized SDPA vector instantiations
#define instantiate_quant_sdpa_vector(type, head_dim, group_size, bits) \
instantiate_kernel( \
"quant_sdpa_vector_" #type "_" #head_dim "_" #group_size "_" #bits, \
quant_sdpa_vector, type, head_dim, group_size, bits)
#define instantiate_quant_sdpa_vector_bits(type, heads, group_size) \
instantiate_quant_sdpa_vector(type, heads, group_size, 4) \
instantiate_quant_sdpa_vector(type, heads, group_size, 8)
#define instantiate_quant_sdpa_vector_group_size(type, heads) \
instantiate_quant_sdpa_vector_bits(type, heads, 32) \
instantiate_quant_sdpa_vector_bits(type, heads, 64) \
instantiate_quant_sdpa_vector_bits(type, heads, 128)
#define instantiate_quant_sdpa_vector_heads(type) \
instantiate_quant_sdpa_vector_group_size(type, 64) \
instantiate_quant_sdpa_vector_group_size(type, 96) \
instantiate_quant_sdpa_vector_group_size(type, 128)
instantiate_quant_sdpa_vector_heads(float)
instantiate_quant_sdpa_vector_heads(bfloat16_t)
instantiate_quant_sdpa_vector_heads(float16_t)
// clang-format on

View File

@@ -113,3 +113,205 @@ template <typename T, int D>
}
}
}
template <typename T, typename U, int elem_per_thread, int bits>
METAL_FUNC U load_queries(const device T* queries, thread U* q, U scale) {
U query_sum = 0;
if (bits == 4) {
for (int i = 0; i < elem_per_thread; i += 4) {
q[i] = scale * queries[i];
q[i + 1] = scale * queries[i + 1];
q[i + 2] = scale * queries[i + 2];
q[i + 3] = scale * queries[i + 3];
query_sum += q[i] + q[i + 1] + q[i + 2] + q[i + 3];
q[i + 1] /= 16.0f;
q[i + 2] /= 256.0f;
q[i + 3] /= 4096.0f;
}
} else if (bits == 8) {
for (int i = 0; i < elem_per_thread; i++) {
q[i] = scale * queries[i];
query_sum += q[i];
}
}
return query_sum;
}
template <typename U, int elem_per_thread, int bits>
METAL_FUNC void load_keys(const device uint32_t* keys, thread U* k) {
if (bits == 4) {
auto ks = (const device uint16_t*)keys;
for (int i = 0; i < elem_per_thread / 4; i++) {
k[4 * i] = ks[i] & 0x000f;
k[4 * i + 1] = ks[i] & 0x00f0;
k[4 * i + 2] = ks[i] & 0x0f00;
k[4 * i + 3] = ks[i] & 0xf000;
}
} else if (bits == 8) {
auto ks = (const device uint8_t*)keys;
for (int i = 0; i < elem_per_thread; i++) {
k[i] = ks[i];
}
}
}
template <typename U, int elem_per_thread, int bits>
METAL_FUNC void load_values(
const device uint32_t* values,
thread U* v,
U value_scale,
U value_bias) {
auto vs = (const device uint8_t*)values;
if (bits == 4) {
U s[2] = {value_scale, value_scale / 16.0f};
for (int i = 0; i < elem_per_thread / 2; i++) {
v[2 * i] = s[0] * (vs[i] & 0x0f) + value_bias;
v[2 * i + 1] = s[1] * (vs[i] & 0xf0) + value_bias;
}
} else if (bits == 8) {
for (int i = 0; i < elem_per_thread; i++) {
v[i] = value_scale * vs[i] + value_bias;
}
}
}
template <typename T, int D, int group_size, int bits>
[[kernel]] void quant_sdpa_vector(
const device T* queries [[buffer(0)]],
const device uint32_t* keys [[buffer(1)]],
const device T* key_scales [[buffer(2)]],
const device T* key_biases [[buffer(3)]],
const device uint32_t* values [[buffer(4)]],
const device T* value_scales [[buffer(5)]],
const device T* value_biases [[buffer(6)]],
device T* out [[buffer(7)]],
const constant int& gqa_factor,
const constant int& N,
const constant size_t& k_stride,
const constant size_t& group_stride,
const constant float& scale,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]],
uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]]) {
constexpr int BN = 32;
constexpr int BD = 4;
constexpr int elem_per_thread = D / BD;
constexpr int pack_factor = 32 / bits;
const int stride = BN * D;
typedef float U;
thread U q[elem_per_thread];
thread U k[elem_per_thread];
thread U v[elem_per_thread];
thread U o[elem_per_thread];
threadgroup U outputs[BN * BD];
threadgroup U max_scores[BN];
threadgroup U sum_exp_scores[BN];
// Adjust positions
const int head_idx = tid.y;
const int kv_head_idx = head_idx / gqa_factor;
queries += head_idx * D + quad_lid * elem_per_thread;
const int kv_idx = quad_gid * D + quad_lid * elem_per_thread;
const int packed_idx = kv_head_idx * k_stride + kv_idx / pack_factor;
const int group_idx = kv_head_idx * group_stride + kv_idx / group_size;
keys += packed_idx;
key_scales += group_idx;
key_biases += group_idx;
values += packed_idx;
value_scales += group_idx;
value_biases += group_idx;
out += head_idx * D + simd_gid * elem_per_thread;
// Read the query and 0 the output accumulator
U query_sum = load_queries<T, U, elem_per_thread, bits>(
queries, q, static_cast<U>(scale));
for (int i = 0; i < elem_per_thread; i++) {
o[i] = 0;
}
U max_score = -INFINITY;
U sum_exp_score = 0;
// For each key
for (int i = quad_gid; i < N; i += BN) {
load_keys<U, elem_per_thread, bits>(keys, k);
// Assume D % group_size == 0 so all the keys are in the same group
U key_scale = key_scales[0];
U key_bias = key_biases[0];
// Compute the i-th score
U score = 0;
for (int i = 0; i < elem_per_thread; i++) {
score += q[i] * k[i];
}
score = score * key_scale + query_sum * key_bias;
score = quad_sum(score);
// Update the accumulators
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
U value_scale = value_scales[0];
U value_bias = value_biases[0];
// Load the values
load_values<U, elem_per_thread, bits>(values, v, value_scale, value_bias);
// Update the output accumulator
for (int i = 0; i < elem_per_thread; i++) {
o[i] = o[i] * factor + exp_score * v[i];
}
// Move the pointers to the next kv
keys += stride / pack_factor;
key_scales += stride / group_size;
key_biases += stride / group_size;
values += stride / pack_factor;
value_scales += stride / group_size;
value_biases += stride / group_size;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Each thread has a partial part of the output so we need to combine them.
// First let's communicate the max and sum_exp
// Each quadgroup communicates it's max score
if (quad_lid == 0) {
max_scores[quad_gid] = max_score;
sum_exp_scores[quad_gid] = sum_exp_score;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_score = max_scores[simd_lid];
U new_max = simd_max(max_score);
U factor = fast::exp(max_score - new_max);
sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
// Now we need to aggregate all the outputs
for (int i = 0; i < elem_per_thread; i++) {
// 128 threads with 32 values per thread
outputs[simd_gid * BN + simd_lid] = o[i];
threadgroup_barrier(mem_flags::mem_threadgroup);
o[i] = simd_sum(outputs[simd_lid * BD + simd_gid] * factor) / sum_exp_score;
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// And write the output
if (simd_lid == 0) {
for (int i = 0; i < elem_per_thread; i++) {
out[i] = static_cast<T>(o[i]);
}
}
}

View File

@@ -14,6 +14,8 @@
#include "mlx/scheduler.h"
#include "mlx/utils.h"
#include <iostream>
namespace mlx::core {
template <typename T>

View File

@@ -1,6 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <iostream>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/copy.h"
@@ -354,4 +355,80 @@ void fast::AffineQuantize::eval_gpu(
d.add_temporaries(std::move(copies), s.index);
}
void fast::KVUpdate::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& s = stream();
auto& d = metal::device(s.device);
std::vector<array> copies;
auto ensure_row_contiguous = [&copies, &s](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
return arr_copy;
}
};
// Copy from the inputs into the outputs
const auto& new_keys = ensure_row_contiguous(inputs[0]);
const auto& new_values = ensure_row_contiguous(inputs[1]);
// Copy the input KV cache to the output.
// If the inputs are contiguous, this will be zero-copy.
for (int i = 0; i < 6; i++) {
auto in = ensure_row_contiguous(inputs[i + 2]);
auto out = outputs[i];
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, s);
}
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_input_array(new_keys, 0);
compute_encoder.set_input_array(new_values, 1);
int enc_offset = 2;
for (auto& out : outputs) {
compute_encoder.set_output_array(out, enc_offset);
enc_offset++;
}
int offset = offset_ * inputs[2].strides(-2) * 4;
// std::cout << "offset " << offset << std::endl;
int batch_stride = inputs[2].shape(-1) * inputs[2].shape(-2);
// std::cout << "batch stride " << batch_stride << std::endl;
compute_encoder->setBytes(&offset, sizeof(int), enc_offset);
compute_encoder->setBytes(&batch_stride, sizeof(int), enc_offset + 1);
auto type_string = get_type_string(new_keys.dtype());
// Now launch the kernel
std::ostringstream kname;
kname << "kv_update" << "_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
auto template_def = get_template_definition(
kname.str(), "kv_update", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int per_thread = 8 / bits_;
size_t nrows = new_keys.size() / new_keys.shape(-1);
size_t ncols = new_keys.shape(-1) / per_thread;
size_t nthreads = nrows * ncols;
// std::cout << "nthreads " << nthreads << std::endl;
// std::cout << "nrows " << nrows << std::endl;
// std::cout << "ncols " << ncols << std::endl;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = ncols;
}
auto group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = MTL::Size(ncols, nrows, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
}
} // namespace mlx::core

View File

@@ -185,19 +185,73 @@ void sdpa_vector(
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
void quant_sdpa_vector(
const Stream& s,
metal::Device& d,
const array& q,
const array& k,
const array& k_scales,
const array& k_biases,
const array& v,
const array& v_scales,
const array& v_biases,
array& out,
float scale,
int group_size,
int bits) {
// Set the kernel name
std::string kname;
kname.reserve(96);
kname += "quant_sdpa_vector_";
kname += get_type_string(q.dtype());
kname += "_";
kname += std::to_string(q.shape(-1));
kname += "_";
kname += std::to_string(group_size);
kname += "_";
kname += std::to_string(bits);
// Compute the necessary sizes
int gqa_factor = q.shape(1) / k.shape(1);
int N = k.shape(2);
int B = q.shape(0) * q.shape(1);
size_t stride = k.strides()[1];
size_t group_stride = k_scales.strides()[1];
MTL::Size group_dims(128, 1, 1);
MTL::Size grid_dims(1, B, 1);
// Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname);
compute_encoder->setComputePipelineState(kernel);
// Set its arguments
compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0);
compute_encoder.set_input_array(k, 1);
compute_encoder.set_input_array(k_scales, 2);
compute_encoder.set_input_array(k_biases, 3);
compute_encoder.set_input_array(v, 4);
compute_encoder.set_input_array(v_scales, 5);
compute_encoder.set_input_array(v_biases, 6);
compute_encoder.set_output_array(out, 7);
compute_encoder->setBytes(&gqa_factor, sizeof(int), 8);
compute_encoder->setBytes(&N, sizeof(int), 9);
compute_encoder->setBytes(&stride, sizeof(size_t), 10);
compute_encoder->setBytes(&group_stride, sizeof(size_t), 11);
compute_encoder->setBytes(&scale, sizeof(float), 12);
// Launch
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
} // namespace
void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs,
array& out) {
assert(inputs.size() == 3);
auto& s = stream();
auto& d = metal::device(s.device);
auto& q_pre = inputs[0];
auto& k_pre = inputs[1];
auto& v_pre = inputs[2];
auto& o = out;
std::vector<array> copies;
@@ -236,11 +290,25 @@ void ScaledDotProductAttention::eval_gpu(
return strides[3] == 1 && strides[2] == shape[3];
};
// We are in vector mode ie single query
if (q_pre.shape(2) == 1) {
if (quantized_) {
auto& q_pre = inputs[0];
auto& k_pre = inputs[1];
auto& k_scales_pre = inputs[2];
auto& k_biases_pre = inputs[3];
auto& v_pre = inputs[4];
auto& v_scales_pre = inputs[5];
auto& v_biases_pre = inputs[6];
// Quantized should only be routed here for single queries
assert(q_pre.shape(2) == 1);
auto q = copy_unless(is_contiguous, q_pre);
auto k = copy_unless(is_contiguous_except_seq_len, k_pre);
auto k_scales = copy_unless(is_contiguous_except_seq_len, k_scales_pre);
auto k_biases = copy_unless(is_contiguous_except_seq_len, k_biases_pre);
auto v = copy_unless(is_contiguous_except_seq_len, v_pre);
auto v_scales = copy_unless(is_contiguous_except_seq_len, v_scales_pre);
auto v_biases = copy_unless(is_contiguous_except_seq_len, v_biases_pre);
// Donate the query if possible
if (q.is_donatable()) {
@@ -249,17 +317,54 @@ void ScaledDotProductAttention::eval_gpu(
o.set_data(allocator::malloc_or_wait(o.nbytes()));
}
sdpa_vector(s, d, q, k, v, o, scale_);
quant_sdpa_vector(
s,
d,
q,
k,
k_scales,
k_biases,
v,
v_scales,
v_biases,
o,
scale_,
group_size_,
bits_);
}
// Full attention mode
// Non-quantized
else {
auto q = copy_unless(is_matrix_contiguous, q_pre);
auto k = copy_unless(is_matrix_contiguous, k_pre);
auto v = copy_unless(is_matrix_contiguous, v_pre);
o.set_data(allocator::malloc_or_wait(o.nbytes()));
assert(inputs.size() == 3);
auto& q_pre = inputs[0];
auto& k_pre = inputs[1];
auto& v_pre = inputs[2];
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);
// We are in vector mode ie single query
if (q_pre.shape(2) == 1) {
auto q = copy_unless(is_contiguous, q_pre);
auto k = copy_unless(is_contiguous_except_seq_len, k_pre);
auto v = copy_unless(is_contiguous_except_seq_len, v_pre);
// Donate the query if possible
if (q.is_donatable()) {
o.move_shared_buffer(q);
} else {
o.set_data(allocator::malloc_or_wait(o.nbytes()));
}
sdpa_vector(s, d, q, k, v, o, scale_);
}
// Full attention mode
else {
auto q = copy_unless(is_matrix_contiguous, q_pre);
auto k = copy_unless(is_matrix_contiguous, k_pre);
auto v = copy_unless(is_matrix_contiguous, v_pre);
o.set_data(allocator::malloc_or_wait(o.nbytes()));
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);
}
}
d.add_temporaries(std::move(copies), s.index);

View File

@@ -123,6 +123,7 @@ NO_GPU_MULTI(RMSNormVJP)
NO_GPU_MULTI(RoPE)
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize)
NO_GPU_MULTI(KVUpdate)
NO_GPU_MULTI(CustomKernel)
} // namespace fast

View File

@@ -648,7 +648,7 @@ array scaled_dot_product_attention(
std::move(out_shape),
final_type,
std::make_shared<ScaledDotProductAttention>(
stream, fallback, scale, false),
stream, fallback, scale, /*needs_mask=*/false, /*quantized=*/false),
{q, k, v});
}
@@ -662,7 +662,130 @@ array scaled_dot_product_attention(
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
const ScaledDotProductAttention& a_other =
static_cast<const ScaledDotProductAttention&>(other);
return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_;
return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_ &&
quantized_ == a_other.quantized_;
}
array quantized_scaled_dot_product_attention(
const array& queries,
const array& keys,
const array& key_scales,
const array& key_biases,
const array& values,
const array& value_scales,
const array& value_biases,
const float scale,
const std::optional<array>& mask,
const int group_size,
const int bits,
StreamOrDevice s) {
int el_per_int = 32 / bits;
int out_dim = values.shape(-1) * el_per_int;
auto n_q_heads = queries.shape(-3);
auto n_kv_heads = keys.shape(-3);
auto out_shape = std::vector<int>(
{queries.shape(0), queries.shape(1), queries.shape(2), out_dim});
auto stream = to_stream(s);
bool needs_mask = mask.has_value();
auto fallback =
[scale, needs_mask, n_q_heads, n_kv_heads, group_size, bits, &s](
const std::vector<array>& inputs) -> std::vector<array> {
int n_repeats = n_q_heads / n_kv_heads;
auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
auto k = inputs[1];
auto k_scales = inputs[2];
auto k_biases = inputs[3];
auto v = inputs[4];
auto v_scales = inputs[5];
auto v_biases = inputs[6];
int B = q.shape(0);
int L = q.shape(2);
if (n_repeats > 1) {
q = reshape(q, {B, n_kv_heads, n_repeats, L, -1}, s);
k = expand_dims(k, 2, s);
k_scales = expand_dims(k_scales, 2, s);
k_biases = expand_dims(k_biases, 2, s);
v = expand_dims(v, 2, s);
v_scales = expand_dims(v_scales, 2, s);
v_biases = expand_dims(v_biases, 2, s);
}
array scores = quantized_matmul(
q,
k,
k_scales,
k_biases,
/*transpose=*/true,
/*group_size=*/group_size,
/*bits=*/bits,
s);
if (needs_mask) {
scores = add(scores, inputs[7], s);
}
scores = softmax(scores, std::vector<int>{-1}, true, s);
array out = quantized_matmul(
scores,
v,
v_scales,
v_biases,
/*transpose=*/false,
/*group_size=*/group_size,
/*bits=*/bits,
s);
if (n_repeats > 1) {
out = reshape(out, {B, n_q_heads, L, -1}, s);
}
return std::vector<array>{out};
};
int L = queries.shape(2);
if (L > 1) {
if (needs_mask) {
return fallback(
{queries,
keys,
key_scales,
key_biases,
values,
value_scales,
value_biases,
mask.value()})[0];
} else {
return fallback(
{queries,
keys,
key_scales,
key_biases,
values,
value_scales,
value_biases})[0];
}
} else {
return array(
std::move(out_shape),
queries.dtype(),
std::make_shared<ScaledDotProductAttention>(
stream,
fallback,
scale,
/*needs_mask=*/false,
/*quantized=*/true,
group_size,
bits),
{queries,
keys,
key_scales,
key_biases,
values,
value_scales,
value_biases});
}
}
array pack_and_quantize(
@@ -907,6 +1030,51 @@ array affine_dequantize(
return fallback({w, scales, biases})[0];
}
std::vector<array> kv_update(
const array& new_keys,
const array& new_values,
const array& keys,
const array& key_scales,
const array& key_biases,
const array& values,
const array& value_scales,
const array& value_biases,
int offset,
int group_size,
int bits,
StreamOrDevice s_) {
auto s = to_stream(s_);
int el_per_int = 32 / bits;
auto out_shape = keys.shape();
out_shape.back() = keys.shape(-1) / el_per_int;
auto fallback = [](const std::vector<array>& inputs) -> std::vector<array> {
return {inputs[0], inputs[1]};
};
return array::make_arrays(
{keys.shape(),
key_scales.shape(),
key_biases.shape(),
values.shape(),
value_scales.shape(),
value_biases.shape()},
{keys.dtype(),
key_scales.dtype(),
key_biases.dtype(),
values.dtype(),
value_scales.dtype(),
value_biases.dtype()},
std::make_shared<KVUpdate>(s, fallback, offset, group_size, bits),
{new_keys,
new_values,
keys,
key_scales,
key_biases,
values,
value_scales,
value_biases});
}
std::string write_signature(
std::string func_name,
const std::string& header,

View File

@@ -41,6 +41,21 @@ array scaled_dot_product_attention(
const std::optional<int> memory_efficient_threshold = std::nullopt,
StreamOrDevice s = {});
/** Computes: `O = softmax(Q @ K.T) @ V` where K and V are quantized. **/
array quantized_scaled_dot_product_attention(
const array& queries,
const array& keys,
const array& key_scales,
const array& key_biases,
const array& values,
const array& value_scales,
const array& value_biases,
const float scale,
const std::optional<array>& mask = std::nullopt,
const int group_size = 64,
const int bits = 4,
StreamOrDevice s = {});
std::tuple<array, array, array> affine_quantize(
const array& w,
int group_size = 64,
@@ -63,6 +78,20 @@ array affine_dequantize(
int bits = 4,
StreamOrDevice s = {});
std::vector<array> kv_update(
const array& new_keys,
const array& new_values,
const array& keys,
const array& key_scales,
const array& key_biases,
const array& values,
const array& value_scales,
const array& value_biases,
int offset,
int group_size = 64,
int bits = 4,
StreamOrDevice s = {});
typedef std::variant<int, bool, Dtype> TemplateArg;
typedef std::function<std::vector<array>(

View File

@@ -190,8 +190,16 @@ class ScaledDotProductAttention : public Custom {
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
const float scale,
const bool needs_mask)
: Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask) {}
const bool needs_mask,
const bool quantized,
const int group_size = 64,
const int bits = 4)
: Custom(stream, fallback),
scale_(scale),
needs_mask_(needs_mask),
quantized_(quantized),
group_size_(group_size),
bits_(bits) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
@@ -212,6 +220,9 @@ class ScaledDotProductAttention : public Custom {
std::function<std::vector<array>(std::vector<array>)> fallback_;
float scale_;
bool needs_mask_;
bool quantized_;
int group_size_;
int bits_;
};
class AffineQuantize : public Custom {
@@ -244,6 +255,36 @@ class AffineQuantize : public Custom {
bool dequantize_;
};
class KVUpdate : public Custom {
public:
explicit KVUpdate(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
int offset,
int group_size,
int bits)
: Custom(stream, fallback),
offset_(offset),
group_size_(group_size),
bits_(bits) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
throw std::runtime_error("NYI");
}
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(KVUpdate);
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
int offset_;
int group_size_;
int bits_;
};
struct CustomKernelShapeInfo {
bool shape = false;
bool strides = false;

View File

@@ -150,6 +150,45 @@ void init_fast(nb::module_& parent_module) {
array: The output array.
)pbdoc");
m.def(
"quantized_scaled_dot_product_attention",
&fast::quantized_scaled_dot_product_attention,
"q"_a,
"k"_a,
"k_scales"_a,
"k_biases"_a,
"v"_a,
"v_scales"_a,
"v_biases"_a,
nb::kw_only(),
"scale"_a,
"mask"_a = nb::none(),
"group_size"_a = 64,
"bits"_a = 4,
"stream"_a = nb::none(),
nb::sig(
"def quantized_scaled_dot_product_attention(q: array, k: array, k_scales: array, k_biases: array, v: array, v_scales: array, v_biases: array, *, scale: float, mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
A fast implementation of multi-head attention where the keys and values are quantized.
see :func:`scaled_dot_product_attention` for more details.
Args:
q (array): Input query array.
k (array): Input keys array.
k_scales (array): Scales for the quantized keys array.
k_biases (array): Biases for the quantized keys array.
v (array): Input values array.
v_scales (array): Scales for the quantized values array.
v_biases (array): Biases for the quantized values array.
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
mask (array, optional): An additive mask to apply to the query-key scores.
group_size (int): The group size used in the KV quantization.
bits (int): The bits used in the KV quantization.
Returns:
array: The output array.
)pbdoc");
m.def(
"affine_quantize",
nb::overload_cast<
@@ -193,6 +232,44 @@ void init_fast(nb::module_& parent_module) {
array: The quantized version of ``w``
)pbdoc");
m.def(
"quantized_kv_update",
&fast::kv_update,
"new_keys"_a,
"new_values"_a,
"keys"_a,
"key_scales"_a,
"key_biases"_a,
"values"_a,
"value_scales"_a,
"value_biases"_a,
"offset"_a = 64,
"group_size"_a = 64,
"bits"_a = 4,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def quantized_kv_update(new_keys: array, new_values: array, key_scales: array, key_biases: array, values: array, value_scales: array, value_biases: array, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Fused update for a quantized KV cache.
.. math::
w_i = s (\hat{w_i} + \beta)
Args:
w (array): Matrix to be quantize
scales (array): The scales to use per ``group_size`` elements of ``w``
biases (array): The biases to use per ``group_size`` elements of ``w``
group_size (int, optional): The size of the group in ``w`` that shares a
scale and bias. (default: ``64``)
bits (int, optional): The number of bits occupied by each element in
``w``. (default: ``4``)
Returns:
array: The quantized version of ``w``
)pbdoc");
m.def(
"metal_kernel",
[](const std::string& name,