mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	remove stale benchmarks (#527)
This commit is contained in:
		| @@ -1,198 +0,0 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| import math | ||||
| import time | ||||
|  | ||||
| import jax | ||||
| import jax.numpy as jnp | ||||
| from flax import linen as nn | ||||
|  | ||||
|  | ||||
| class RoPE(nn.Module): | ||||
|     dims: int | ||||
|     traditional: bool = False | ||||
|  | ||||
|     def _compute_rope(self, costheta, sintheta, x): | ||||
|         x1 = x[..., : self.dims // 2] | ||||
|         x2 = x[..., self.dims // 2 : self.dims] | ||||
|         rx1 = x1 * costheta - x2 * sintheta | ||||
|         rx2 = x1 * sintheta + x2 * costheta | ||||
|  | ||||
|         if self.dims < x.shape[-1]: | ||||
|             rx = jnp.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1) | ||||
|         else: | ||||
|             rx = jnp.concatenate([rx1, rx2], axis=-1) | ||||
|  | ||||
|         return rx | ||||
|  | ||||
|     def _compute_traditional_rope(self, costheta, sintheta, x): | ||||
|         x1 = x[..., ::2] | ||||
|         x2 = x[..., 1::2] | ||||
|         rx1 = x1 * costheta - x2 * sintheta | ||||
|         rx2 = x1 * sintheta + x2 * costheta | ||||
|  | ||||
|         if self.dims < x.shape[-1]: | ||||
|             raise NotImplementedError( | ||||
|                 "RoPE doesn't implement partial traditional application" | ||||
|             ) | ||||
|  | ||||
|         rx = jnp.concatenate([rx1[..., None], rx2[..., None]], axis=-1) | ||||
|  | ||||
|         return rx | ||||
|  | ||||
|     @staticmethod | ||||
|     def create_cos_sin_theta( | ||||
|         N: int, | ||||
|         D: int, | ||||
|         offset: int = 0, | ||||
|         base: float = 10000, | ||||
|         dtype=jnp.float32, | ||||
|     ): | ||||
|         D = D // 2 | ||||
|         positions = jnp.arange(offset, N, dtype=dtype) | ||||
|         freqs = jnp.exp(-jnp.arange(0, D, dtype=dtype) * (math.log(base) / D)) | ||||
|         theta = positions.reshape((-1, 1)) * freqs.reshape((1, -1)) | ||||
|         costheta = jnp.cos(theta) | ||||
|         sintheta = jnp.sin(theta) | ||||
|  | ||||
|         return costheta, sintheta | ||||
|  | ||||
|     @nn.compact | ||||
|     def __call__(self, x, offset: int = 0): | ||||
|         shape = x.shape | ||||
|         x = x.reshape((-1, shape[-2], shape[-1])) | ||||
|         N = x.shape[1] + offset | ||||
|         costheta, sintheta = RoPE.create_cos_sin_theta( | ||||
|             N, self.dims, offset=offset, dtype=x.dtype | ||||
|         ) | ||||
|  | ||||
|         rope = ( | ||||
|             self._compute_traditional_rope if self.traditional else self._compute_rope | ||||
|         ) | ||||
|         rx = rope(costheta, sintheta, x) | ||||
|  | ||||
|         return rx.reshape(shape) | ||||
|  | ||||
|  | ||||
| class LlamaAttention(nn.Module): | ||||
|     dims: int | ||||
|     num_heads: int | ||||
|     dtype: jnp.dtype | ||||
|  | ||||
|     def setup(self): | ||||
|         num_heads = self.num_heads | ||||
|         dims = self.dims | ||||
|  | ||||
|         self.rope = RoPE(dims // num_heads, True) | ||||
|         self.query_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype) | ||||
|         self.key_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype) | ||||
|         self.value_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype) | ||||
|         self.out_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype) | ||||
|  | ||||
|     def __call__(self, queries, keys, values, mask=None, cache=None): | ||||
|         queries = self.query_proj(queries) | ||||
|         keys = self.key_proj(keys) | ||||
|         values = self.value_proj(values) | ||||
|  | ||||
|         num_heads = self.num_heads | ||||
|         B, L, D = queries.shape | ||||
|         queries = queries.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3)) | ||||
|         keys = keys.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3)) | ||||
|         values = values.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3)) | ||||
|  | ||||
|         if cache is not None: | ||||
|             key_cache, value_cache = cache | ||||
|             queries = self.rope(queries, offset=key_cache.shape[2]) | ||||
|             keys = self.rope(keys, offset=key_cache.shape[2]) | ||||
|             keys = jnp.concatenate([key_cache, keys], axis=2) | ||||
|             values = jnp.concatenate([value_cache, values], axis=2) | ||||
|         else: | ||||
|             queries = self.rope(queries) | ||||
|             keys = self.rope(keys) | ||||
|  | ||||
|         # Dimensions are [batch x num heads x sequence x hidden dim] | ||||
|         scale = math.sqrt(1 / queries.shape[-1]) | ||||
|         scores = (queries * scale) @ keys.transpose((0, 1, 3, 2)) | ||||
|         if mask is not None: | ||||
|             scores = scores + mask | ||||
|         scores = jax.nn.softmax(scores, axis=-1) | ||||
|         values_hat = (scores @ values).transpose((0, 2, 1, 3)).reshape((B, L, -1)) | ||||
|  | ||||
|         return self.out_proj(values_hat), (keys, values) | ||||
|  | ||||
|  | ||||
| class LlamaEncoderLayer(nn.Module): | ||||
|     dims: int | ||||
|     mlp_dims: int | ||||
|     num_heads: int | ||||
|     dtype: jnp.dtype | ||||
|  | ||||
|     def setup(self): | ||||
|         dims = self.dims | ||||
|         mlp_dims = self.mlp_dims | ||||
|         num_heads = self.num_heads | ||||
|  | ||||
|         self.attention = LlamaAttention(dims, num_heads, dtype) | ||||
|  | ||||
|         self.norm1 = nn.RMSNorm(param_dtype=self.dtype) | ||||
|         self.norm2 = nn.RMSNorm(param_dtype=self.dtype) | ||||
|  | ||||
|         self.linear1 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype) | ||||
|         self.linear2 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype) | ||||
|         self.linear3 = nn.Dense(dims, use_bias=False, param_dtype=self.dtype) | ||||
|  | ||||
|     def __call__(self, x, mask=None, cache=None): | ||||
|         y = self.norm1(x) | ||||
|         y, cache = self.attention(y, y, y, mask, cache) | ||||
|         x = x + y | ||||
|  | ||||
|         y = self.norm2(x) | ||||
|         a = self.linear1(y) | ||||
|         b = self.linear2(y) | ||||
|         y = jax.nn.silu(a) * b | ||||
|         y = self.linear3(y) | ||||
|         x = x + y | ||||
|  | ||||
|         return x, cache | ||||
|  | ||||
|  | ||||
| def measure(model, x, cache): | ||||
|     for i in range(5): | ||||
|         y, c = model(x, mask=None, cache=cache) | ||||
|         jax.block_until_ready((y, c)) | ||||
|  | ||||
|     start = time.time() | ||||
|     for i in range(5): | ||||
|         y, c = model(x, mask=None, cache=cache) | ||||
|         jax.block_until_ready((y, c)) | ||||
|  | ||||
|     end = time.time() | ||||
|     return (end - start) * 1000 / 5 | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     H = 32 | ||||
|     D = 4096 | ||||
|     F = 43 * 256 | ||||
|     C = 1000 | ||||
|     dtype = jnp.float16 | ||||
|  | ||||
|     k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4) | ||||
|  | ||||
|     x = jax.random.normal(k1, (1, 1, D), dtype) | ||||
|     cache = [ | ||||
|         jax.random.normal(k2, [1, H, C, D // H], dtype), | ||||
|         jax.random.normal(k3, [1, H, C, D // H], dtype), | ||||
|     ] | ||||
|  | ||||
|     layer = LlamaEncoderLayer(D, F, H, dtype=dtype) | ||||
|     params = layer.init(k4, x, mask=None, cache=cache)["params"] | ||||
|  | ||||
|     @jax.jit | ||||
|     def model_fn(x, mask, cache): | ||||
|         return layer.apply({"params": params}, x, mask=mask, cache=cache) | ||||
|  | ||||
|     T = measure(model_fn, x, cache) | ||||
|  | ||||
|     print("Time per layer per token:", T, "ms") | ||||
|     print("Lower bound total time per token:", T * 32, "ms") | ||||
| @@ -1,118 +0,0 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| import math | ||||
| import time | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
| import mlx.utils | ||||
|  | ||||
|  | ||||
| class LlamaAttention(nn.Module): | ||||
|     def __init__(self, dims: int, num_heads: int): | ||||
|         super().__init__() | ||||
|         self.num_heads = num_heads | ||||
|         self.rope = nn.RoPE(dims // num_heads, True) | ||||
|         self.query_proj = nn.Linear(dims, dims, False) | ||||
|         self.key_proj = nn.Linear(dims, dims, False) | ||||
|         self.value_proj = nn.Linear(dims, dims, False) | ||||
|         self.out_proj = nn.Linear(dims, dims, False) | ||||
|  | ||||
|     def __call__(self, queries, keys, values, mask=None, cache=None): | ||||
|         queries = self.query_proj(queries) | ||||
|         keys = self.key_proj(keys) | ||||
|         values = self.value_proj(values) | ||||
|  | ||||
|         num_heads = self.num_heads | ||||
|         B, L, D = queries.shape | ||||
|         queries = mx.transpose(mx.reshape(queries, (B, L, num_heads, -1)), (0, 2, 1, 3)) | ||||
|         keys = mx.transpose(mx.reshape(keys, (B, L, num_heads, -1)), (0, 2, 1, 3)) | ||||
|         values = mx.transpose(mx.reshape(values, (B, L, num_heads, -1)), (0, 2, 1, 3)) | ||||
|  | ||||
|         if cache is not None: | ||||
|             key_cache, value_cache = cache | ||||
|             queries = self.rope(queries, offset=key_cache.shape[2]) | ||||
|             keys = self.rope(keys, offset=key_cache.shape[2]) | ||||
|             keys = mx.concatenate([key_cache, keys], axis=2) | ||||
|             values = mx.concatenate([value_cache, values], axis=2) | ||||
|         else: | ||||
|             queries = self.rope(queries) | ||||
|             keys = self.rope(keys) | ||||
|  | ||||
|         # Dimensions are [batch x num heads x sequence x hidden dim] | ||||
|         scale = mx.array(math.sqrt(1 / queries.shape[-1]), dtype=queries.dtype) | ||||
|         scores = (queries * scale) @ mx.transpose(keys, (0, 1, 3, 2)) | ||||
|         if mask is not None: | ||||
|             scores = scores + mask | ||||
|         scores = mx.softmax(scores, axis=-1) | ||||
|         values_hat = mx.reshape(mx.transpose(scores @ values, (0, 2, 1, 3)), (B, L, -1)) | ||||
|  | ||||
|         return self.out_proj(values_hat), (keys, values) | ||||
|  | ||||
|  | ||||
| class LlamaEncoderLayer(nn.Module): | ||||
|     def __init__(self, dims: int, mlp_dims: int, num_heads: int): | ||||
|         super().__init__() | ||||
|  | ||||
|         self.attention = LlamaAttention(dims, num_heads) | ||||
|  | ||||
|         self.norm1 = nn.RMSNorm(dims) | ||||
|         self.norm2 = nn.RMSNorm(dims) | ||||
|  | ||||
|         self.linear1 = nn.Linear(dims, mlp_dims, False) | ||||
|         self.linear2 = nn.Linear(dims, mlp_dims, False) | ||||
|         self.linear3 = nn.Linear(mlp_dims, dims, False) | ||||
|  | ||||
|     def __call__(self, x, mask=None, cache=None): | ||||
|         y = self.norm1(x) | ||||
|         y, cache = self.attention(y, y, y, mask, cache) | ||||
|         x = x + y | ||||
|  | ||||
|         y = self.norm2(x) | ||||
|         a = self.linear1(y) | ||||
|         b = self.linear2(y) | ||||
|         y = a * mx.sigmoid(a) * b | ||||
|         y = self.linear3(y) | ||||
|         x = x + y | ||||
|  | ||||
|         return x, cache | ||||
|  | ||||
|  | ||||
| def measure(model, x, cache): | ||||
|     for i in range(5): | ||||
|         y, c = model(x, mask=None, cache=cache) | ||||
|         mx.eval(y, c) | ||||
|  | ||||
|     start = time.time() | ||||
|     rs = [] | ||||
|     for i in range(5): | ||||
|         y, c = model(x, mask=None, cache=cache) | ||||
|         rs.append((y, c)) | ||||
|     mx.eval(rs) | ||||
|     end = time.time() | ||||
|  | ||||
|     return (end - start) * 1000 / 5 | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     H = 32 | ||||
|     D = 4096 | ||||
|     F = 43 * 256 | ||||
|     C = 1000 | ||||
|     mx.set_default_device(mx.gpu) | ||||
|     dtype = mx.float16 | ||||
|  | ||||
|     layer = LlamaEncoderLayer(D, F, H) | ||||
|     layer.update(mlx.utils.tree_map(lambda x: x.astype(dtype), layer.parameters())) | ||||
|     k1, k2, k3 = mx.random.split(mx.random.key(0), 3) | ||||
|     x = mx.random.normal([1, 1, D], dtype=dtype) | ||||
|     cache = [ | ||||
|         mx.random.normal([1, H, C, D // H], dtype=dtype), | ||||
|         mx.random.normal([1, H, C, D // H], dtype=dtype), | ||||
|     ] | ||||
|     mx.eval(x, cache) | ||||
|  | ||||
|     T = measure(layer, x, cache) | ||||
|  | ||||
|     print("Time per layer per token:", T, "ms") | ||||
|     print("Lower bound total time per token:", T * 32, "ms") | ||||
| @@ -1,199 +0,0 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| import math | ||||
| import time | ||||
|  | ||||
| import torch | ||||
| import torch.mps | ||||
| import torch.nn as nn | ||||
|  | ||||
|  | ||||
| def sync_if_needed(x): | ||||
|     if x.device != torch.device("cpu"): | ||||
|         torch.mps.synchronize() | ||||
|  | ||||
|  | ||||
| class RoPE(nn.Module): | ||||
|     def __init__(self, dims: int, traditional: bool = False): | ||||
|         super().__init__() | ||||
|         self.dims = dims | ||||
|         self.traditional = traditional | ||||
|  | ||||
|     def _compute_rope(self, costheta, sintheta, x): | ||||
|         x1 = x[..., : self.dims // 2] | ||||
|         x2 = x[..., self.dims // 2 : self.dims] | ||||
|         rx1 = x1 * costheta - x2 * sintheta | ||||
|         rx2 = x1 * sintheta + x2 * costheta | ||||
|  | ||||
|         if self.dims < x.shape[-1]: | ||||
|             rx = torch.cat([rx1, rx2, x[..., self.dims :]], dim=-1) | ||||
|         else: | ||||
|             rx = torch.cat([rx1, rx2], dim=-1) | ||||
|  | ||||
|         return rx | ||||
|  | ||||
|     def _compute_traditional_rope(self, costheta, sintheta, x): | ||||
|         x1 = x[..., ::2] | ||||
|         x2 = x[..., 1::2] | ||||
|         rx1 = x1 * costheta - x2 * sintheta | ||||
|         rx2 = x1 * sintheta + x2 * costheta | ||||
|  | ||||
|         if self.dims < x.shape[-1]: | ||||
|             raise NotImplementedError( | ||||
|                 "RoPE doesn't implement partial traditional application" | ||||
|             ) | ||||
|  | ||||
|         rx = torch.cat([rx1[..., None], rx2[..., None]], dim=-1) | ||||
|  | ||||
|         return rx | ||||
|  | ||||
|     def forward(self, x, offset: int = 0): | ||||
|         shape = x.shape | ||||
|         x = x.view(-1, shape[-2], shape[-1]) | ||||
|         N = x.shape[1] + offset | ||||
|         costheta, sintheta = RoPE.create_cos_sin_theta( | ||||
|             N, self.dims, offset=offset, device=x.device, dtype=x.dtype | ||||
|         ) | ||||
|  | ||||
|         rope = ( | ||||
|             self._compute_traditional_rope if self.traditional else self._compute_rope | ||||
|         ) | ||||
|         rx = rope(costheta, sintheta, x) | ||||
|  | ||||
|         return rx.view(*shape) | ||||
|  | ||||
|     @staticmethod | ||||
|     def create_cos_sin_theta( | ||||
|         N: int, | ||||
|         D: int, | ||||
|         offset: int = 0, | ||||
|         base: float = 10000, | ||||
|         device="cpu", | ||||
|         dtype=torch.float32, | ||||
|     ): | ||||
|         D = D // 2 | ||||
|         positions = torch.arange(offset, N, dtype=dtype, device=device) | ||||
|         freqs = torch.exp( | ||||
|             -torch.arange(0, D, dtype=dtype, device=device) * (math.log(base) / D) | ||||
|         ) | ||||
|         theta = positions.view(-1, 1) * freqs.view(1, -1) | ||||
|         costheta = torch.cos(theta) | ||||
|         sintheta = torch.sin(theta) | ||||
|  | ||||
|         return costheta, sintheta | ||||
|  | ||||
|  | ||||
| class RMSNorm(nn.Module): | ||||
|     def __init__(self, dims: int, epsilon: float = 1e-6): | ||||
|         super().__init__() | ||||
|         self.gamma = nn.Parameter(torch.ones((dims,))) | ||||
|         self.epsilon = epsilon | ||||
|  | ||||
|     def forward(self, x): | ||||
|         n = torch.rsqrt(x.square().mean(dim=-1, keepdims=True) + self.epsilon) | ||||
|         return self.gamma * x * n | ||||
|  | ||||
|  | ||||
| class LlamaAttention(nn.Module): | ||||
|     def __init__(self, dims: int, num_heads: int): | ||||
|         super().__init__() | ||||
|         self.num_heads = num_heads | ||||
|         self.rope = RoPE(dims // num_heads, True) | ||||
|         self.query_proj = nn.Linear(dims, dims, bias=False) | ||||
|         self.key_proj = nn.Linear(dims, dims, bias=False) | ||||
|         self.value_proj = nn.Linear(dims, dims, bias=False) | ||||
|         self.out_proj = nn.Linear(dims, dims, bias=False) | ||||
|  | ||||
|     def forward(self, queries, keys, values, mask=None, cache=None): | ||||
|         queries = self.query_proj(queries) | ||||
|         keys = self.key_proj(keys) | ||||
|         values = self.value_proj(values) | ||||
|  | ||||
|         num_heads = self.num_heads | ||||
|         B, L, D = queries.shape | ||||
|         queries = queries.view(B, L, num_heads, -1).permute(0, 2, 1, 3) | ||||
|         keys = keys.view(B, L, num_heads, -1).permute(0, 2, 1, 3) | ||||
|         values = values.view(B, L, num_heads, -1).permute(0, 2, 1, 3) | ||||
|  | ||||
|         if cache is not None: | ||||
|             key_cache, value_cache = cache | ||||
|             queries = self.rope(queries, offset=key_cache.shape[2]) | ||||
|             keys = self.rope(keys, offset=key_cache.shape[2]) | ||||
|             keys = torch.cat([key_cache, keys], dim=2) | ||||
|             values = torch.cat([value_cache, values], dim=2) | ||||
|         else: | ||||
|             queries = self.rope(queries) | ||||
|             keys = self.rope(keys) | ||||
|  | ||||
|         # Dimensions are [batch x num heads x sequence x hidden dim] | ||||
|         scale = math.sqrt(1 / queries.shape[-1]) | ||||
|         scores = (queries * scale) @ keys.permute(0, 1, 3, 2) | ||||
|         if mask is not None: | ||||
|             scores = scores + mask | ||||
|         scores = torch.softmax(scores, dim=-1) | ||||
|         values_hat = (scores @ values).permute(0, 2, 1, 3).reshape(B, L, -1) | ||||
|  | ||||
|         return self.out_proj(values_hat), (keys, values) | ||||
|  | ||||
|  | ||||
| class LlamaEncoderLayer(nn.Module): | ||||
|     def __init__(self, dims: int, mlp_dims: int, num_heads: int): | ||||
|         super().__init__() | ||||
|  | ||||
|         self.attention = LlamaAttention(dims, num_heads) | ||||
|  | ||||
|         self.norm1 = RMSNorm(dims) | ||||
|         self.norm2 = RMSNorm(dims) | ||||
|  | ||||
|         self.linear1 = nn.Linear(dims, mlp_dims, bias=False) | ||||
|         self.linear2 = nn.Linear(dims, mlp_dims, bias=False) | ||||
|         self.linear3 = nn.Linear(mlp_dims, dims, bias=False) | ||||
|  | ||||
|     def forward(self, x, mask=None, cache=None): | ||||
|         y = self.norm1(x) | ||||
|         y, cache = self.attention(y, y, y, mask, cache) | ||||
|         x = x + y | ||||
|  | ||||
|         y = self.norm2(x) | ||||
|         a = self.linear1(y) | ||||
|         b = self.linear2(y) | ||||
|         y = torch.nn.functional.silu(a) * b | ||||
|         y = self.linear3(y) | ||||
|         x = x + y | ||||
|  | ||||
|         return x, cache | ||||
|  | ||||
|  | ||||
| @torch.no_grad() | ||||
| def measure(model, x, cache): | ||||
|     for i in range(5): | ||||
|         y, c = model(x, mask=None, cache=cache) | ||||
|     sync_if_needed(x) | ||||
|  | ||||
|     start = time.time() | ||||
|     for i in range(5): | ||||
|         y, c = model(x, mask=None, cache=cache) | ||||
|     sync_if_needed(x) | ||||
|     end = time.time() | ||||
|     return (end - start) * 1000 / 5 | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     H = 32 | ||||
|     D = 4096 | ||||
|     F = 43 * 256 | ||||
|     C = 1000 | ||||
|     device = torch.device("mps") | ||||
|     dtype = torch.float16 | ||||
|  | ||||
|     layer = LlamaEncoderLayer(D, F, H).to(device).to(dtype) | ||||
|     x = torch.randn(1, 1, D).to(device).to(dtype) | ||||
|     cache = [ | ||||
|         torch.randn(1, H, C, D // H).to(device).to(dtype), | ||||
|         torch.randn(1, H, C, D // H).to(device).to(dtype), | ||||
|     ] | ||||
|  | ||||
|     T = measure(layer, x, cache) | ||||
|  | ||||
|     print("Time per layer per token:", T, "ms") | ||||
|     print("Lower bound total time per token:", T * 32, "ms") | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun