From f5b0f1196843008a0bbaf0eba1747527e08530d9 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Sat, 26 Oct 2024 00:24:49 -0700 Subject: [PATCH] add fast::quantized_kv_update --- mlx/backend/metal/copy.cpp | 1 + mlx/backend/metal/custom_kernel.cpp | 1 + mlx/backend/metal/kernels/quantized.h | 64 +++++++++++++++++++--- mlx/backend/metal/primitives.cpp | 2 + mlx/backend/metal/quantized.cpp | 77 +++++++++++++++++++++++++++ mlx/backend/no_metal/primitives.cpp | 1 + mlx/fast.cpp | 45 ++++++++++++++++ mlx/fast.h | 14 +++++ mlx/fast_primitives.h | 30 +++++++++++ python/src/fast.cpp | 38 +++++++++++++ 10 files changed, 266 insertions(+), 7 deletions(-) diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index 49a09483a..4024e8402 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -1,5 +1,6 @@ // Copyright © 2023-2024 Apple Inc. +#include #include #include "mlx/backend/metal/copy.h" diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index e2627c87b..d04706e47 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -15,6 +15,7 @@ void CustomKernel::eval_gpu( std::vector copies; for (auto& out : outputs) { + // Copy from previous kernel out.set_data(allocator::malloc_or_wait(out.nbytes())); if (init_value_) { copies.emplace_back(init_value_.value(), out.dtype()); diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index e8f1c18a2..717a3198a 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -1737,13 +1737,13 @@ template < } template -[[kernel]] void affine_quantize( - const device T* w [[buffer(0)]], - device uint8_t* out [[buffer(1)]], - device T* scales [[buffer(2)]], - device T* biases [[buffer(3)]], - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { +METAL_FUNC void affine_quantize_impl( + const device T* w, + device uint8_t* out, + device T* scales, + device T* biases, + uint2 index, + uint2 grid_dim) { constexpr T eps = T(1e-7); constexpr int simd_size = 32; constexpr int uint8_bits = 8; @@ -1820,6 +1820,18 @@ template } } +template +[[kernel]] void affine_quantize( + const device T* w [[buffer(0)]], + device uint8_t* out [[buffer(1)]], + device T* scales [[buffer(2)]], + device T* biases [[buffer(3)]], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + affine_quantize_impl( + w, out, scales, biases, index, grid_dim); +} + template [[kernel]] void affine_quantize_scales_biases( const device T* w [[buffer(0)]], @@ -1883,3 +1895,41 @@ template out[oindex + i] = scale * d + bias; } } + +template +[[kernel]] void kv_update( + const device T* new_keys [[buffer(0)]], + const device T* new_values [[buffer(1)]], + device uint8_t* keys [[buffer(2)]], + device T* key_scales [[buffer(3)]], + device T* key_biases [[buffer(4)]], + device uint8_t* values [[buffer(5)]], + device T* value_scales [[buffer(6)]], + device T* value_biases [[buffer(7)]], + const constant int& offset, + const constant int& batch_stride, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + // Get the right offset in the thing + // Need to use the head dim too + constexpr int pack_factor = 8 / bits; + uint batch_idx = index.y * batch_stride * 4 + offset; + new_keys += index.y * 128; + new_values += index.y * 128; + // uint batch_idx = offset; + // // Index to correct slice + uint group_idx = batch_idx * pack_factor / group_size; + keys += batch_idx; + key_scales += group_idx; + key_biases += group_idx; + values += batch_idx; + value_scales += group_idx; + value_biases += group_idx; + + uint2 new_index = {index.x, 0}; + + affine_quantize_impl( + new_keys, keys, key_scales, key_biases, new_index, grid_dim); + affine_quantize_impl( + new_values, values, value_scales, value_biases, new_index, grid_dim); +} diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index e5a7d885b..a536997fd 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -14,6 +14,8 @@ #include "mlx/scheduler.h" #include "mlx/utils.h" +#include + namespace mlx::core { template diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 30828da70..6f8c104d3 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -1,6 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include +#include #include "mlx/backend/common/compiled.h" #include "mlx/backend/metal/copy.h" @@ -354,4 +355,80 @@ void fast::AffineQuantize::eval_gpu( d.add_temporaries(std::move(copies), s.index); } +void fast::KVUpdate::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + auto& s = stream(); + auto& d = metal::device(s.device); + + std::vector copies; + auto ensure_row_contiguous = [&copies, &s](const array& arr) { + if (arr.flags().row_contiguous) { + return arr; + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy_gpu(arr, arr_copy, CopyType::General, s); + copies.push_back(arr_copy); + return arr_copy; + } + }; + + // Copy from the inputs into the outputs + const auto& new_keys = ensure_row_contiguous(inputs[0]); + const auto& new_values = ensure_row_contiguous(inputs[1]); + + // Copy the input KV cache to the output. + // If the inputs are contiguous, this will be zero-copy. + for (int i = 0; i < 6; i++) { + auto in = ensure_row_contiguous(inputs[i + 2]); + auto out = outputs[i]; + auto ctype = in.flags().contiguous && in.size() == in.data_size() + ? CopyType::Vector + : CopyType::General; + copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, s); + } + + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_input_array(new_keys, 0); + compute_encoder.set_input_array(new_values, 1); + int enc_offset = 2; + for (auto& out : outputs) { + compute_encoder.set_output_array(out, enc_offset); + enc_offset++; + } + int offset = offset_ * inputs[2].strides(-2) * 4; + // std::cout << "offset " << offset << std::endl; + int batch_stride = inputs[2].shape(-1) * inputs[2].shape(-2); + // std::cout << "batch stride " << batch_stride << std::endl; + compute_encoder->setBytes(&offset, sizeof(int), enc_offset); + compute_encoder->setBytes(&batch_stride, sizeof(int), enc_offset + 1); + + auto type_string = get_type_string(new_keys.dtype()); + // Now launch the kernel + std::ostringstream kname; + kname << "kv_update" << "_" << type_string << "_gs_" << group_size_ << "_b_" + << bits_; + auto template_def = get_template_definition( + kname.str(), "kv_update", type_string, group_size_, bits_); + auto kernel = get_quantized_kernel(d, kname.str(), template_def); + compute_encoder->setComputePipelineState(kernel); + + int per_thread = 8 / bits_; + size_t nrows = new_keys.size() / new_keys.shape(-1); + size_t ncols = new_keys.shape(-1) / per_thread; + size_t nthreads = nrows * ncols; + // std::cout << "nthreads " << nthreads << std::endl; + // std::cout << "nrows " << nrows << std::endl; + // std::cout << "ncols " << ncols << std::endl; + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size > nthreads) { + thread_group_size = ncols; + } + auto group_dims = MTL::Size(thread_group_size, 1, 1); + MTL::Size grid_dims = MTL::Size(ncols, nrows, 1); + compute_encoder.dispatchThreads(grid_dims, group_dims); + + d.add_temporaries(std::move(copies), s.index); +} + } // namespace mlx::core diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index aaee51d83..d3688f3f1 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -123,6 +123,7 @@ NO_GPU_MULTI(RMSNormVJP) NO_GPU_MULTI(RoPE) NO_GPU(ScaledDotProductAttention) NO_GPU_MULTI(AffineQuantize) +NO_GPU_MULTI(KVUpdate) NO_GPU_MULTI(CustomKernel) } // namespace fast diff --git a/mlx/fast.cpp b/mlx/fast.cpp index cdc594bea..d64e44c80 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -1030,6 +1030,51 @@ array affine_dequantize( return fallback({w, scales, biases})[0]; } +std::vector kv_update( + const array& new_keys, + const array& new_values, + const array& keys, + const array& key_scales, + const array& key_biases, + const array& values, + const array& value_scales, + const array& value_biases, + int offset, + int group_size, + int bits, + StreamOrDevice s_) { + auto s = to_stream(s_); + + int el_per_int = 32 / bits; + auto out_shape = keys.shape(); + out_shape.back() = keys.shape(-1) / el_per_int; + auto fallback = [](const std::vector& inputs) -> std::vector { + return {inputs[0], inputs[1]}; + }; + return array::make_arrays( + {keys.shape(), + key_scales.shape(), + key_biases.shape(), + values.shape(), + value_scales.shape(), + value_biases.shape()}, + {keys.dtype(), + key_scales.dtype(), + key_biases.dtype(), + values.dtype(), + value_scales.dtype(), + value_biases.dtype()}, + std::make_shared(s, fallback, offset, group_size, bits), + {new_keys, + new_values, + keys, + key_scales, + key_biases, + values, + value_scales, + value_biases}); +} + std::string write_signature( std::string func_name, const std::string& header, diff --git a/mlx/fast.h b/mlx/fast.h index 987aa8ce8..edd4fa14e 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -78,6 +78,20 @@ array affine_dequantize( int bits = 4, StreamOrDevice s = {}); +std::vector kv_update( + const array& new_keys, + const array& new_values, + const array& keys, + const array& key_scales, + const array& key_biases, + const array& values, + const array& value_scales, + const array& value_biases, + int offset, + int group_size = 64, + int bits = 4, + StreamOrDevice s = {}); + typedef std::variant TemplateArg; typedef std::function( diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index cb79aee31..ba87fcfd0 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -255,6 +255,36 @@ class AffineQuantize : public Custom { bool dequantize_; }; +class KVUpdate : public Custom { + public: + explicit KVUpdate( + Stream stream, + std::function(std::vector)> fallback, + int offset, + int group_size, + int bits) + : Custom(stream, fallback), + offset_(offset), + group_size_(group_size), + bits_(bits) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override { + throw std::runtime_error("NYI"); + } + + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_PRINT(KVUpdate); + + private: + std::function(std::vector)> fallback_; + int offset_; + int group_size_; + int bits_; +}; + struct CustomKernelShapeInfo { bool shape = false; bool strides = false; diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 0b3947567..a60420460 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -232,6 +232,44 @@ void init_fast(nb::module_& parent_module) { array: The quantized version of ``w`` )pbdoc"); + m.def( + "quantized_kv_update", + &fast::kv_update, + "new_keys"_a, + "new_values"_a, + "keys"_a, + "key_scales"_a, + "key_biases"_a, + "values"_a, + "value_scales"_a, + "value_biases"_a, + "offset"_a = 64, + "group_size"_a = 64, + "bits"_a = 4, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def quantized_kv_update(new_keys: array, new_values: array, key_scales: array, key_biases: array, values: array, value_scales: array, value_biases: array, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Fused update for a quantized KV cache. + + .. math:: + + w_i = s (\hat{w_i} + \beta) + + Args: + w (array): Matrix to be quantize + scales (array): The scales to use per ``group_size`` elements of ``w`` + biases (array): The biases to use per ``group_size`` elements of ``w`` + group_size (int, optional): The size of the group in ``w`` that shares a + scale and bias. (default: ``64``) + bits (int, optional): The number of bits occupied by each element in + ``w``. (default: ``4``) + + Returns: + array: The quantized version of ``w`` + )pbdoc"); + m.def( "metal_kernel", [](const std::string& name,