mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-28 03:41:14 +08:00
clean
This commit is contained in:
parent
6649244686
commit
852336b8a2
@ -1,18 +1,15 @@
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from mlx.utils import tree_map
|
||||
from time_utils import time_fn
|
||||
|
||||
L = 16
|
||||
L = 65536
|
||||
H = 32
|
||||
H_k = 32 // 4
|
||||
D = 128
|
||||
|
||||
|
||||
def attention(q, k, v):
|
||||
k = mx.quantize(k)
|
||||
v = mx.quantize(v)
|
||||
k = mx.dequantize(*k)
|
||||
v = mx.dequantize(*v)
|
||||
B, Hq, L, D = q.shape
|
||||
_, Hk, S, _ = k.shape
|
||||
q = q.reshape(B, Hk, Hq // Hk, L, D)
|
||||
@ -25,21 +22,31 @@ def attention(q, k, v):
|
||||
|
||||
|
||||
def sdpa(q, k, v):
|
||||
k = mx.quantize(k, bits=8)
|
||||
v = mx.quantize(v, bits=8)
|
||||
k = mx.dequantize(*k, bits=8)
|
||||
v = mx.dequantize(*v, bits=8)
|
||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None)
|
||||
|
||||
|
||||
def quant_sdpa(q, k, v):
|
||||
k = mx.quantize(k, bits=8)
|
||||
v = mx.quantize(v, bits=8)
|
||||
return mx.fast.quantized_scaled_dot_product_attention(
|
||||
q, *k, *v, scale=1.0, mask=None, bits=8
|
||||
)
|
||||
|
||||
|
||||
def quant_attention(q, k, v):
|
||||
B, Hq, L, D = q.shape
|
||||
Hk = k[0].shape[1]
|
||||
|
||||
q = q.reshape((B, Hk, Hq // Hk, L, D))
|
||||
k = tree_map(lambda x: mx.expand_dims(x, axis=2), k)
|
||||
v = tree_map(lambda x: mx.expand_dims(x, axis=2), v)
|
||||
|
||||
scores = mx.quantized_matmul(q, *k, transpose=True)
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
|
||||
out = mx.quantized_matmul(scores, *v, transpose=False)
|
||||
out = out.reshape((B, Hq, L, D))
|
||||
return out
|
||||
|
||||
|
||||
def time_self_attention_primitives(q, k, v):
|
||||
time_fn(attention, q, k, v)
|
||||
|
||||
@ -52,34 +59,22 @@ def time_self_attention_quant_sdpa(q, k, v):
|
||||
time_fn(quant_sdpa, q, k, v)
|
||||
|
||||
|
||||
def time_self_attention_quant_primitives(q, k, v):
|
||||
time_fn(quant_attention, q, k, v)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mx.random.seed(3)
|
||||
# q = mx.random.uniform(shape=(1, H, 1, D))
|
||||
# k = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
# v = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
q = mx.array(np.load("/Users/alexbarron/mlx-examples/llms/queries.npy"))
|
||||
k = mx.array(np.load("/Users/alexbarron/mlx-examples/llms/keys.npy"))
|
||||
v = mx.array(np.load("/Users/alexbarron/mlx-examples/llms/values.npy"))
|
||||
print(q.dtype)
|
||||
print(q.shape, k.shape, v.shape)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D))
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
mx.eval(q, k, v)
|
||||
|
||||
k_quant = mx.quantize(k)
|
||||
v_quant = mx.quantize(v)
|
||||
mx.eval(k_quant, v_quant)
|
||||
|
||||
# time_self_attention_sdpa(q, k, v)
|
||||
# time_self_attention_quant_sdpa(q, k_quant, v_quant)
|
||||
# time_self_attention_primitives(q, k, v)
|
||||
q_sdpa = quant_sdpa(q, k, v)
|
||||
print(q_sdpa)
|
||||
# o_attention = attention(q, k, v)
|
||||
# print(o_attention)
|
||||
# np.testing.assert_allclose(q_sdpa, o_attention, atol=1e-5)
|
||||
o_sdpa = sdpa(q, k, v)
|
||||
print(o_sdpa)
|
||||
np.testing.assert_allclose(q_sdpa, o_sdpa, atol=1e-5)
|
||||
# print(o_sdpa[..., :64])
|
||||
# print()
|
||||
# print(o_attention[..., :64])
|
||||
# np.testing.assert_allclose(o_sdpa, o_attention)
|
||||
time_self_attention_sdpa(q, k, v)
|
||||
time_self_attention_quant_sdpa(q, k_quant, v_quant)
|
||||
time_self_attention_primitives(q, k, v)
|
||||
time_self_attention_quant_primitives(q, k_quant, v_quant)
|
||||
|
@ -9,8 +9,6 @@
|
||||
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
namespace {
|
||||
|
@ -10,8 +10,6 @@
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/transforms.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
std::vector<array> Custom::vjp(
|
||||
|
@ -169,26 +169,22 @@ void init_fast(nb::module_& parent_module) {
|
||||
nb::sig(
|
||||
"def quantized_scaled_dot_product_attention(q: array, k: array, k_scales: array, k_biases: array, v: array, v_scales: array, v_biases: array, *, scale: float, mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.
|
||||
A fast implementation of multi-head attention where the keys and values are quantized.
|
||||
|
||||
Supports:
|
||||
|
||||
* `Multi-Head Attention <https://arxiv.org/abs/1706.03762>`_
|
||||
* `Grouped Query Attention <https://arxiv.org/abs/2305.13245>`_
|
||||
* `Multi-Query Attention <https://arxiv.org/abs/1911.02150>`_
|
||||
|
||||
Note: The softmax operation is performed in ``float32`` regardless of
|
||||
the input precision.
|
||||
|
||||
Note: For Grouped Query Attention and Multi-Query Attention, the ``k``
|
||||
and ``v`` inputs should not be pre-tiled to match ``q``.
|
||||
see :func:`scaled_dot_product_attention` for more details.
|
||||
|
||||
Args:
|
||||
q (array): Input query array.
|
||||
k (array): Input keys array.
|
||||
k_scales (array): Scales for the quantized keys array.
|
||||
k_biases (array): Biases for the quantized keys array.
|
||||
v (array): Input values array.
|
||||
v_scales (array): Scales for the quantized values array.
|
||||
v_biases (array): Biases for the quantized values array.
|
||||
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
|
||||
mask (array, optional): An additive mask to apply to the query-key scores.
|
||||
group_size (int): The group size used in the KV quantization.
|
||||
bits (int): The bits used in the KV quantization.
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
|
Loading…
Reference in New Issue
Block a user