Change the argument name to quantization_type

This commit is contained in:
Angelos Katharopoulos
2024-12-16 13:31:34 -08:00
parent f5da489a3c
commit 410ccdbed5
4 changed files with 78 additions and 50 deletions

View File

@@ -20,14 +20,14 @@ def qmv_(x, wq1, wq2, q_type):
*wq1,
group_size=group_size,
bits=bits,
type=q_type,
quantization_type=q_type,
)
x = mx.quantized_matmul(
x,
*wq2,
group_size=group_size,
bits=bits,
type=q_type,
quantization_type=q_type,
)
return x
@@ -44,9 +44,9 @@ def time_qmv():
mx.random.seed(3)
x = mx.random.normal(shape=(1, D)).astype(dtype)
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, type="affine")
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, quantization_type="affine")
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, type="affine")
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, quantization_type="affine")
mx.eval(x, wq1, wq2)
time_fn(affine_qmv, x, wq1, wq2)
@@ -55,15 +55,19 @@ def time_packed_qmv():
mx.random.seed(3)
x = mx.random.normal(shape=(1, D)).astype(dtype)
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, type="affine-packed")
wq1 = mx.quantize(
w1, group_size=group_size, bits=bits, quantization_type="affine-packed"
)
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, type="affine-packed")
wq2 = mx.quantize(
w2, group_size=group_size, bits=bits, quantization_type="affine-packed"
)
mx.eval(x, wq1, wq2)
time_fn(affine_packed_qmv, x, wq1, wq2)
if __name__ == "__main__":
for b in [2, 3, 4, 6, 8]:
for b in [2, 4, 8]:
bits = b
print(f"Bits {bits}:")
time_qmv()