mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 02:28:13 +08:00 
			
		
		
		
	* einsum initial * fix comma break * sum axis was wrong * small cleanups * python binding * changed bindings to resemble numpy * remove todo comment * comment changes * add count of operands/inputs * fail fast if operands list is empty * ignore comma if no output * einsum path matching numpy * getting somewhere with path * remove print * it passes the first test * moved einsum tests to seperate file * seperated einsum path * moved einsum naive * remove space from equation * fast fail if no operands passed * update tests and remove printf * small cleanup * some more cleanups * removed python helper file * ack * utilize std for finding min in vector * duplicate def * remove the tuple as it was unreadable * moved einsum_naive back to ops * remaining isn't needed * avoid creating another set * cleanup * greedy path, start of naive einsum * more einsum * fix some bugs * some more fixes, tests pass * benchmark * some simplify * fix einsum and test Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com> * add a bunch more tests and fix a bunch more bugs * some docs nits --------- Co-authored-by: dc-dc-dc <dgcruz983@gmail.com> Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
		
			
				
	
	
		
			85 lines
		
	
	
		
			2.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			85 lines
		
	
	
		
			2.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright © 2024 Apple Inc.
 | 
						|
 | 
						|
import time
 | 
						|
 | 
						|
import mlx.core as mx
 | 
						|
import numpy as np
 | 
						|
 | 
						|
 | 
						|
def timeit(fn, its=100, args=[]):
 | 
						|
    for _ in range(5):
 | 
						|
        fn(*args)
 | 
						|
    tic = time.perf_counter()
 | 
						|
    for _ in range(its):
 | 
						|
        fn(*args)
 | 
						|
    toc = time.perf_counter()
 | 
						|
    return 1e3 * (toc - tic) / its
 | 
						|
 | 
						|
 | 
						|
def time_little_einsum_path():
 | 
						|
    subscripts = "ik,kj->ij"
 | 
						|
    x = mx.ones((32, 32))
 | 
						|
    y = mx.ones((32, 32))
 | 
						|
    mx_time = timeit(mx.einsum_path, args=(subscripts, x, y))
 | 
						|
 | 
						|
    x = np.array(x)
 | 
						|
    y = np.array(y)
 | 
						|
    np_time = timeit(np.einsum_path, args=(subscripts, x, y))
 | 
						|
    print("Timing little einsum path...")
 | 
						|
    print(f"MLX ... {mx_time:.3f} ms")
 | 
						|
    print(f"NumPy... {np_time:.3f} ms")
 | 
						|
 | 
						|
 | 
						|
def time_big_einsum_path():
 | 
						|
    chars = list("abcdefgh")
 | 
						|
    char_to_dim = {c: v for v, c in enumerate(chars)}
 | 
						|
 | 
						|
    num_inputs = 10
 | 
						|
    inputs = []
 | 
						|
    subscripts = []
 | 
						|
    for _ in range(num_inputs):
 | 
						|
        subscript = np.random.choice(chars, size=5, replace=False).tolist()
 | 
						|
        subscripts.append("".join(subscript))
 | 
						|
        inputs.append(np.ones(list(char_to_dim[c] for c in subscript)))
 | 
						|
    subscripts = ",".join(subscripts)
 | 
						|
 | 
						|
    np_time = timeit(np.einsum_path, args=(subscripts, *inputs))
 | 
						|
 | 
						|
    inputs = [mx.array(x) for x in inputs]
 | 
						|
    mx_time = timeit(mx.einsum_path, args=(subscripts, *inputs))
 | 
						|
    print("Timing big einsum path...")
 | 
						|
    print(f"MLX ... {mx_time:.3f} ms")
 | 
						|
    print(f"NumPy... {np_time:.3f} ms")
 | 
						|
 | 
						|
 | 
						|
def time_attention():
 | 
						|
    def regular_attention(x):
 | 
						|
        # shape [batch, sequence, num_heads, head_dim]
 | 
						|
        queries, keys, values = x, x, x
 | 
						|
        scores = queries.transpose(0, 2, 1, 3) @ keys.transpose(0, 2, 3, 1)
 | 
						|
        scores = mx.softmax(scores, axis=-1)
 | 
						|
        output = (scores @ values.transpose(0, 2, 1, 3)).swapaxes(1, 2)
 | 
						|
        mx.eval(output)
 | 
						|
 | 
						|
    def einsum_attention(x):
 | 
						|
        # shape [batch, sequence, num_heads, head_dim]
 | 
						|
        queries, keys, values = x, x, x
 | 
						|
        scores = mx.einsum("itjk,iujk->ijtu", queries, keys)
 | 
						|
        scores = mx.softmax(scores, axis=-1)
 | 
						|
        output = mx.einsum("ijtu,iujk->itjk", scores, values)
 | 
						|
        mx.eval(output)
 | 
						|
 | 
						|
    x = mx.random.uniform(shape=(8, 512, 32, 128))
 | 
						|
 | 
						|
    regular_time = timeit(regular_attention, args=(x,))
 | 
						|
    ein_time = timeit(einsum_attention, args=(x,))
 | 
						|
    print("Timing einsum attention...")
 | 
						|
    print(f"Regular ... {regular_time:.3f} ms")
 | 
						|
    print(f"Einsum ... {ein_time:.3f} ms")
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    time_little_einsum_path()
 | 
						|
    time_big_einsum_path()
 | 
						|
    time_attention()
 |