This commit is contained in:
Alex Barron
2024-10-22 16:14:29 -07:00
parent 8e88e30d95
commit 5824626c0b
8 changed files with 532 additions and 62 deletions

View File

@@ -150,6 +150,49 @@ void init_fast(nb::module_& parent_module) {
array: The output array.
)pbdoc");
m.def(
"quantized_scaled_dot_product_attention",
&fast::quantized_scaled_dot_product_attention,
"q"_a,
"k"_a,
"k_scales"_a,
"k_biases"_a,
"v"_a,
"v_scales"_a,
"v_biases"_a,
nb::kw_only(),
"scale"_a,
"mask"_a = nb::none(),
"group_size"_a = 64,
"bits"_a = 4,
"stream"_a = nb::none(),
nb::sig(
"def quantized_scaled_dot_product_attention(q: array, k: array, k_scales: array, k_biases: array, v: array, v_scales: array, v_biases: array, *, scale: float, mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.
Supports:
* `Multi-Head Attention <https://arxiv.org/abs/1706.03762>`_
* `Grouped Query Attention <https://arxiv.org/abs/2305.13245>`_
* `Multi-Query Attention <https://arxiv.org/abs/1911.02150>`_
Note: The softmax operation is performed in ``float32`` regardless of
the input precision.
Note: For Grouped Query Attention and Multi-Query Attention, the ``k``
and ``v`` inputs should not be pre-tiled to match ``q``.
Args:
q (array): Input query array.
k (array): Input keys array.
v (array): Input values array.
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
mask (array, optional): An additive mask to apply to the query-key scores.
Returns:
array: The output array.
)pbdoc");
m.def(
"affine_quantize",
nb::overload_cast<