From 86e0c79467177eb0251118e268a4a73ede6d7e7d Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 22 Jan 2024 22:17:58 -0800 Subject: [PATCH] remove stale benchmarks (#527) --- benchmarks/python/llama_jax_bench.py | 198 ------------------------ benchmarks/python/llama_mlx_bench.py | 118 --------------- benchmarks/python/llama_torch_bench.py | 199 ------------------------- 3 files changed, 515 deletions(-) delete mode 100644 benchmarks/python/llama_jax_bench.py delete mode 100644 benchmarks/python/llama_mlx_bench.py delete mode 100644 benchmarks/python/llama_torch_bench.py diff --git a/benchmarks/python/llama_jax_bench.py b/benchmarks/python/llama_jax_bench.py deleted file mode 100644 index 9f647426ed..0000000000 --- a/benchmarks/python/llama_jax_bench.py +++ /dev/null @@ -1,198 +0,0 @@ -# Copyright © 2023 Apple Inc. - -import math -import time - -import jax -import jax.numpy as jnp -from flax import linen as nn - - -class RoPE(nn.Module): - dims: int - traditional: bool = False - - def _compute_rope(self, costheta, sintheta, x): - x1 = x[..., : self.dims // 2] - x2 = x[..., self.dims // 2 : self.dims] - rx1 = x1 * costheta - x2 * sintheta - rx2 = x1 * sintheta + x2 * costheta - - if self.dims < x.shape[-1]: - rx = jnp.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1) - else: - rx = jnp.concatenate([rx1, rx2], axis=-1) - - return rx - - def _compute_traditional_rope(self, costheta, sintheta, x): - x1 = x[..., ::2] - x2 = x[..., 1::2] - rx1 = x1 * costheta - x2 * sintheta - rx2 = x1 * sintheta + x2 * costheta - - if self.dims < x.shape[-1]: - raise NotImplementedError( - "RoPE doesn't implement partial traditional application" - ) - - rx = jnp.concatenate([rx1[..., None], rx2[..., None]], axis=-1) - - return rx - - @staticmethod - def create_cos_sin_theta( - N: int, - D: int, - offset: int = 0, - base: float = 10000, - dtype=jnp.float32, - ): - D = D // 2 - positions = jnp.arange(offset, N, dtype=dtype) - freqs = jnp.exp(-jnp.arange(0, D, dtype=dtype) * (math.log(base) / D)) - theta = positions.reshape((-1, 1)) * freqs.reshape((1, -1)) - costheta = jnp.cos(theta) - sintheta = jnp.sin(theta) - - return costheta, sintheta - - @nn.compact - def __call__(self, x, offset: int = 0): - shape = x.shape - x = x.reshape((-1, shape[-2], shape[-1])) - N = x.shape[1] + offset - costheta, sintheta = RoPE.create_cos_sin_theta( - N, self.dims, offset=offset, dtype=x.dtype - ) - - rope = ( - self._compute_traditional_rope if self.traditional else self._compute_rope - ) - rx = rope(costheta, sintheta, x) - - return rx.reshape(shape) - - -class LlamaAttention(nn.Module): - dims: int - num_heads: int - dtype: jnp.dtype - - def setup(self): - num_heads = self.num_heads - dims = self.dims - - self.rope = RoPE(dims // num_heads, True) - self.query_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype) - self.key_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype) - self.value_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype) - self.out_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype) - - def __call__(self, queries, keys, values, mask=None, cache=None): - queries = self.query_proj(queries) - keys = self.key_proj(keys) - values = self.value_proj(values) - - num_heads = self.num_heads - B, L, D = queries.shape - queries = queries.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3)) - keys = keys.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3)) - values = values.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3)) - - if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = jnp.concatenate([key_cache, keys], axis=2) - values = jnp.concatenate([value_cache, values], axis=2) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - # Dimensions are [batch x num heads x sequence x hidden dim] - scale = math.sqrt(1 / queries.shape[-1]) - scores = (queries * scale) @ keys.transpose((0, 1, 3, 2)) - if mask is not None: - scores = scores + mask - scores = jax.nn.softmax(scores, axis=-1) - values_hat = (scores @ values).transpose((0, 2, 1, 3)).reshape((B, L, -1)) - - return self.out_proj(values_hat), (keys, values) - - -class LlamaEncoderLayer(nn.Module): - dims: int - mlp_dims: int - num_heads: int - dtype: jnp.dtype - - def setup(self): - dims = self.dims - mlp_dims = self.mlp_dims - num_heads = self.num_heads - - self.attention = LlamaAttention(dims, num_heads, dtype) - - self.norm1 = nn.RMSNorm(param_dtype=self.dtype) - self.norm2 = nn.RMSNorm(param_dtype=self.dtype) - - self.linear1 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype) - self.linear2 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype) - self.linear3 = nn.Dense(dims, use_bias=False, param_dtype=self.dtype) - - def __call__(self, x, mask=None, cache=None): - y = self.norm1(x) - y, cache = self.attention(y, y, y, mask, cache) - x = x + y - - y = self.norm2(x) - a = self.linear1(y) - b = self.linear2(y) - y = jax.nn.silu(a) * b - y = self.linear3(y) - x = x + y - - return x, cache - - -def measure(model, x, cache): - for i in range(5): - y, c = model(x, mask=None, cache=cache) - jax.block_until_ready((y, c)) - - start = time.time() - for i in range(5): - y, c = model(x, mask=None, cache=cache) - jax.block_until_ready((y, c)) - - end = time.time() - return (end - start) * 1000 / 5 - - -if __name__ == "__main__": - H = 32 - D = 4096 - F = 43 * 256 - C = 1000 - dtype = jnp.float16 - - k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4) - - x = jax.random.normal(k1, (1, 1, D), dtype) - cache = [ - jax.random.normal(k2, [1, H, C, D // H], dtype), - jax.random.normal(k3, [1, H, C, D // H], dtype), - ] - - layer = LlamaEncoderLayer(D, F, H, dtype=dtype) - params = layer.init(k4, x, mask=None, cache=cache)["params"] - - @jax.jit - def model_fn(x, mask, cache): - return layer.apply({"params": params}, x, mask=mask, cache=cache) - - T = measure(model_fn, x, cache) - - print("Time per layer per token:", T, "ms") - print("Lower bound total time per token:", T * 32, "ms") diff --git a/benchmarks/python/llama_mlx_bench.py b/benchmarks/python/llama_mlx_bench.py deleted file mode 100644 index b26ec24bd6..0000000000 --- a/benchmarks/python/llama_mlx_bench.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright © 2023 Apple Inc. - -import math -import time - -import mlx.core as mx -import mlx.nn as nn -import mlx.utils - - -class LlamaAttention(nn.Module): - def __init__(self, dims: int, num_heads: int): - super().__init__() - self.num_heads = num_heads - self.rope = nn.RoPE(dims // num_heads, True) - self.query_proj = nn.Linear(dims, dims, False) - self.key_proj = nn.Linear(dims, dims, False) - self.value_proj = nn.Linear(dims, dims, False) - self.out_proj = nn.Linear(dims, dims, False) - - def __call__(self, queries, keys, values, mask=None, cache=None): - queries = self.query_proj(queries) - keys = self.key_proj(keys) - values = self.value_proj(values) - - num_heads = self.num_heads - B, L, D = queries.shape - queries = mx.transpose(mx.reshape(queries, (B, L, num_heads, -1)), (0, 2, 1, 3)) - keys = mx.transpose(mx.reshape(keys, (B, L, num_heads, -1)), (0, 2, 1, 3)) - values = mx.transpose(mx.reshape(values, (B, L, num_heads, -1)), (0, 2, 1, 3)) - - if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - # Dimensions are [batch x num heads x sequence x hidden dim] - scale = mx.array(math.sqrt(1 / queries.shape[-1]), dtype=queries.dtype) - scores = (queries * scale) @ mx.transpose(keys, (0, 1, 3, 2)) - if mask is not None: - scores = scores + mask - scores = mx.softmax(scores, axis=-1) - values_hat = mx.reshape(mx.transpose(scores @ values, (0, 2, 1, 3)), (B, L, -1)) - - return self.out_proj(values_hat), (keys, values) - - -class LlamaEncoderLayer(nn.Module): - def __init__(self, dims: int, mlp_dims: int, num_heads: int): - super().__init__() - - self.attention = LlamaAttention(dims, num_heads) - - self.norm1 = nn.RMSNorm(dims) - self.norm2 = nn.RMSNorm(dims) - - self.linear1 = nn.Linear(dims, mlp_dims, False) - self.linear2 = nn.Linear(dims, mlp_dims, False) - self.linear3 = nn.Linear(mlp_dims, dims, False) - - def __call__(self, x, mask=None, cache=None): - y = self.norm1(x) - y, cache = self.attention(y, y, y, mask, cache) - x = x + y - - y = self.norm2(x) - a = self.linear1(y) - b = self.linear2(y) - y = a * mx.sigmoid(a) * b - y = self.linear3(y) - x = x + y - - return x, cache - - -def measure(model, x, cache): - for i in range(5): - y, c = model(x, mask=None, cache=cache) - mx.eval(y, c) - - start = time.time() - rs = [] - for i in range(5): - y, c = model(x, mask=None, cache=cache) - rs.append((y, c)) - mx.eval(rs) - end = time.time() - - return (end - start) * 1000 / 5 - - -if __name__ == "__main__": - H = 32 - D = 4096 - F = 43 * 256 - C = 1000 - mx.set_default_device(mx.gpu) - dtype = mx.float16 - - layer = LlamaEncoderLayer(D, F, H) - layer.update(mlx.utils.tree_map(lambda x: x.astype(dtype), layer.parameters())) - k1, k2, k3 = mx.random.split(mx.random.key(0), 3) - x = mx.random.normal([1, 1, D], dtype=dtype) - cache = [ - mx.random.normal([1, H, C, D // H], dtype=dtype), - mx.random.normal([1, H, C, D // H], dtype=dtype), - ] - mx.eval(x, cache) - - T = measure(layer, x, cache) - - print("Time per layer per token:", T, "ms") - print("Lower bound total time per token:", T * 32, "ms") diff --git a/benchmarks/python/llama_torch_bench.py b/benchmarks/python/llama_torch_bench.py deleted file mode 100644 index 5ee4ddd9cd..0000000000 --- a/benchmarks/python/llama_torch_bench.py +++ /dev/null @@ -1,199 +0,0 @@ -# Copyright © 2023 Apple Inc. - -import math -import time - -import torch -import torch.mps -import torch.nn as nn - - -def sync_if_needed(x): - if x.device != torch.device("cpu"): - torch.mps.synchronize() - - -class RoPE(nn.Module): - def __init__(self, dims: int, traditional: bool = False): - super().__init__() - self.dims = dims - self.traditional = traditional - - def _compute_rope(self, costheta, sintheta, x): - x1 = x[..., : self.dims // 2] - x2 = x[..., self.dims // 2 : self.dims] - rx1 = x1 * costheta - x2 * sintheta - rx2 = x1 * sintheta + x2 * costheta - - if self.dims < x.shape[-1]: - rx = torch.cat([rx1, rx2, x[..., self.dims :]], dim=-1) - else: - rx = torch.cat([rx1, rx2], dim=-1) - - return rx - - def _compute_traditional_rope(self, costheta, sintheta, x): - x1 = x[..., ::2] - x2 = x[..., 1::2] - rx1 = x1 * costheta - x2 * sintheta - rx2 = x1 * sintheta + x2 * costheta - - if self.dims < x.shape[-1]: - raise NotImplementedError( - "RoPE doesn't implement partial traditional application" - ) - - rx = torch.cat([rx1[..., None], rx2[..., None]], dim=-1) - - return rx - - def forward(self, x, offset: int = 0): - shape = x.shape - x = x.view(-1, shape[-2], shape[-1]) - N = x.shape[1] + offset - costheta, sintheta = RoPE.create_cos_sin_theta( - N, self.dims, offset=offset, device=x.device, dtype=x.dtype - ) - - rope = ( - self._compute_traditional_rope if self.traditional else self._compute_rope - ) - rx = rope(costheta, sintheta, x) - - return rx.view(*shape) - - @staticmethod - def create_cos_sin_theta( - N: int, - D: int, - offset: int = 0, - base: float = 10000, - device="cpu", - dtype=torch.float32, - ): - D = D // 2 - positions = torch.arange(offset, N, dtype=dtype, device=device) - freqs = torch.exp( - -torch.arange(0, D, dtype=dtype, device=device) * (math.log(base) / D) - ) - theta = positions.view(-1, 1) * freqs.view(1, -1) - costheta = torch.cos(theta) - sintheta = torch.sin(theta) - - return costheta, sintheta - - -class RMSNorm(nn.Module): - def __init__(self, dims: int, epsilon: float = 1e-6): - super().__init__() - self.gamma = nn.Parameter(torch.ones((dims,))) - self.epsilon = epsilon - - def forward(self, x): - n = torch.rsqrt(x.square().mean(dim=-1, keepdims=True) + self.epsilon) - return self.gamma * x * n - - -class LlamaAttention(nn.Module): - def __init__(self, dims: int, num_heads: int): - super().__init__() - self.num_heads = num_heads - self.rope = RoPE(dims // num_heads, True) - self.query_proj = nn.Linear(dims, dims, bias=False) - self.key_proj = nn.Linear(dims, dims, bias=False) - self.value_proj = nn.Linear(dims, dims, bias=False) - self.out_proj = nn.Linear(dims, dims, bias=False) - - def forward(self, queries, keys, values, mask=None, cache=None): - queries = self.query_proj(queries) - keys = self.key_proj(keys) - values = self.value_proj(values) - - num_heads = self.num_heads - B, L, D = queries.shape - queries = queries.view(B, L, num_heads, -1).permute(0, 2, 1, 3) - keys = keys.view(B, L, num_heads, -1).permute(0, 2, 1, 3) - values = values.view(B, L, num_heads, -1).permute(0, 2, 1, 3) - - if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = torch.cat([key_cache, keys], dim=2) - values = torch.cat([value_cache, values], dim=2) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - # Dimensions are [batch x num heads x sequence x hidden dim] - scale = math.sqrt(1 / queries.shape[-1]) - scores = (queries * scale) @ keys.permute(0, 1, 3, 2) - if mask is not None: - scores = scores + mask - scores = torch.softmax(scores, dim=-1) - values_hat = (scores @ values).permute(0, 2, 1, 3).reshape(B, L, -1) - - return self.out_proj(values_hat), (keys, values) - - -class LlamaEncoderLayer(nn.Module): - def __init__(self, dims: int, mlp_dims: int, num_heads: int): - super().__init__() - - self.attention = LlamaAttention(dims, num_heads) - - self.norm1 = RMSNorm(dims) - self.norm2 = RMSNorm(dims) - - self.linear1 = nn.Linear(dims, mlp_dims, bias=False) - self.linear2 = nn.Linear(dims, mlp_dims, bias=False) - self.linear3 = nn.Linear(mlp_dims, dims, bias=False) - - def forward(self, x, mask=None, cache=None): - y = self.norm1(x) - y, cache = self.attention(y, y, y, mask, cache) - x = x + y - - y = self.norm2(x) - a = self.linear1(y) - b = self.linear2(y) - y = torch.nn.functional.silu(a) * b - y = self.linear3(y) - x = x + y - - return x, cache - - -@torch.no_grad() -def measure(model, x, cache): - for i in range(5): - y, c = model(x, mask=None, cache=cache) - sync_if_needed(x) - - start = time.time() - for i in range(5): - y, c = model(x, mask=None, cache=cache) - sync_if_needed(x) - end = time.time() - return (end - start) * 1000 / 5 - - -if __name__ == "__main__": - H = 32 - D = 4096 - F = 43 * 256 - C = 1000 - device = torch.device("mps") - dtype = torch.float16 - - layer = LlamaEncoderLayer(D, F, H).to(device).to(dtype) - x = torch.randn(1, 1, D).to(device).to(dtype) - cache = [ - torch.randn(1, H, C, D // H).to(device).to(dtype), - torch.randn(1, H, C, D // H).to(device).to(dtype), - ] - - T = measure(layer, x, cache) - - print("Time per layer per token:", T, "ms") - print("Lower bound total time per token:", T * 32, "ms")