diff --git a/benchmarks/python/sdpa_vector_bench.py b/benchmarks/python/sdpa_vector_bench.py index 9fb6f36d2..5ce1c50a1 100644 --- a/benchmarks/python/sdpa_vector_bench.py +++ b/benchmarks/python/sdpa_vector_bench.py @@ -25,13 +25,13 @@ def sdpa(q, k, v): return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=None) -def quant_sdpa(q, k, v): +def quant_sdpa(q, k, v, bits=4): return mx.fast.quantized_scaled_dot_product_attention( - q, *k, *v, scale=1.0, mask=None, bits=8 + q, *k, *v, scale=1.0, mask=None, bits=bits ) -def quant_attention(q, k, v): +def quant_attention(q, k, v, bits=4): B, Hq, L, D = q.shape Hk = k[0].shape[1] @@ -39,10 +39,10 @@ def quant_attention(q, k, v): k = tree_map(lambda x: mx.expand_dims(x, axis=2), k) v = tree_map(lambda x: mx.expand_dims(x, axis=2), v) - scores = mx.quantized_matmul(q, *k, transpose=True) + scores = mx.quantized_matmul(q, *k, transpose=True, bits=bits) scores = mx.softmax(scores, axis=-1) - out = mx.quantized_matmul(scores, *v, transpose=False) + out = mx.quantized_matmul(scores, *v, transpose=False, bits=bits) out = out.reshape((B, Hq, L, D)) return out @@ -55,11 +55,11 @@ def time_self_attention_sdpa(q, k, v): time_fn(sdpa, q, k, v) -def time_self_attention_quant_sdpa(q, k, v): - time_fn(quant_sdpa, q, k, v) +def time_self_attention_quant_sdpa(q, k, v, bits=4): + time_fn(quant_sdpa, q, k, v, bits) -def time_self_attention_quant_primitives(q, k, v): +def time_self_attention_quant_primitives(q, k, v, bits=4): time_fn(quant_attention, q, k, v) @@ -70,11 +70,12 @@ if __name__ == "__main__": v = mx.random.uniform(shape=(1, H_k, L, D)) mx.eval(q, k, v) - k_quant = mx.quantize(k) - v_quant = mx.quantize(v) + bits = 4 + k_quant = mx.quantize(k, bits=bits) + v_quant = mx.quantize(v, bits=bits) mx.eval(k_quant, v_quant) time_self_attention_sdpa(q, k, v) - time_self_attention_quant_sdpa(q, k_quant, v_quant) + time_self_attention_quant_sdpa(q, k_quant, v_quant, bits) time_self_attention_primitives(q, k, v) - time_self_attention_quant_primitives(q, k_quant, v_quant) + time_self_attention_quant_primitives(q, k_quant, v_quant, bits)