mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
85 lines
2.5 KiB
Python
85 lines
2.5 KiB
Python
![]() |
# Copyright © 2024 Apple Inc.
|
||
|
|
||
|
import time
|
||
|
|
||
|
import mlx.core as mx
|
||
|
import numpy as np
|
||
|
|
||
|
|
||
|
def timeit(fn, its=100, args=[]):
|
||
|
for _ in range(5):
|
||
|
fn(*args)
|
||
|
tic = time.perf_counter()
|
||
|
for _ in range(its):
|
||
|
fn(*args)
|
||
|
toc = time.perf_counter()
|
||
|
return 1e3 * (toc - tic) / its
|
||
|
|
||
|
|
||
|
def time_little_einsum_path():
|
||
|
subscripts = "ik,kj->ij"
|
||
|
x = mx.ones((32, 32))
|
||
|
y = mx.ones((32, 32))
|
||
|
mx_time = timeit(mx.einsum_path, args=(subscripts, x, y))
|
||
|
|
||
|
x = np.array(x)
|
||
|
y = np.array(y)
|
||
|
np_time = timeit(np.einsum_path, args=(subscripts, x, y))
|
||
|
print("Timing little einsum path...")
|
||
|
print(f"MLX ... {mx_time:.3f} ms")
|
||
|
print(f"NumPy... {np_time:.3f} ms")
|
||
|
|
||
|
|
||
|
def time_big_einsum_path():
|
||
|
chars = list("abcdefgh")
|
||
|
char_to_dim = {c: v for v, c in enumerate(chars)}
|
||
|
|
||
|
num_inputs = 10
|
||
|
inputs = []
|
||
|
subscripts = []
|
||
|
for _ in range(num_inputs):
|
||
|
subscript = np.random.choice(chars, size=5, replace=False).tolist()
|
||
|
subscripts.append("".join(subscript))
|
||
|
inputs.append(np.ones(list(char_to_dim[c] for c in subscript)))
|
||
|
subscripts = ",".join(subscripts)
|
||
|
|
||
|
np_time = timeit(np.einsum_path, args=(subscripts, *inputs))
|
||
|
|
||
|
inputs = [mx.array(x) for x in inputs]
|
||
|
mx_time = timeit(mx.einsum_path, args=(subscripts, *inputs))
|
||
|
print("Timing big einsum path...")
|
||
|
print(f"MLX ... {mx_time:.3f} ms")
|
||
|
print(f"NumPy... {np_time:.3f} ms")
|
||
|
|
||
|
|
||
|
def time_attention():
|
||
|
def regular_attention(x):
|
||
|
# shape [batch, sequence, num_heads, head_dim]
|
||
|
queries, keys, values = x, x, x
|
||
|
scores = queries.transpose(0, 2, 1, 3) @ keys.transpose(0, 2, 3, 1)
|
||
|
scores = mx.softmax(scores, axis=-1)
|
||
|
output = (scores @ values.transpose(0, 2, 1, 3)).swapaxes(1, 2)
|
||
|
mx.eval(output)
|
||
|
|
||
|
def einsum_attention(x):
|
||
|
# shape [batch, sequence, num_heads, head_dim]
|
||
|
queries, keys, values = x, x, x
|
||
|
scores = mx.einsum("itjk,iujk->ijtu", queries, keys)
|
||
|
scores = mx.softmax(scores, axis=-1)
|
||
|
output = mx.einsum("ijtu,iujk->itjk", scores, values)
|
||
|
mx.eval(output)
|
||
|
|
||
|
x = mx.random.uniform(shape=(8, 512, 32, 128))
|
||
|
|
||
|
regular_time = timeit(regular_attention, args=(x,))
|
||
|
ein_time = timeit(einsum_attention, args=(x,))
|
||
|
print("Timing einsum attention...")
|
||
|
print(f"Regular ... {regular_time:.3f} ms")
|
||
|
print(f"Einsum ... {ein_time:.3f} ms")
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
time_little_einsum_path()
|
||
|
time_big_einsum_path()
|
||
|
time_attention()
|