mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 15:41:13 +08:00
Improve the benchmark
This commit is contained in:
parent
fd161aa31f
commit
17a1fa2f0b
@ -5,18 +5,26 @@ from functools import partial
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from time_utils import time_fn
|
from time_utils import time_fn
|
||||||
|
|
||||||
D = 8192
|
D = 1024
|
||||||
|
M = 4 * D
|
||||||
group_size = 64
|
group_size = 64
|
||||||
bits = 4
|
bits = 4
|
||||||
dtype = mx.float16
|
dtype = mx.float16
|
||||||
loops = 100
|
loops = 100
|
||||||
|
|
||||||
|
|
||||||
def qmv_(x, wq, q_type):
|
def qmv_(x, wq1, wq2, q_type):
|
||||||
for i in range(loops):
|
for i in range(loops):
|
||||||
x = mx.quantized_matmul(
|
x = mx.quantized_matmul(
|
||||||
x,
|
x,
|
||||||
*wq,
|
*wq1,
|
||||||
|
group_size=group_size,
|
||||||
|
bits=bits,
|
||||||
|
type=q_type,
|
||||||
|
)
|
||||||
|
x = mx.quantized_matmul(
|
||||||
|
x,
|
||||||
|
*wq2,
|
||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
bits=bits,
|
bits=bits,
|
||||||
type=q_type,
|
type=q_type,
|
||||||
@ -24,32 +32,39 @@ def qmv_(x, wq, q_type):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def affine_qmv(x, wq):
|
def affine_qmv(x, wq1, wq2):
|
||||||
return qmv_(x, wq, "affine")
|
return qmv_(x, wq1, wq2, "affine")
|
||||||
|
|
||||||
|
|
||||||
def affine_packed_qmv(x, wq):
|
def affine_packed_qmv(x, wq1, wq2):
|
||||||
return qmv_(x, wq, "affine-packed")
|
return qmv_(x, wq1, wq2, "affine-packed")
|
||||||
|
|
||||||
|
|
||||||
def time_qmv():
|
def time_qmv():
|
||||||
mx.random.seed(3)
|
mx.random.seed(3)
|
||||||
x = mx.random.normal(shape=(1, D)).astype(dtype)
|
x = mx.random.normal(shape=(1, D)).astype(dtype)
|
||||||
w = mx.random.normal(shape=(D, D)).astype(dtype)
|
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
|
||||||
wq = mx.quantize(w, group_size=group_size, bits=bits, type="affine")
|
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, type="affine")
|
||||||
mx.eval(x, wq)
|
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
|
||||||
time_fn(affine_qmv, x, wq)
|
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, type="affine")
|
||||||
|
mx.eval(x, wq1, wq2)
|
||||||
|
time_fn(affine_qmv, x, wq1, wq2)
|
||||||
|
|
||||||
|
|
||||||
def time_packed_qmv():
|
def time_packed_qmv():
|
||||||
mx.random.seed(3)
|
mx.random.seed(3)
|
||||||
x = mx.random.normal(shape=(1, D)).astype(dtype)
|
x = mx.random.normal(shape=(1, D)).astype(dtype)
|
||||||
w = mx.random.normal(shape=(D, D)).astype(dtype)
|
w1 = mx.random.normal(shape=(M, D)).astype(dtype)
|
||||||
wq = mx.quantize(w, group_size=group_size, bits=bits, type="affine-packed")
|
wq1 = mx.quantize(w1, group_size=group_size, bits=bits, type="affine-packed")
|
||||||
mx.eval(x, wq)
|
w2 = mx.random.normal(shape=(D, M)).astype(dtype)
|
||||||
time_fn(affine_packed_qmv, x, wq)
|
wq2 = mx.quantize(w2, group_size=group_size, bits=bits, type="affine-packed")
|
||||||
|
mx.eval(x, wq1, wq2)
|
||||||
|
time_fn(affine_packed_qmv, x, wq1, wq2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
time_qmv()
|
for b in [2, 3, 4, 6, 8]:
|
||||||
time_packed_qmv()
|
bits = b
|
||||||
|
print(f"Bits {bits}:")
|
||||||
|
time_qmv()
|
||||||
|
time_packed_qmv()
|
||||||
|
Loading…
Reference in New Issue
Block a user