mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-04 18:48:15 +08:00
Change the argument name to quantization_type
This commit is contained in:
@@ -20,14 +20,14 @@ def qmv_(x, wq1, wq2, q_type):
|
||||
*wq1,
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
type=q_type,
|
||||
quantization_type=q_type,
|
||||
)
|
||||
x = mx.quantized_matmul(
|
||||
x,
|
||||
*wq2,
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
type=q_type,
|
||||
quantization_type=q_type,
|
||||
)
|
||||
return x
|
||||
|
||||
@@ -44,9 +44,9 @@ def time_qmv():
|
||||
mx.random.seed(3)
|
||||
x = mx.random.normal(shape=(1, D)).astype(dtype)
|
||||
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
|
||||
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, type="affine")
|
||||
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, quantization_type="affine")
|
||||
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
|
||||
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, type="affine")
|
||||
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, quantization_type="affine")
|
||||
mx.eval(x, wq1, wq2)
|
||||
time_fn(affine_qmv, x, wq1, wq2)
|
||||
|
||||
@@ -55,15 +55,19 @@ def time_packed_qmv():
|
||||
mx.random.seed(3)
|
||||
x = mx.random.normal(shape=(1, D)).astype(dtype)
|
||||
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
|
||||
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, type="affine-packed")
|
||||
wq1 = mx.quantize(
|
||||
w1, group_size=group_size, bits=bits, quantization_type="affine-packed"
|
||||
)
|
||||
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
|
||||
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, type="affine-packed")
|
||||
wq2 = mx.quantize(
|
||||
w2, group_size=group_size, bits=bits, quantization_type="affine-packed"
|
||||
)
|
||||
mx.eval(x, wq1, wq2)
|
||||
time_fn(affine_packed_qmv, x, wq1, wq2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for b in [2, 3, 4, 6, 8]:
|
||||
for b in [2, 4, 8]:
|
||||
bits = b
|
||||
print(f"Bits {bits}:")
|
||||
time_qmv()
|
||||
|
||||
Reference in New Issue
Block a user