mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-23 10:02:12 +08:00
add fast::quantized_kv_update
This commit is contained in:
parent
b509c2ad76
commit
f5b0f11968
@ -1,5 +1,6 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
|
@ -15,6 +15,7 @@ void CustomKernel::eval_gpu(
|
||||
std::vector<array> copies;
|
||||
|
||||
for (auto& out : outputs) {
|
||||
// Copy from previous kernel
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
if (init_value_) {
|
||||
copies.emplace_back(init_value_.value(), out.dtype());
|
||||
|
@ -1737,13 +1737,13 @@ template <
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void affine_quantize(
|
||||
const device T* w [[buffer(0)]],
|
||||
device uint8_t* out [[buffer(1)]],
|
||||
device T* scales [[buffer(2)]],
|
||||
device T* biases [[buffer(3)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
METAL_FUNC void affine_quantize_impl(
|
||||
const device T* w,
|
||||
device uint8_t* out,
|
||||
device T* scales,
|
||||
device T* biases,
|
||||
uint2 index,
|
||||
uint2 grid_dim) {
|
||||
constexpr T eps = T(1e-7);
|
||||
constexpr int simd_size = 32;
|
||||
constexpr int uint8_bits = 8;
|
||||
@ -1820,6 +1820,18 @@ template <typename T, const int group_size, const int bits>
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void affine_quantize(
|
||||
const device T* w [[buffer(0)]],
|
||||
device uint8_t* out [[buffer(1)]],
|
||||
device T* scales [[buffer(2)]],
|
||||
device T* biases [[buffer(3)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
affine_quantize_impl<T, group_size, bits>(
|
||||
w, out, scales, biases, index, grid_dim);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void affine_quantize_scales_biases(
|
||||
const device T* w [[buffer(0)]],
|
||||
@ -1883,3 +1895,41 @@ template <typename T, const int group_size, const int bits>
|
||||
out[oindex + i] = scale * d + bias;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void kv_update(
|
||||
const device T* new_keys [[buffer(0)]],
|
||||
const device T* new_values [[buffer(1)]],
|
||||
device uint8_t* keys [[buffer(2)]],
|
||||
device T* key_scales [[buffer(3)]],
|
||||
device T* key_biases [[buffer(4)]],
|
||||
device uint8_t* values [[buffer(5)]],
|
||||
device T* value_scales [[buffer(6)]],
|
||||
device T* value_biases [[buffer(7)]],
|
||||
const constant int& offset,
|
||||
const constant int& batch_stride,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
// Get the right offset in the thing
|
||||
// Need to use the head dim too
|
||||
constexpr int pack_factor = 8 / bits;
|
||||
uint batch_idx = index.y * batch_stride * 4 + offset;
|
||||
new_keys += index.y * 128;
|
||||
new_values += index.y * 128;
|
||||
// uint batch_idx = offset;
|
||||
// // Index to correct slice
|
||||
uint group_idx = batch_idx * pack_factor / group_size;
|
||||
keys += batch_idx;
|
||||
key_scales += group_idx;
|
||||
key_biases += group_idx;
|
||||
values += batch_idx;
|
||||
value_scales += group_idx;
|
||||
value_biases += group_idx;
|
||||
|
||||
uint2 new_index = {index.x, 0};
|
||||
|
||||
affine_quantize_impl<T, group_size, bits>(
|
||||
new_keys, keys, key_scales, key_biases, new_index, grid_dim);
|
||||
affine_quantize_impl<T, group_size, bits>(
|
||||
new_values, values, value_scales, value_biases, new_index, grid_dim);
|
||||
}
|
||||
|
@ -14,6 +14,8 @@
|
||||
#include "mlx/scheduler.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
|
@ -1,6 +1,7 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
@ -354,4 +355,80 @@ void fast::AffineQuantize::eval_gpu(
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
void fast::KVUpdate::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
std::vector<array> copies;
|
||||
auto ensure_row_contiguous = [&copies, &s](const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
return arr_copy;
|
||||
}
|
||||
};
|
||||
|
||||
// Copy from the inputs into the outputs
|
||||
const auto& new_keys = ensure_row_contiguous(inputs[0]);
|
||||
const auto& new_values = ensure_row_contiguous(inputs[1]);
|
||||
|
||||
// Copy the input KV cache to the output.
|
||||
// If the inputs are contiguous, this will be zero-copy.
|
||||
for (int i = 0; i < 6; i++) {
|
||||
auto in = ensure_row_contiguous(inputs[i + 2]);
|
||||
auto out = outputs[i];
|
||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, s);
|
||||
}
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_input_array(new_keys, 0);
|
||||
compute_encoder.set_input_array(new_values, 1);
|
||||
int enc_offset = 2;
|
||||
for (auto& out : outputs) {
|
||||
compute_encoder.set_output_array(out, enc_offset);
|
||||
enc_offset++;
|
||||
}
|
||||
int offset = offset_ * inputs[2].strides(-2) * 4;
|
||||
// std::cout << "offset " << offset << std::endl;
|
||||
int batch_stride = inputs[2].shape(-1) * inputs[2].shape(-2);
|
||||
// std::cout << "batch stride " << batch_stride << std::endl;
|
||||
compute_encoder->setBytes(&offset, sizeof(int), enc_offset);
|
||||
compute_encoder->setBytes(&batch_stride, sizeof(int), enc_offset + 1);
|
||||
|
||||
auto type_string = get_type_string(new_keys.dtype());
|
||||
// Now launch the kernel
|
||||
std::ostringstream kname;
|
||||
kname << "kv_update" << "_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "kv_update", type_string, group_size_, bits_);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int per_thread = 8 / bits_;
|
||||
size_t nrows = new_keys.size() / new_keys.shape(-1);
|
||||
size_t ncols = new_keys.shape(-1) / per_thread;
|
||||
size_t nthreads = nrows * ncols;
|
||||
// std::cout << "nthreads " << nthreads << std::endl;
|
||||
// std::cout << "nrows " << nrows << std::endl;
|
||||
// std::cout << "ncols " << ncols << std::endl;
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = ncols;
|
||||
}
|
||||
auto group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(ncols, nrows, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -123,6 +123,7 @@ NO_GPU_MULTI(RMSNormVJP)
|
||||
NO_GPU_MULTI(RoPE)
|
||||
NO_GPU(ScaledDotProductAttention)
|
||||
NO_GPU_MULTI(AffineQuantize)
|
||||
NO_GPU_MULTI(KVUpdate)
|
||||
NO_GPU_MULTI(CustomKernel)
|
||||
} // namespace fast
|
||||
|
||||
|
45
mlx/fast.cpp
45
mlx/fast.cpp
@ -1030,6 +1030,51 @@ array affine_dequantize(
|
||||
return fallback({w, scales, biases})[0];
|
||||
}
|
||||
|
||||
std::vector<array> kv_update(
|
||||
const array& new_keys,
|
||||
const array& new_values,
|
||||
const array& keys,
|
||||
const array& key_scales,
|
||||
const array& key_biases,
|
||||
const array& values,
|
||||
const array& value_scales,
|
||||
const array& value_biases,
|
||||
int offset,
|
||||
int group_size,
|
||||
int bits,
|
||||
StreamOrDevice s_) {
|
||||
auto s = to_stream(s_);
|
||||
|
||||
int el_per_int = 32 / bits;
|
||||
auto out_shape = keys.shape();
|
||||
out_shape.back() = keys.shape(-1) / el_per_int;
|
||||
auto fallback = [](const std::vector<array>& inputs) -> std::vector<array> {
|
||||
return {inputs[0], inputs[1]};
|
||||
};
|
||||
return array::make_arrays(
|
||||
{keys.shape(),
|
||||
key_scales.shape(),
|
||||
key_biases.shape(),
|
||||
values.shape(),
|
||||
value_scales.shape(),
|
||||
value_biases.shape()},
|
||||
{keys.dtype(),
|
||||
key_scales.dtype(),
|
||||
key_biases.dtype(),
|
||||
values.dtype(),
|
||||
value_scales.dtype(),
|
||||
value_biases.dtype()},
|
||||
std::make_shared<KVUpdate>(s, fallback, offset, group_size, bits),
|
||||
{new_keys,
|
||||
new_values,
|
||||
keys,
|
||||
key_scales,
|
||||
key_biases,
|
||||
values,
|
||||
value_scales,
|
||||
value_biases});
|
||||
}
|
||||
|
||||
std::string write_signature(
|
||||
std::string func_name,
|
||||
const std::string& header,
|
||||
|
14
mlx/fast.h
14
mlx/fast.h
@ -78,6 +78,20 @@ array affine_dequantize(
|
||||
int bits = 4,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
std::vector<array> kv_update(
|
||||
const array& new_keys,
|
||||
const array& new_values,
|
||||
const array& keys,
|
||||
const array& key_scales,
|
||||
const array& key_biases,
|
||||
const array& values,
|
||||
const array& value_scales,
|
||||
const array& value_biases,
|
||||
int offset,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
typedef std::variant<int, bool, Dtype> TemplateArg;
|
||||
|
||||
typedef std::function<std::vector<array>(
|
||||
|
@ -255,6 +255,36 @@ class AffineQuantize : public Custom {
|
||||
bool dequantize_;
|
||||
};
|
||||
|
||||
class KVUpdate : public Custom {
|
||||
public:
|
||||
explicit KVUpdate(
|
||||
Stream stream,
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||
int offset,
|
||||
int group_size,
|
||||
int bits)
|
||||
: Custom(stream, fallback),
|
||||
offset_(offset),
|
||||
group_size_(group_size),
|
||||
bits_(bits) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error("NYI");
|
||||
}
|
||||
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_PRINT(KVUpdate);
|
||||
|
||||
private:
|
||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||
int offset_;
|
||||
int group_size_;
|
||||
int bits_;
|
||||
};
|
||||
|
||||
struct CustomKernelShapeInfo {
|
||||
bool shape = false;
|
||||
bool strides = false;
|
||||
|
@ -232,6 +232,44 @@ void init_fast(nb::module_& parent_module) {
|
||||
array: The quantized version of ``w``
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"quantized_kv_update",
|
||||
&fast::kv_update,
|
||||
"new_keys"_a,
|
||||
"new_values"_a,
|
||||
"keys"_a,
|
||||
"key_scales"_a,
|
||||
"key_biases"_a,
|
||||
"values"_a,
|
||||
"value_scales"_a,
|
||||
"value_biases"_a,
|
||||
"offset"_a = 64,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def quantized_kv_update(new_keys: array, new_values: array, key_scales: array, key_biases: array, values: array, value_scales: array, value_biases: array, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Fused update for a quantized KV cache.
|
||||
|
||||
.. math::
|
||||
|
||||
w_i = s (\hat{w_i} + \beta)
|
||||
|
||||
Args:
|
||||
w (array): Matrix to be quantize
|
||||
scales (array): The scales to use per ``group_size`` elements of ``w``
|
||||
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||
group_size (int, optional): The size of the group in ``w`` that shares a
|
||||
scale and bias. (default: ``64``)
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. (default: ``4``)
|
||||
|
||||
Returns:
|
||||
array: The quantized version of ``w``
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"metal_kernel",
|
||||
[](const std::string& name,
|
||||
|
Loading…
Reference in New Issue
Block a user