2024-10-18 15:58:52 +08:00
|
|
|
import mlx.core as mx
|
2024-10-23 07:14:29 +08:00
|
|
|
import numpy as np
|
2024-10-23 11:13:32 +08:00
|
|
|
from mlx.utils import tree_map
|
2024-10-18 15:58:52 +08:00
|
|
|
from time_utils import time_fn
|
|
|
|
|
2024-10-23 11:13:32 +08:00
|
|
|
L = 65536
|
2024-10-18 15:58:52 +08:00
|
|
|
H = 32
|
|
|
|
H_k = 32 // 4
|
|
|
|
D = 128
|
|
|
|
|
|
|
|
|
|
|
|
def attention(q, k, v):
|
|
|
|
B, Hq, L, D = q.shape
|
|
|
|
_, Hk, S, _ = k.shape
|
|
|
|
q = q.reshape(B, Hk, Hq // Hk, L, D)
|
|
|
|
k = k[:, :, None, :, :]
|
|
|
|
v = v[:, :, None, :, :]
|
|
|
|
s = q @ k.transpose(0, 1, 2, 4, 3)
|
|
|
|
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
|
|
|
o = p @ v
|
|
|
|
return o.reshape(B, Hq, L, D)
|
|
|
|
|
|
|
|
|
|
|
|
def sdpa(q, k, v):
|
2024-10-23 11:09:27 +08:00
|
|
|
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None)
|
2024-10-18 15:58:52 +08:00
|
|
|
|
|
|
|
|
2024-10-23 07:14:29 +08:00
|
|
|
def quant_sdpa(q, k, v):
|
2024-10-23 10:20:45 +08:00
|
|
|
return mx.fast.quantized_scaled_dot_product_attention(
|
2024-10-23 11:09:27 +08:00
|
|
|
q, *k, *v, scale=1.0, mask=None, bits=8
|
2024-10-23 10:20:45 +08:00
|
|
|
)
|
2024-10-23 07:14:29 +08:00
|
|
|
|
|
|
|
|
2024-10-23 11:13:32 +08:00
|
|
|
def quant_attention(q, k, v):
|
|
|
|
B, Hq, L, D = q.shape
|
|
|
|
Hk = k[0].shape[1]
|
|
|
|
|
|
|
|
q = q.reshape((B, Hk, Hq // Hk, L, D))
|
|
|
|
k = tree_map(lambda x: mx.expand_dims(x, axis=2), k)
|
|
|
|
v = tree_map(lambda x: mx.expand_dims(x, axis=2), v)
|
|
|
|
|
|
|
|
scores = mx.quantized_matmul(q, *k, transpose=True)
|
|
|
|
scores = mx.softmax(scores, axis=-1)
|
|
|
|
|
|
|
|
out = mx.quantized_matmul(scores, *v, transpose=False)
|
|
|
|
out = out.reshape((B, Hq, L, D))
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
2024-10-23 07:14:29 +08:00
|
|
|
def time_self_attention_primitives(q, k, v):
|
2024-10-18 15:58:52 +08:00
|
|
|
time_fn(attention, q, k, v)
|
|
|
|
|
|
|
|
|
2024-10-23 07:14:29 +08:00
|
|
|
def time_self_attention_sdpa(q, k, v):
|
|
|
|
time_fn(sdpa, q, k, v)
|
|
|
|
|
|
|
|
|
|
|
|
def time_self_attention_quant_sdpa(q, k, v):
|
|
|
|
time_fn(quant_sdpa, q, k, v)
|
|
|
|
|
|
|
|
|
2024-10-23 11:13:32 +08:00
|
|
|
def time_self_attention_quant_primitives(q, k, v):
|
|
|
|
time_fn(quant_attention, q, k, v)
|
|
|
|
|
|
|
|
|
2024-10-23 07:14:29 +08:00
|
|
|
if __name__ == "__main__":
|
2024-10-18 15:58:52 +08:00
|
|
|
mx.random.seed(3)
|
2024-10-23 11:13:32 +08:00
|
|
|
q = mx.random.uniform(shape=(1, H, 1, D))
|
|
|
|
k = mx.random.uniform(shape=(1, H_k, L, D))
|
|
|
|
v = mx.random.uniform(shape=(1, H_k, L, D))
|
2024-10-18 15:58:52 +08:00
|
|
|
mx.eval(q, k, v)
|
|
|
|
|
2024-10-23 07:14:29 +08:00
|
|
|
k_quant = mx.quantize(k)
|
|
|
|
v_quant = mx.quantize(v)
|
|
|
|
mx.eval(k_quant, v_quant)
|
2024-10-18 15:58:52 +08:00
|
|
|
|
2024-10-23 11:13:32 +08:00
|
|
|
time_self_attention_sdpa(q, k, v)
|
|
|
|
time_self_attention_quant_sdpa(q, k_quant, v_quant)
|
|
|
|
time_self_attention_primitives(q, k, v)
|
|
|
|
time_self_attention_quant_primitives(q, k_quant, v_quant)
|