mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-01 16:58:08 +08:00
8 bit working
This commit is contained in:
@@ -25,18 +25,18 @@ def attention(q, k, v):
|
||||
|
||||
|
||||
def sdpa(q, k, v):
|
||||
k = mx.quantize(k)
|
||||
v = mx.quantize(v)
|
||||
k = mx.dequantize(*k)
|
||||
v = mx.dequantize(*v)
|
||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=0.08, mask=None)
|
||||
k = mx.quantize(k, bits=8)
|
||||
v = mx.quantize(v, bits=8)
|
||||
k = mx.dequantize(*k, bits=8)
|
||||
v = mx.dequantize(*v, bits=8)
|
||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None)
|
||||
|
||||
|
||||
def quant_sdpa(q, k, v):
|
||||
k = mx.quantize(k)
|
||||
v = mx.quantize(v)
|
||||
k = mx.quantize(k, bits=8)
|
||||
v = mx.quantize(v, bits=8)
|
||||
return mx.fast.quantized_scaled_dot_product_attention(
|
||||
q, *k, *v, scale=0.08, mask=None
|
||||
q, *k, *v, scale=1.0, mask=None, bits=8
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user