import argparse import math import mlx.core as mx from time_utils import time_fn B = 1024 D = 1024 M = 4 * D group_size = 64 bits = 4 dtype = mx.float16 loops = 10 def qmm_(x, wq1, wq2, q_type): for i in range(loops): x = mx.quantized_matmul( x, *wq1, group_size=group_size, bits=bits, quantization_type=q_type, ) x = mx.quantized_matmul( x, *wq2, group_size=group_size, bits=bits, quantization_type=q_type, ) return x def affine_qmm(x, wq1, wq2): return qmm_(x, wq1, wq2, "affine") def affine_packed_qmm(x, wq1, wq2): return qmm_(x, wq1, wq2, "affine-packed") def time_qmm(): mx.random.seed(3) x = mx.random.normal(shape=(B, D)).astype(dtype) w1 = mx.random.normal(shape=(M, D)).astype(dtype) wq1 = mx.quantize(w1, group_size=group_size, bits=bits, quantization_type="affine") w2 = mx.random.normal(shape=(D, M)).astype(dtype) wq2 = mx.quantize(w2, group_size=group_size, bits=bits, quantization_type="affine") mx.eval(x, wq1, wq2) time_fn(affine_qmm, x, wq1, wq2) def time_packed_qmm(): mx.random.seed(3) x = mx.random.normal(shape=(B, D)).astype(dtype) w1 = mx.random.normal(shape=(M, D)).astype(dtype) wq1 = mx.quantize( w1, group_size=group_size, bits=bits, quantization_type="affine-packed" ) w2 = mx.random.normal(shape=(D, M)).astype(dtype) wq2 = mx.quantize( w2, group_size=group_size, bits=bits, quantization_type="affine-packed" ) mx.eval(x, wq1, wq2) time_fn(affine_packed_qmm, x, wq1, wq2) if __name__ == "__main__": for b in [2, 4, 8]: bits = b print(f"Bits {bits}:") time_qmm() time_packed_qmm()