2024-12-14 08:27:27 +08:00
|
|
|
import argparse
|
|
|
|
import math
|
|
|
|
|
|
|
|
import mlx.core as mx
|
|
|
|
from time_utils import time_fn
|
|
|
|
|
2024-12-17 13:49:14 +08:00
|
|
|
B = 1024
|
2024-12-15 15:04:29 +08:00
|
|
|
D = 1024
|
|
|
|
M = 4 * D
|
2024-12-14 08:27:27 +08:00
|
|
|
group_size = 64
|
2024-12-14 15:23:21 +08:00
|
|
|
bits = 4
|
2024-12-14 08:27:27 +08:00
|
|
|
dtype = mx.float16
|
2024-12-17 13:49:14 +08:00
|
|
|
loops = 10
|
2024-12-14 08:27:27 +08:00
|
|
|
|
|
|
|
|
2024-12-17 13:49:14 +08:00
|
|
|
def qmm_(x, wq1, wq2, q_type):
|
2024-12-14 08:27:27 +08:00
|
|
|
for i in range(loops):
|
|
|
|
x = mx.quantized_matmul(
|
|
|
|
x,
|
2024-12-15 15:04:29 +08:00
|
|
|
*wq1,
|
|
|
|
group_size=group_size,
|
|
|
|
bits=bits,
|
2024-12-17 05:31:34 +08:00
|
|
|
quantization_type=q_type,
|
2024-12-15 15:04:29 +08:00
|
|
|
)
|
|
|
|
x = mx.quantized_matmul(
|
|
|
|
x,
|
|
|
|
*wq2,
|
2024-12-14 08:27:27 +08:00
|
|
|
group_size=group_size,
|
|
|
|
bits=bits,
|
2024-12-17 05:31:34 +08:00
|
|
|
quantization_type=q_type,
|
2024-12-14 08:27:27 +08:00
|
|
|
)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
2024-12-17 13:49:14 +08:00
|
|
|
def affine_qmm(x, wq1, wq2):
|
|
|
|
return qmm_(x, wq1, wq2, "affine")
|
2024-12-14 08:27:27 +08:00
|
|
|
|
|
|
|
|
2024-12-17 13:49:14 +08:00
|
|
|
def affine_packed_qmm(x, wq1, wq2):
|
|
|
|
return qmm_(x, wq1, wq2, "affine-packed")
|
2024-12-14 08:27:27 +08:00
|
|
|
|
|
|
|
|
2024-12-17 13:49:14 +08:00
|
|
|
def time_qmm():
|
2024-12-14 08:27:27 +08:00
|
|
|
mx.random.seed(3)
|
2024-12-17 13:49:14 +08:00
|
|
|
x = mx.random.normal(shape=(B, D)).astype(dtype)
|
2024-12-15 15:04:29 +08:00
|
|
|
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
|
2024-12-17 05:31:34 +08:00
|
|
|
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, quantization_type="affine")
|
2024-12-15 15:04:29 +08:00
|
|
|
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
|
2024-12-17 05:31:34 +08:00
|
|
|
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, quantization_type="affine")
|
2024-12-15 15:04:29 +08:00
|
|
|
mx.eval(x, wq1, wq2)
|
2024-12-17 13:49:14 +08:00
|
|
|
time_fn(affine_qmm, x, wq1, wq2)
|
2024-12-14 08:27:27 +08:00
|
|
|
|
|
|
|
|
2024-12-17 13:49:14 +08:00
|
|
|
def time_packed_qmm():
|
2024-12-14 08:27:27 +08:00
|
|
|
mx.random.seed(3)
|
2024-12-17 13:49:14 +08:00
|
|
|
x = mx.random.normal(shape=(B, D)).astype(dtype)
|
2024-12-15 15:04:29 +08:00
|
|
|
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
|
2024-12-17 05:31:34 +08:00
|
|
|
wq1 = mx.quantize(
|
|
|
|
w1, group_size=group_size, bits=bits, quantization_type="affine-packed"
|
|
|
|
)
|
2024-12-15 15:04:29 +08:00
|
|
|
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
|
2024-12-17 05:31:34 +08:00
|
|
|
wq2 = mx.quantize(
|
|
|
|
w2, group_size=group_size, bits=bits, quantization_type="affine-packed"
|
|
|
|
)
|
2024-12-15 15:04:29 +08:00
|
|
|
mx.eval(x, wq1, wq2)
|
2024-12-17 13:49:14 +08:00
|
|
|
time_fn(affine_packed_qmm, x, wq1, wq2)
|
2024-12-14 08:27:27 +08:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2024-12-17 05:31:34 +08:00
|
|
|
for b in [2, 4, 8]:
|
2024-12-15 15:04:29 +08:00
|
|
|
bits = b
|
|
|
|
print(f"Bits {bits}:")
|
2024-12-17 13:49:14 +08:00
|
|
|
time_qmm()
|
|
|
|
time_packed_qmm()
|