mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 18:51:14 +08:00
56 lines
1.1 KiB
Python
56 lines
1.1 KiB
Python
![]() |
import argparse
|
||
|
import math
|
||
|
from functools import partial
|
||
|
|
||
|
import mlx.core as mx
|
||
|
from time_utils import time_fn
|
||
|
|
||
|
D = 16384
|
||
|
group_size = 64
|
||
|
bits = 3
|
||
|
dtype = mx.float16
|
||
|
loops = 10
|
||
|
|
||
|
|
||
|
def qmv_(x, wq, q_type):
|
||
|
for i in range(loops):
|
||
|
x = mx.quantized_matmul(
|
||
|
x,
|
||
|
*wq,
|
||
|
group_size=group_size,
|
||
|
bits=bits,
|
||
|
type=q_type,
|
||
|
)
|
||
|
return x
|
||
|
|
||
|
|
||
|
def affine_qmv(x, wq):
|
||
|
return qmv_(x, wq, "affine")
|
||
|
|
||
|
|
||
|
def affine_packed_qmv(x, wq):
|
||
|
return qmv_(x, wq, "affine-packed")
|
||
|
|
||
|
|
||
|
def time_qmv():
|
||
|
mx.random.seed(3)
|
||
|
x = mx.random.normal(shape=(1, D)).astype(dtype)
|
||
|
w = mx.random.normal(shape=(D, D)).astype(dtype)
|
||
|
wq = mx.quantize(w, group_size=group_size, bits=bits, type="affine")
|
||
|
mx.eval(x, wq)
|
||
|
time_fn(affine_qmv, x, wq)
|
||
|
|
||
|
|
||
|
def time_packed_qmv():
|
||
|
mx.random.seed(3)
|
||
|
x = mx.random.normal(shape=(1, D)).astype(dtype)
|
||
|
w = mx.random.normal(shape=(D, D)).astype(dtype)
|
||
|
wq = mx.quantize(w, group_size=group_size, bits=bits, type="affine-packed")
|
||
|
mx.eval(x, wq)
|
||
|
time_fn(affine_packed_qmv, x, wq)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
time_qmv()
|
||
|
time_packed_qmv()
|