Improve the benchmark

This commit is contained in:
Angelos Katharopoulos 2024-12-14 23:04:29 -08:00
parent fd161aa31f
commit 17a1fa2f0b

View File

@ -5,18 +5,26 @@ from functools import partial
import mlx.core as mx import mlx.core as mx
from time_utils import time_fn from time_utils import time_fn
D = 8192 D = 1024
M = 4 * D
group_size = 64 group_size = 64
bits = 4 bits = 4
dtype = mx.float16 dtype = mx.float16
loops = 100 loops = 100
def qmv_(x, wq, q_type): def qmv_(x, wq1, wq2, q_type):
for i in range(loops): for i in range(loops):
x = mx.quantized_matmul( x = mx.quantized_matmul(
x, x,
*wq, *wq1,
group_size=group_size,
bits=bits,
type=q_type,
)
x = mx.quantized_matmul(
x,
*wq2,
group_size=group_size, group_size=group_size,
bits=bits, bits=bits,
type=q_type, type=q_type,
@ -24,32 +32,39 @@ def qmv_(x, wq, q_type):
return x return x
def affine_qmv(x, wq): def affine_qmv(x, wq1, wq2):
return qmv_(x, wq, "affine") return qmv_(x, wq1, wq2, "affine")
def affine_packed_qmv(x, wq): def affine_packed_qmv(x, wq1, wq2):
return qmv_(x, wq, "affine-packed") return qmv_(x, wq1, wq2, "affine-packed")
def time_qmv(): def time_qmv():
mx.random.seed(3) mx.random.seed(3)
x = mx.random.normal(shape=(1, D)).astype(dtype) x = mx.random.normal(shape=(1, D)).astype(dtype)
w = mx.random.normal(shape=(D, D)).astype(dtype) w1 = mx.random.normal(shape=(M, D)).astype(dtype)
wq = mx.quantize(w, group_size=group_size, bits=bits, type="affine") wq1 = mx.quantize(w1, group_size=group_size, bits=bits, type="affine")
mx.eval(x, wq) w2 = mx.random.normal(shape=(D, M)).astype(dtype)
time_fn(affine_qmv, x, wq) wq2 = mx.quantize(w2, group_size=group_size, bits=bits, type="affine")
mx.eval(x, wq1, wq2)
time_fn(affine_qmv, x, wq1, wq2)
def time_packed_qmv(): def time_packed_qmv():
mx.random.seed(3) mx.random.seed(3)
x = mx.random.normal(shape=(1, D)).astype(dtype) x = mx.random.normal(shape=(1, D)).astype(dtype)
w = mx.random.normal(shape=(D, D)).astype(dtype) w1 = mx.random.normal(shape=(M, D)).astype(dtype)
wq = mx.quantize(w, group_size=group_size, bits=bits, type="affine-packed") wq1 = mx.quantize(w1, group_size=group_size, bits=bits, type="affine-packed")
mx.eval(x, wq) w2 = mx.random.normal(shape=(D, M)).astype(dtype)
time_fn(affine_packed_qmv, x, wq) wq2 = mx.quantize(w2, group_size=group_size, bits=bits, type="affine-packed")
mx.eval(x, wq1, wq2)
time_fn(affine_packed_qmv, x, wq1, wq2)
if __name__ == "__main__": if __name__ == "__main__":
time_qmv() for b in [2, 3, 4, 6, 8]:
time_packed_qmv() bits = b
print(f"Bits {bits}:")
time_qmv()
time_packed_qmv()