2023-12-01 03:12:53 +08:00
|
|
|
# Copyright © 2023 Apple Inc.
|
|
|
|
|
2023-11-30 02:52:08 +08:00
|
|
|
import math
|
|
|
|
import time
|
|
|
|
|
|
|
|
import mlx.core as mx
|
|
|
|
import mlx.nn as nn
|
|
|
|
import mlx.utils
|
|
|
|
|
|
|
|
|
|
|
|
class LlamaAttention(nn.Module):
|
|
|
|
def __init__(self, dims: int, num_heads: int):
|
|
|
|
super().__init__()
|
|
|
|
self.num_heads = num_heads
|
|
|
|
self.rope = nn.RoPE(dims // num_heads, True)
|
|
|
|
self.query_proj = nn.Linear(dims, dims, False)
|
|
|
|
self.key_proj = nn.Linear(dims, dims, False)
|
|
|
|
self.value_proj = nn.Linear(dims, dims, False)
|
|
|
|
self.out_proj = nn.Linear(dims, dims, False)
|
|
|
|
|
|
|
|
def __call__(self, queries, keys, values, mask=None, cache=None):
|
|
|
|
queries = self.query_proj(queries)
|
|
|
|
keys = self.key_proj(keys)
|
|
|
|
values = self.value_proj(values)
|
|
|
|
|
|
|
|
num_heads = self.num_heads
|
|
|
|
B, L, D = queries.shape
|
|
|
|
queries = mx.transpose(mx.reshape(queries, (B, L, num_heads, -1)), (0, 2, 1, 3))
|
|
|
|
keys = mx.transpose(mx.reshape(keys, (B, L, num_heads, -1)), (0, 2, 1, 3))
|
|
|
|
values = mx.transpose(mx.reshape(values, (B, L, num_heads, -1)), (0, 2, 1, 3))
|
|
|
|
|
|
|
|
if cache is not None:
|
|
|
|
key_cache, value_cache = cache
|
|
|
|
queries = self.rope(queries, offset=key_cache.shape[2])
|
|
|
|
keys = self.rope(keys, offset=key_cache.shape[2])
|
|
|
|
keys = mx.concatenate([key_cache, keys], axis=2)
|
|
|
|
values = mx.concatenate([value_cache, values], axis=2)
|
|
|
|
else:
|
|
|
|
queries = self.rope(queries)
|
|
|
|
keys = self.rope(keys)
|
|
|
|
|
|
|
|
# Dimensions are [batch x num heads x sequence x hidden dim]
|
|
|
|
scale = mx.array(math.sqrt(1 / queries.shape[-1]), dtype=queries.dtype)
|
|
|
|
scores = (queries * scale) @ mx.transpose(keys, (0, 1, 3, 2))
|
|
|
|
if mask is not None:
|
|
|
|
scores = scores + mask
|
|
|
|
scores = mx.softmax(scores, axis=-1)
|
|
|
|
values_hat = mx.reshape(mx.transpose(scores @ values, (0, 2, 1, 3)), (B, L, -1))
|
|
|
|
|
|
|
|
return self.out_proj(values_hat), (keys, values)
|
|
|
|
|
|
|
|
|
|
|
|
class LlamaEncoderLayer(nn.Module):
|
|
|
|
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
self.attention = LlamaAttention(dims, num_heads)
|
|
|
|
|
|
|
|
self.norm1 = nn.RMSNorm(dims)
|
|
|
|
self.norm2 = nn.RMSNorm(dims)
|
|
|
|
|
|
|
|
self.linear1 = nn.Linear(dims, mlp_dims, False)
|
|
|
|
self.linear2 = nn.Linear(dims, mlp_dims, False)
|
|
|
|
self.linear3 = nn.Linear(mlp_dims, dims, False)
|
|
|
|
|
|
|
|
def __call__(self, x, mask=None, cache=None):
|
|
|
|
y = self.norm1(x)
|
|
|
|
y, cache = self.attention(y, y, y, mask, cache)
|
|
|
|
x = x + y
|
|
|
|
|
|
|
|
y = self.norm2(x)
|
|
|
|
a = self.linear1(y)
|
|
|
|
b = self.linear2(y)
|
|
|
|
y = a * mx.sigmoid(a) * b
|
|
|
|
y = self.linear3(y)
|
|
|
|
x = x + y
|
|
|
|
|
|
|
|
return x, cache
|
|
|
|
|
|
|
|
|
|
|
|
def measure(model, x, cache):
|
|
|
|
for i in range(5):
|
|
|
|
y, c = model(x, mask=None, cache=cache)
|
|
|
|
mx.eval(y, c)
|
|
|
|
|
|
|
|
start = time.time()
|
|
|
|
rs = []
|
|
|
|
for i in range(5):
|
|
|
|
y, c = model(x, mask=None, cache=cache)
|
|
|
|
rs.append((y, c))
|
|
|
|
mx.eval(rs)
|
|
|
|
end = time.time()
|
|
|
|
|
|
|
|
return (end - start) * 1000 / 5
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
H = 32
|
|
|
|
D = 4096
|
|
|
|
F = 43 * 256
|
|
|
|
C = 1000
|
|
|
|
mx.set_default_device(mx.gpu)
|
|
|
|
dtype = mx.float16
|
|
|
|
|
|
|
|
layer = LlamaEncoderLayer(D, F, H)
|
|
|
|
layer.update(mlx.utils.tree_map(lambda x: x.astype(dtype), layer.parameters()))
|
|
|
|
k1, k2, k3 = mx.random.split(mx.random.key(0), 3)
|
|
|
|
x = mx.random.normal([1, 1, D], dtype=dtype)
|
|
|
|
cache = [
|
|
|
|
mx.random.normal([1, H, C, D // H], dtype=dtype),
|
|
|
|
mx.random.normal([1, H, C, D // H], dtype=dtype),
|
|
|
|
]
|
|
|
|
mx.eval(x, cache)
|
|
|
|
|
|
|
|
T = measure(layer, x, cache)
|
|
|
|
|
|
|
|
print("Time per layer per token:", T, "ms")
|
|
|
|
print("Lower bound total time per token:", T * 32, "ms")
|