mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 02:28:13 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			133 lines
		
	
	
		
			2.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			133 lines
		
	
	
		
			2.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright © 2023 Apple Inc.
 | 
						|
 | 
						|
import argparse
 | 
						|
 | 
						|
import mlx.core as mx
 | 
						|
from time_utils import time_fn
 | 
						|
 | 
						|
 | 
						|
def time_add():
 | 
						|
    a = mx.random.uniform(shape=(32, 1024, 1024))
 | 
						|
    b = mx.random.uniform(shape=(32, 1024, 1024))
 | 
						|
    mx.eval(a, b)
 | 
						|
    time_fn(mx.add, a, b)
 | 
						|
 | 
						|
    aT = mx.transpose(a, [0, 2, 1])
 | 
						|
    mx.eval(aT)
 | 
						|
 | 
						|
    def transpose_add(a, b):
 | 
						|
        return mx.add(a, b)
 | 
						|
 | 
						|
    time_fn(transpose_add, aT, b)
 | 
						|
 | 
						|
    b = mx.random.uniform(shape=(1024,))
 | 
						|
    mx.eval(b)
 | 
						|
 | 
						|
    def slice_add(a, b):
 | 
						|
        return mx.add(a, b)
 | 
						|
 | 
						|
    time_fn(slice_add, a, b)
 | 
						|
 | 
						|
    b = mx.reshape(b, (1, 1024, 1))
 | 
						|
    mx.eval(b)
 | 
						|
 | 
						|
    def mid_slice_add(a, b):
 | 
						|
        return mx.add(a, b)
 | 
						|
 | 
						|
    time_fn(mid_slice_add, a, b)
 | 
						|
 | 
						|
 | 
						|
def time_matmul():
 | 
						|
    a = mx.random.uniform(shape=(1024, 1024))
 | 
						|
    b = mx.random.uniform(shape=(1024, 1024))
 | 
						|
    mx.eval(a, b)
 | 
						|
    time_fn(mx.matmul, a, b)
 | 
						|
 | 
						|
 | 
						|
def time_maximum():
 | 
						|
    a = mx.random.uniform(shape=(32, 1024, 1024))
 | 
						|
    b = mx.random.uniform(shape=(32, 1024, 1024))
 | 
						|
    mx.eval(a, b)
 | 
						|
    time_fn(mx.maximum, a, b)
 | 
						|
 | 
						|
 | 
						|
def time_max():
 | 
						|
    a = mx.random.uniform(shape=(32, 1024, 1024))
 | 
						|
    a[1, 1] = mx.nan
 | 
						|
    mx.eval(a)
 | 
						|
    time_fn(mx.max, a, 0)
 | 
						|
 | 
						|
 | 
						|
def time_min():
 | 
						|
    a = mx.random.uniform(shape=(32, 1024, 1024))
 | 
						|
    a[1, 1] = mx.nan
 | 
						|
    mx.eval(a)
 | 
						|
    time_fn(mx.min, a, 0)
 | 
						|
 | 
						|
 | 
						|
def time_negative():
 | 
						|
    a = mx.random.uniform(shape=(10000, 1000))
 | 
						|
    mx.eval(a)
 | 
						|
 | 
						|
    def negative(a):
 | 
						|
        return -a
 | 
						|
 | 
						|
    mx.eval(a)
 | 
						|
 | 
						|
    time_fn(negative, a)
 | 
						|
 | 
						|
 | 
						|
def time_exp():
 | 
						|
    a = mx.random.uniform(shape=(1000, 100))
 | 
						|
    mx.eval(a)
 | 
						|
    time_fn(mx.exp, a)
 | 
						|
 | 
						|
 | 
						|
def time_logsumexp():
 | 
						|
    a = mx.random.uniform(shape=(64, 10, 10000))
 | 
						|
    mx.eval(a)
 | 
						|
    time_fn(mx.logsumexp, a, axis=-1)
 | 
						|
 | 
						|
 | 
						|
def time_take():
 | 
						|
    a = mx.random.uniform(shape=(10000, 500))
 | 
						|
    ids = mx.random.randint(low=0, high=10000, shape=(20, 10))
 | 
						|
    ids = [mx.reshape(idx, (-1,)) for idx in ids]
 | 
						|
    mx.eval(ids)
 | 
						|
 | 
						|
    def random_take():
 | 
						|
        return [mx.take(a, idx, 0) for idx in ids]
 | 
						|
 | 
						|
    time_fn(random_take)
 | 
						|
 | 
						|
 | 
						|
def time_reshape_transposed():
 | 
						|
    x = mx.random.uniform(shape=(256, 256, 128))
 | 
						|
    mx.eval(x)
 | 
						|
 | 
						|
    def reshape_transposed():
 | 
						|
        return mx.reshape(mx.transpose(x, (1, 0, 2)), (-1,))
 | 
						|
 | 
						|
    time_fn(reshape_transposed)
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    parser = argparse.ArgumentParser("MLX benchmarks.")
 | 
						|
    parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
 | 
						|
    args = parser.parse_args()
 | 
						|
    if args.gpu:
 | 
						|
        mx.set_default_device(mx.gpu)
 | 
						|
    else:
 | 
						|
        mx.set_default_device(mx.cpu)
 | 
						|
 | 
						|
    time_add()
 | 
						|
    time_matmul()
 | 
						|
    time_min()
 | 
						|
    time_max()
 | 
						|
    time_maximum()
 | 
						|
    time_exp()
 | 
						|
    time_negative()
 | 
						|
    time_logsumexp()
 | 
						|
    time_take()
 | 
						|
    time_reshape_transposed()
 |