mlx/benchmarks/python/packed_qmm_bench.py

75 lines
1.8 KiB
Python
Raw Normal View History

2024-12-14 08:27:27 +08:00
import argparse
import math
import mlx.core as mx
from time_utils import time_fn
2024-12-17 13:49:14 +08:00
B = 1024
2024-12-15 15:04:29 +08:00
D = 1024
M = 4 * D
2024-12-14 08:27:27 +08:00
group_size = 64
bits = 4
2024-12-14 08:27:27 +08:00
dtype = mx.float16
2024-12-17 13:49:14 +08:00
loops = 10
2024-12-14 08:27:27 +08:00
2024-12-17 13:49:14 +08:00
def qmm_(x, wq1, wq2, q_type):
2024-12-14 08:27:27 +08:00
for i in range(loops):
x = mx.quantized_matmul(
x,
2024-12-15 15:04:29 +08:00
*wq1,
group_size=group_size,
bits=bits,
quantization_type=q_type,
2024-12-15 15:04:29 +08:00
)
x = mx.quantized_matmul(
x,
*wq2,
2024-12-14 08:27:27 +08:00
group_size=group_size,
bits=bits,
quantization_type=q_type,
2024-12-14 08:27:27 +08:00
)
return x
2024-12-17 13:49:14 +08:00
def affine_qmm(x, wq1, wq2):
return qmm_(x, wq1, wq2, "affine")
2024-12-14 08:27:27 +08:00
2024-12-17 13:49:14 +08:00
def affine_packed_qmm(x, wq1, wq2):
return qmm_(x, wq1, wq2, "affine-packed")
2024-12-14 08:27:27 +08:00
2024-12-17 13:49:14 +08:00
def time_qmm():
2024-12-14 08:27:27 +08:00
mx.random.seed(3)
2024-12-17 13:49:14 +08:00
x = mx.random.normal(shape=(B, D)).astype(dtype)
2024-12-15 15:04:29 +08:00
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, quantization_type="affine")
2024-12-15 15:04:29 +08:00
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, quantization_type="affine")
2024-12-15 15:04:29 +08:00
mx.eval(x, wq1, wq2)
2024-12-17 13:49:14 +08:00
time_fn(affine_qmm, x, wq1, wq2)
2024-12-14 08:27:27 +08:00
2024-12-17 13:49:14 +08:00
def time_packed_qmm():
2024-12-14 08:27:27 +08:00
mx.random.seed(3)
2024-12-17 13:49:14 +08:00
x = mx.random.normal(shape=(B, D)).astype(dtype)
2024-12-15 15:04:29 +08:00
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
wq1 = mx.quantize(
w1, group_size=group_size, bits=bits, quantization_type="affine-packed"
)
2024-12-15 15:04:29 +08:00
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
wq2 = mx.quantize(
w2, group_size=group_size, bits=bits, quantization_type="affine-packed"
)
2024-12-15 15:04:29 +08:00
mx.eval(x, wq1, wq2)
2024-12-17 13:49:14 +08:00
time_fn(affine_packed_qmm, x, wq1, wq2)
2024-12-14 08:27:27 +08:00
if __name__ == "__main__":
for b in [2, 4, 8]:
2024-12-15 15:04:29 +08:00
bits = b
print(f"Bits {bits}:")
2024-12-17 13:49:14 +08:00
time_qmm()
time_packed_qmm()