Add packed_affine_qmm_t

This commit is contained in:
Angelos Katharopoulos
2024-12-16 21:49:14 -08:00
parent 410ccdbed5
commit fb7be036af
4 changed files with 455 additions and 21 deletions

View File

@@ -1,19 +1,19 @@
import argparse
import math
from functools import partial
import mlx.core as mx
from time_utils import time_fn
B = 1024
D = 1024
M = 4 * D
group_size = 64
bits = 4
dtype = mx.float16
loops = 100
loops = 10
def qmv_(x, wq1, wq2, q_type):
def qmm_(x, wq1, wq2, q_type):
for i in range(loops):
x = mx.quantized_matmul(
x,
@@ -32,28 +32,28 @@ def qmv_(x, wq1, wq2, q_type):
return x
def affine_qmv(x, wq1, wq2):
return qmv_(x, wq1, wq2, "affine")
def affine_qmm(x, wq1, wq2):
return qmm_(x, wq1, wq2, "affine")
def affine_packed_qmv(x, wq1, wq2):
return qmv_(x, wq1, wq2, "affine-packed")
def affine_packed_qmm(x, wq1, wq2):
return qmm_(x, wq1, wq2, "affine-packed")
def time_qmv():
def time_qmm():
mx.random.seed(3)
x = mx.random.normal(shape=(1, D)).astype(dtype)
x = mx.random.normal(shape=(B, D)).astype(dtype)
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, quantization_type="affine")
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, quantization_type="affine")
mx.eval(x, wq1, wq2)
time_fn(affine_qmv, x, wq1, wq2)
time_fn(affine_qmm, x, wq1, wq2)
def time_packed_qmv():
def time_packed_qmm():
mx.random.seed(3)
x = mx.random.normal(shape=(1, D)).astype(dtype)
x = mx.random.normal(shape=(B, D)).astype(dtype)
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
wq1 = mx.quantize(
w1, group_size=group_size, bits=bits, quantization_type="affine-packed"
@@ -63,12 +63,12 @@ def time_packed_qmv():
w2, group_size=group_size, bits=bits, quantization_type="affine-packed"
)
mx.eval(x, wq1, wq2)
time_fn(affine_packed_qmv, x, wq1, wq2)
time_fn(affine_packed_qmm, x, wq1, wq2)
if __name__ == "__main__":
for b in [2, 4, 8]:
bits = b
print(f"Bits {bits}:")
time_qmv()
time_packed_qmv()
time_qmm()
time_packed_qmm()