mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-26 07:48:08 +08:00
awni's commit files
This commit is contained in:
197
benchmarks/python/llama_torch_bench.py
Normal file
197
benchmarks/python/llama_torch_bench.py
Normal file
@@ -0,0 +1,197 @@
|
||||
import math
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.mps
|
||||
|
||||
|
||||
def sync_if_needed(x):
|
||||
if x.device != torch.device("cpu"):
|
||||
torch.mps.synchronize()
|
||||
|
||||
|
||||
class RoPE(nn.Module):
|
||||
def __init__(self, dims: int, traditional: bool = False):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
self.traditional = traditional
|
||||
|
||||
def _compute_rope(self, costheta, sintheta, x):
|
||||
x1 = x[..., : self.dims // 2]
|
||||
x2 = x[..., self.dims // 2 : self.dims]
|
||||
rx1 = x1 * costheta - x2 * sintheta
|
||||
rx2 = x1 * sintheta + x2 * costheta
|
||||
|
||||
if self.dims < x.shape[-1]:
|
||||
rx = torch.cat([rx1, rx2, x[..., self.dims :]], dim=-1)
|
||||
else:
|
||||
rx = torch.cat([rx1, rx2], dim=-1)
|
||||
|
||||
return rx
|
||||
|
||||
def _compute_traditional_rope(self, costheta, sintheta, x):
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
rx1 = x1 * costheta - x2 * sintheta
|
||||
rx2 = x1 * sintheta + x2 * costheta
|
||||
|
||||
if self.dims < x.shape[-1]:
|
||||
raise NotImplementedError(
|
||||
"RoPE doesn't implement partial traditional application"
|
||||
)
|
||||
|
||||
rx = torch.cat([rx1[..., None], rx2[..., None]], dim=-1)
|
||||
|
||||
return rx
|
||||
|
||||
def forward(self, x, offset: int = 0):
|
||||
shape = x.shape
|
||||
x = x.view(-1, shape[-2], shape[-1])
|
||||
N = x.shape[1] + offset
|
||||
costheta, sintheta = RoPE.create_cos_sin_theta(
|
||||
N, self.dims, offset=offset, device=x.device, dtype=x.dtype
|
||||
)
|
||||
|
||||
rope = (
|
||||
self._compute_traditional_rope if self.traditional else self._compute_rope
|
||||
)
|
||||
rx = rope(costheta, sintheta, x)
|
||||
|
||||
return rx.view(*shape)
|
||||
|
||||
@staticmethod
|
||||
def create_cos_sin_theta(
|
||||
N: int,
|
||||
D: int,
|
||||
offset: int = 0,
|
||||
base: float = 10000,
|
||||
device="cpu",
|
||||
dtype=torch.float32,
|
||||
):
|
||||
D = D // 2
|
||||
positions = torch.arange(offset, N, dtype=dtype, device=device)
|
||||
freqs = torch.exp(
|
||||
-torch.arange(0, D, dtype=dtype, device=device) * (math.log(base) / D)
|
||||
)
|
||||
theta = positions.view(-1, 1) * freqs.view(1, -1)
|
||||
costheta = torch.cos(theta)
|
||||
sintheta = torch.sin(theta)
|
||||
|
||||
return costheta, sintheta
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, epsilon: float = 1e-6):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.ones((dims,)))
|
||||
self.epsilon = epsilon
|
||||
|
||||
def forward(self, x):
|
||||
n = torch.rsqrt(x.square().mean(dim=-1, keepdims=True) + self.epsilon)
|
||||
return self.gamma * x * n
|
||||
|
||||
|
||||
class LlamaAttention(nn.Module):
|
||||
def __init__(self, dims: int, num_heads: int):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.rope = RoPE(dims // num_heads, True)
|
||||
self.query_proj = nn.Linear(dims, dims, bias=False)
|
||||
self.key_proj = nn.Linear(dims, dims, bias=False)
|
||||
self.value_proj = nn.Linear(dims, dims, bias=False)
|
||||
self.out_proj = nn.Linear(dims, dims, bias=False)
|
||||
|
||||
def forward(self, queries, keys, values, mask=None, cache=None):
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
|
||||
num_heads = self.num_heads
|
||||
B, L, D = queries.shape
|
||||
queries = queries.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
|
||||
keys = keys.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
|
||||
values = values.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = torch.cat([key_cache, keys], dim=2)
|
||||
values = torch.cat([value_cache, values], dim=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
scores = (queries * scale) @ keys.permute(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
scores = torch.softmax(scores, dim=-1)
|
||||
values_hat = (scores @ values).permute(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return self.out_proj(values_hat), (keys, values)
|
||||
|
||||
|
||||
class LlamaEncoderLayer(nn.Module):
|
||||
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
|
||||
super().__init__()
|
||||
|
||||
self.attention = LlamaAttention(dims, num_heads)
|
||||
|
||||
self.norm1 = RMSNorm(dims)
|
||||
self.norm2 = RMSNorm(dims)
|
||||
|
||||
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
|
||||
self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
|
||||
self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
|
||||
|
||||
def forward(self, x, mask=None, cache=None):
|
||||
y = self.norm1(x)
|
||||
y, cache = self.attention(y, y, y, mask, cache)
|
||||
x = x + y
|
||||
|
||||
y = self.norm2(x)
|
||||
a = self.linear1(y)
|
||||
b = self.linear2(y)
|
||||
y = torch.nn.functional.silu(a) * b
|
||||
y = self.linear3(y)
|
||||
x = x + y
|
||||
|
||||
return x, cache
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def measure(model, x, cache):
|
||||
for i in range(5):
|
||||
y, c = model(x, mask=None, cache=cache)
|
||||
sync_if_needed(x)
|
||||
|
||||
start = time.time()
|
||||
for i in range(5):
|
||||
y, c = model(x, mask=None, cache=cache)
|
||||
sync_if_needed(x)
|
||||
end = time.time()
|
||||
return (end - start) * 1000 / 5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
H = 32
|
||||
D = 4096
|
||||
F = 43 * 256
|
||||
C = 1000
|
||||
device = torch.device("mps")
|
||||
dtype = torch.float16
|
||||
|
||||
layer = LlamaEncoderLayer(D, F, H).to(device).to(dtype)
|
||||
x = torch.randn(1, 1, D).to(device).to(dtype)
|
||||
cache = [
|
||||
torch.randn(1, H, C, D // H).to(device).to(dtype),
|
||||
torch.randn(1, H, C, D // H).to(device).to(dtype),
|
||||
]
|
||||
|
||||
T = measure(layer, x, cache)
|
||||
|
||||
print("Time per layer per token:", T, "ms")
|
||||
print("Lower bound total time per token:", T * 32, "ms")
|
Reference in New Issue
Block a user