update bench

This commit is contained in:
Alex Barron 2024-10-22 20:27:58 -07:00
parent 852336b8a2
commit b509c2ad76

View File

@ -25,13 +25,13 @@ def sdpa(q, k, v):
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None) return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None)
def quant_sdpa(q, k, v): def quant_sdpa(q, k, v, bits=4):
return mx.fast.quantized_scaled_dot_product_attention( return mx.fast.quantized_scaled_dot_product_attention(
q, *k, *v, scale=1.0, mask=None, bits=8 q, *k, *v, scale=1.0, mask=None, bits=bits
) )
def quant_attention(q, k, v): def quant_attention(q, k, v, bits=4):
B, Hq, L, D = q.shape B, Hq, L, D = q.shape
Hk = k[0].shape[1] Hk = k[0].shape[1]
@ -39,10 +39,10 @@ def quant_attention(q, k, v):
k = tree_map(lambda x: mx.expand_dims(x, axis=2), k) k = tree_map(lambda x: mx.expand_dims(x, axis=2), k)
v = tree_map(lambda x: mx.expand_dims(x, axis=2), v) v = tree_map(lambda x: mx.expand_dims(x, axis=2), v)
scores = mx.quantized_matmul(q, *k, transpose=True) scores = mx.quantized_matmul(q, *k, transpose=True, bits=bits)
scores = mx.softmax(scores, axis=-1) scores = mx.softmax(scores, axis=-1)
out = mx.quantized_matmul(scores, *v, transpose=False) out = mx.quantized_matmul(scores, *v, transpose=False, bits=bits)
out = out.reshape((B, Hq, L, D)) out = out.reshape((B, Hq, L, D))
return out return out
@ -55,11 +55,11 @@ def time_self_attention_sdpa(q, k, v):
time_fn(sdpa, q, k, v) time_fn(sdpa, q, k, v)
def time_self_attention_quant_sdpa(q, k, v): def time_self_attention_quant_sdpa(q, k, v, bits=4):
time_fn(quant_sdpa, q, k, v) time_fn(quant_sdpa, q, k, v, bits)
def time_self_attention_quant_primitives(q, k, v): def time_self_attention_quant_primitives(q, k, v, bits=4):
time_fn(quant_attention, q, k, v) time_fn(quant_attention, q, k, v)
@ -70,11 +70,12 @@ if __name__ == "__main__":
v = mx.random.uniform(shape=(1, H_k, L, D)) v = mx.random.uniform(shape=(1, H_k, L, D))
mx.eval(q, k, v) mx.eval(q, k, v)
k_quant = mx.quantize(k) bits = 4
v_quant = mx.quantize(v) k_quant = mx.quantize(k, bits=bits)
v_quant = mx.quantize(v, bits=bits)
mx.eval(k_quant, v_quant) mx.eval(k_quant, v_quant)
time_self_attention_sdpa(q, k, v) time_self_attention_sdpa(q, k, v)
time_self_attention_quant_sdpa(q, k_quant, v_quant) time_self_attention_quant_sdpa(q, k_quant, v_quant, bits)
time_self_attention_primitives(q, k, v) time_self_attention_primitives(q, k, v)
time_self_attention_quant_primitives(q, k_quant, v_quant) time_self_attention_quant_primitives(q, k_quant, v_quant, bits)