add fast::quantized_kv_update

This commit is contained in:
Alex Barron
2024-10-26 00:24:49 -07:00
parent b509c2ad76
commit f5b0f11968
10 changed files with 266 additions and 7 deletions

View File

@@ -232,6 +232,44 @@ void init_fast(nb::module_& parent_module) {
array: The quantized version of ``w``
)pbdoc");
m.def(
"quantized_kv_update",
&fast::kv_update,
"new_keys"_a,
"new_values"_a,
"keys"_a,
"key_scales"_a,
"key_biases"_a,
"values"_a,
"value_scales"_a,
"value_biases"_a,
"offset"_a = 64,
"group_size"_a = 64,
"bits"_a = 4,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def quantized_kv_update(new_keys: array, new_values: array, key_scales: array, key_biases: array, values: array, value_scales: array, value_biases: array, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Fused update for a quantized KV cache.
.. math::
w_i = s (\hat{w_i} + \beta)
Args:
w (array): Matrix to be quantize
scales (array): The scales to use per ``group_size`` elements of ``w``
biases (array): The biases to use per ``group_size`` elements of ``w``
group_size (int, optional): The size of the group in ``w`` that shares a
scale and bias. (default: ``64``)
bits (int, optional): The number of bits occupied by each element in
``w``. (default: ``4``)
Returns:
array: The quantized version of ``w``
)pbdoc");
m.def(
"metal_kernel",
[](const std::string& name,