mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-05 11:28:12 +08:00
start
This commit is contained in:
@@ -1,16 +1,18 @@
|
||||
import argparse
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from time_utils import time_fn
|
||||
|
||||
L = 1024
|
||||
L = 30000
|
||||
H = 32
|
||||
H_k = 32 // 4
|
||||
D = 128
|
||||
|
||||
|
||||
def attention(q, k, v):
|
||||
k = mx.quantize(k)
|
||||
v = mx.quantize(v)
|
||||
k = mx.dequantize(*k)
|
||||
v = mx.dequantize(*v)
|
||||
B, Hq, L, D = q.shape
|
||||
_, Hk, S, _ = k.shape
|
||||
q = q.reshape(B, Hk, Hq // Hk, L, D)
|
||||
@@ -23,27 +25,54 @@ def attention(q, k, v):
|
||||
|
||||
|
||||
def sdpa(q, k, v):
|
||||
k = mx.quantize(k)
|
||||
v = mx.quantize(v)
|
||||
k = mx.dequantize(*k)
|
||||
v = mx.dequantize(*v)
|
||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
|
||||
|
||||
|
||||
def time_self_attention_primitives():
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D))
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
mx.eval(q, k, v)
|
||||
def quant_sdpa(q, k, v):
|
||||
k = mx.quantize(k)
|
||||
v = mx.quantize(v)
|
||||
return mx.fast.quantized_scaled_dot_product_attention(q, *k, *v, scale=1.0)
|
||||
|
||||
|
||||
def time_self_attention_primitives(q, k, v):
|
||||
time_fn(attention, q, k, v)
|
||||
|
||||
|
||||
def time_self_attention_sdpa():
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D))
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
mx.eval(q, k, v)
|
||||
def time_self_attention_sdpa(q, k, v):
|
||||
time_fn(sdpa, q, k, v)
|
||||
|
||||
|
||||
def time_self_attention_quant_sdpa(q, k, v):
|
||||
time_fn(quant_sdpa, q, k, v)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_self_attention_sdpa()
|
||||
time_self_attention_primitives()
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 10, D))
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D))
|
||||
mx.eval(q, k, v)
|
||||
|
||||
k_quant = mx.quantize(k)
|
||||
v_quant = mx.quantize(v)
|
||||
mx.eval(k_quant, v_quant)
|
||||
|
||||
# time_self_attention_sdpa(q, k, v)
|
||||
# time_self_attention_quant_sdpa(q, k_quant, v_quant)
|
||||
# time_self_attention_primitives(q, k, v)
|
||||
q_sdpa = quant_sdpa(q, k, v)
|
||||
print(q_sdpa)
|
||||
o_attention = attention(q, k, v)
|
||||
print(o_attention)
|
||||
np.testing.assert_allclose(q_sdpa, o_attention, atol=1e-5)
|
||||
# o_sdpa = sdpa(q, k, v)
|
||||
# print(o_sdpa)
|
||||
# np.testing.assert_allclose(q_sdpa, o_sdpa, atol=1e-5)
|
||||
# print(o_sdpa[..., :64])
|
||||
# print()
|
||||
# print(o_attention[..., :64])
|
||||
# np.testing.assert_allclose(o_sdpa, o_attention)
|
||||
|
||||
Reference in New Issue
Block a user