8 bit working

This commit is contained in:
Alex Barron
2024-10-22 20:09:27 -07:00
parent ef14b1e9c3
commit 047a584e3d
6 changed files with 127 additions and 48 deletions

View File

@@ -25,18 +25,18 @@ def attention(q, k, v):
def sdpa(q, k, v):
k = mx.quantize(k)
v = mx.quantize(v)
k = mx.dequantize(*k)
v = mx.dequantize(*v)
return mx.fast.scaled_dot_product_attention(q, k, v, scale=0.08, mask=None)
k = mx.quantize(k, bits=8)
v = mx.quantize(v, bits=8)
k = mx.dequantize(*k, bits=8)
v = mx.dequantize(*v, bits=8)
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None)
def quant_sdpa(q, k, v):
k = mx.quantize(k)
v = mx.quantize(v)
k = mx.quantize(k, bits=8)
v = mx.quantize(v, bits=8)
return mx.fast.quantized_scaled_dot_product_attention(
q, *k, *v, scale=0.08, mask=None
q, *k, *v, scale=1.0, mask=None, bits=8
)