4 bit working

This commit is contained in:
Alex Barron
2024-10-22 19:20:45 -07:00
parent 5824626c0b
commit ef14b1e9c3
3 changed files with 24 additions and 16 deletions

View File

@@ -2,7 +2,7 @@ import mlx.core as mx
import numpy as np
from time_utils import time_fn
L = 30000
L = 16
H = 32
H_k = 32 // 4
D = 128
@@ -29,13 +29,15 @@ def sdpa(q, k, v):
v = mx.quantize(v)
k = mx.dequantize(*k)
v = mx.dequantize(*v)
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
return mx.fast.scaled_dot_product_attention(q, k, v, scale=0.08, mask=None)
def quant_sdpa(q, k, v):
k = mx.quantize(k)
v = mx.quantize(v)
return mx.fast.quantized_scaled_dot_product_attention(q, *k, *v, scale=1.0)
return mx.fast.quantized_scaled_dot_product_attention(
q, *k, *v, scale=0.08, mask=None
)
def time_self_attention_primitives(q, k, v):
@@ -52,9 +54,14 @@ def time_self_attention_quant_sdpa(q, k, v):
if __name__ == "__main__":
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 10, D))
k = mx.random.uniform(shape=(1, H_k, L, D))
v = mx.random.uniform(shape=(1, H_k, L, D))
# q = mx.random.uniform(shape=(1, H, 1, D))
# k = mx.random.uniform(shape=(1, H_k, L, D))
# v = mx.random.uniform(shape=(1, H_k, L, D))
q = mx.array(np.load("/Users/alexbarron/mlx-examples/llms/queries.npy"))
k = mx.array(np.load("/Users/alexbarron/mlx-examples/llms/keys.npy"))
v = mx.array(np.load("/Users/alexbarron/mlx-examples/llms/values.npy"))
print(q.dtype)
print(q.shape, k.shape, v.shape)
mx.eval(q, k, v)
k_quant = mx.quantize(k)
@@ -66,12 +73,12 @@ if __name__ == "__main__":
# time_self_attention_primitives(q, k, v)
q_sdpa = quant_sdpa(q, k, v)
print(q_sdpa)
o_attention = attention(q, k, v)
print(o_attention)
np.testing.assert_allclose(q_sdpa, o_attention, atol=1e-5)
# o_sdpa = sdpa(q, k, v)
# print(o_sdpa)
# np.testing.assert_allclose(q_sdpa, o_sdpa, atol=1e-5)
# o_attention = attention(q, k, v)
# print(o_attention)
# np.testing.assert_allclose(q_sdpa, o_attention, atol=1e-5)
o_sdpa = sdpa(q, k, v)
print(o_sdpa)
np.testing.assert_allclose(q_sdpa, o_sdpa, atol=1e-5)
# print(o_sdpa[..., :64])
# print()
# print(o_attention[..., :64])