mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			119 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			119 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright © 2023 Apple Inc.
 | |
| 
 | |
| import math
 | |
| import time
 | |
| 
 | |
| import mlx.core as mx
 | |
| import mlx.nn as nn
 | |
| import mlx.utils
 | |
| 
 | |
| 
 | |
| class LlamaAttention(nn.Module):
 | |
|     def __init__(self, dims: int, num_heads: int):
 | |
|         super().__init__()
 | |
|         self.num_heads = num_heads
 | |
|         self.rope = nn.RoPE(dims // num_heads, True)
 | |
|         self.query_proj = nn.Linear(dims, dims, False)
 | |
|         self.key_proj = nn.Linear(dims, dims, False)
 | |
|         self.value_proj = nn.Linear(dims, dims, False)
 | |
|         self.out_proj = nn.Linear(dims, dims, False)
 | |
| 
 | |
|     def __call__(self, queries, keys, values, mask=None, cache=None):
 | |
|         queries = self.query_proj(queries)
 | |
|         keys = self.key_proj(keys)
 | |
|         values = self.value_proj(values)
 | |
| 
 | |
|         num_heads = self.num_heads
 | |
|         B, L, D = queries.shape
 | |
|         queries = mx.transpose(mx.reshape(queries, (B, L, num_heads, -1)), (0, 2, 1, 3))
 | |
|         keys = mx.transpose(mx.reshape(keys, (B, L, num_heads, -1)), (0, 2, 1, 3))
 | |
|         values = mx.transpose(mx.reshape(values, (B, L, num_heads, -1)), (0, 2, 1, 3))
 | |
| 
 | |
|         if cache is not None:
 | |
|             key_cache, value_cache = cache
 | |
|             queries = self.rope(queries, offset=key_cache.shape[2])
 | |
|             keys = self.rope(keys, offset=key_cache.shape[2])
 | |
|             keys = mx.concatenate([key_cache, keys], axis=2)
 | |
|             values = mx.concatenate([value_cache, values], axis=2)
 | |
|         else:
 | |
|             queries = self.rope(queries)
 | |
|             keys = self.rope(keys)
 | |
| 
 | |
|         # Dimensions are [batch x num heads x sequence x hidden dim]
 | |
|         scale = mx.array(math.sqrt(1 / queries.shape[-1]), dtype=queries.dtype)
 | |
|         scores = (queries * scale) @ mx.transpose(keys, (0, 1, 3, 2))
 | |
|         if mask is not None:
 | |
|             scores = scores + mask
 | |
|         scores = mx.softmax(scores, axis=-1)
 | |
|         values_hat = mx.reshape(mx.transpose(scores @ values, (0, 2, 1, 3)), (B, L, -1))
 | |
| 
 | |
|         return self.out_proj(values_hat), (keys, values)
 | |
| 
 | |
| 
 | |
| class LlamaEncoderLayer(nn.Module):
 | |
|     def __init__(self, dims: int, mlp_dims: int, num_heads: int):
 | |
|         super().__init__()
 | |
| 
 | |
|         self.attention = LlamaAttention(dims, num_heads)
 | |
| 
 | |
|         self.norm1 = nn.RMSNorm(dims)
 | |
|         self.norm2 = nn.RMSNorm(dims)
 | |
| 
 | |
|         self.linear1 = nn.Linear(dims, mlp_dims, False)
 | |
|         self.linear2 = nn.Linear(dims, mlp_dims, False)
 | |
|         self.linear3 = nn.Linear(mlp_dims, dims, False)
 | |
| 
 | |
|     def __call__(self, x, mask=None, cache=None):
 | |
|         y = self.norm1(x)
 | |
|         y, cache = self.attention(y, y, y, mask, cache)
 | |
|         x = x + y
 | |
| 
 | |
|         y = self.norm2(x)
 | |
|         a = self.linear1(y)
 | |
|         b = self.linear2(y)
 | |
|         y = a * mx.sigmoid(a) * b
 | |
|         y = self.linear3(y)
 | |
|         x = x + y
 | |
| 
 | |
|         return x, cache
 | |
| 
 | |
| 
 | |
| def measure(model, x, cache):
 | |
|     for i in range(5):
 | |
|         y, c = model(x, mask=None, cache=cache)
 | |
|         mx.eval(y, c)
 | |
| 
 | |
|     start = time.time()
 | |
|     rs = []
 | |
|     for i in range(5):
 | |
|         y, c = model(x, mask=None, cache=cache)
 | |
|         rs.append((y, c))
 | |
|     mx.eval(rs)
 | |
|     end = time.time()
 | |
| 
 | |
|     return (end - start) * 1000 / 5
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     H = 32
 | |
|     D = 4096
 | |
|     F = 43 * 256
 | |
|     C = 1000
 | |
|     mx.set_default_device(mx.gpu)
 | |
|     dtype = mx.float16
 | |
| 
 | |
|     layer = LlamaEncoderLayer(D, F, H)
 | |
|     layer.update(mlx.utils.tree_map(lambda x: x.astype(dtype), layer.parameters()))
 | |
|     k1, k2, k3 = mx.random.split(mx.random.key(0), 3)
 | |
|     x = mx.random.normal([1, 1, D], dtype=dtype)
 | |
|     cache = [
 | |
|         mx.random.normal([1, H, C, D // H], dtype=dtype),
 | |
|         mx.random.normal([1, H, C, D // H], dtype=dtype),
 | |
|     ]
 | |
|     mx.eval(x, cache)
 | |
| 
 | |
|     T = measure(layer, x, cache)
 | |
| 
 | |
|     print("Time per layer per token:", T, "ms")
 | |
|     print("Lower bound total time per token:", T * 32, "ms")
 | 
