mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	Add a small benchmark
This commit is contained in:
		
							
								
								
									
										55
									
								
								benchmarks/python/packed_qmv_bench.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								benchmarks/python/packed_qmv_bench.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,55 @@
 | 
			
		||||
import argparse
 | 
			
		||||
import math
 | 
			
		||||
from functools import partial
 | 
			
		||||
 | 
			
		||||
import mlx.core as mx
 | 
			
		||||
from time_utils import time_fn
 | 
			
		||||
 | 
			
		||||
D = 16384
 | 
			
		||||
group_size = 64
 | 
			
		||||
bits = 3
 | 
			
		||||
dtype = mx.float16
 | 
			
		||||
loops = 10
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def qmv_(x, wq, q_type):
 | 
			
		||||
    for i in range(loops):
 | 
			
		||||
        x = mx.quantized_matmul(
 | 
			
		||||
            x,
 | 
			
		||||
            *wq,
 | 
			
		||||
            group_size=group_size,
 | 
			
		||||
            bits=bits,
 | 
			
		||||
            type=q_type,
 | 
			
		||||
        )
 | 
			
		||||
    return x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def affine_qmv(x, wq):
 | 
			
		||||
    return qmv_(x, wq, "affine")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def affine_packed_qmv(x, wq):
 | 
			
		||||
    return qmv_(x, wq, "affine-packed")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def time_qmv():
 | 
			
		||||
    mx.random.seed(3)
 | 
			
		||||
    x = mx.random.normal(shape=(1, D)).astype(dtype)
 | 
			
		||||
    w = mx.random.normal(shape=(D, D)).astype(dtype)
 | 
			
		||||
    wq = mx.quantize(w, group_size=group_size, bits=bits, type="affine")
 | 
			
		||||
    mx.eval(x, wq)
 | 
			
		||||
    time_fn(affine_qmv, x, wq)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def time_packed_qmv():
 | 
			
		||||
    mx.random.seed(3)
 | 
			
		||||
    x = mx.random.normal(shape=(1, D)).astype(dtype)
 | 
			
		||||
    w = mx.random.normal(shape=(D, D)).astype(dtype)
 | 
			
		||||
    wq = mx.quantize(w, group_size=group_size, bits=bits, type="affine-packed")
 | 
			
		||||
    mx.eval(x, wq)
 | 
			
		||||
    time_fn(affine_packed_qmv, x, wq)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    time_qmv()
 | 
			
		||||
    time_packed_qmv()
 | 
			
		||||
		Reference in New Issue
	
	Block a user