mlx-examples/llms/phi2/phi2.py

234 lines
6.9 KiB
Python
Raw Normal View History

2023-12-15 00:08:28 +08:00
import argparse
import math
from dataclasses import dataclass
from pathlib import Path
2023-12-14 11:22:56 +08:00
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten
from transformers import AutoTokenizer
2023-12-14 11:22:56 +08:00
2023-12-15 00:08:28 +08:00
2023-12-14 11:22:56 +08:00
@dataclass
class ModelArgs:
max_sequence_length: int = 2048
num_vocab: int = 51200
model_dim: int = 2560
num_heads: int = 32
num_layers: int = 32
rotary_dim: int = 32
2023-12-15 00:08:28 +08:00
class LayerNorm(nn.LayerNorm):
def __call__(self, x: mx.array) -> mx.array:
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
2023-12-14 11:22:56 +08:00
class RoPEAttention(nn.Module):
2023-12-15 00:08:28 +08:00
def __init__(self, dims: int, num_heads: int, rotary_dim: int):
2023-12-14 11:22:56 +08:00
super().__init__()
self.num_heads = num_heads
2023-12-15 00:08:28 +08:00
self.rope = nn.RoPE(rotary_dim, traditional=False)
self.Wqkv = nn.Linear(dims, 3 * dims)
2023-12-15 00:08:28 +08:00
self.out_proj = nn.Linear(dims, dims)
2023-12-14 11:22:56 +08:00
def __call__(self, x, mask=None, cache=None):
qkv = self.Wqkv(x)
queries, keys, values = mx.split(qkv, 3, axis=-1)
2023-12-14 11:22:56 +08:00
# Extract some shapes
num_heads = self.num_heads
B, L, D = queries.shape
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
2023-12-15 00:08:28 +08:00
queries = queries.astype(mx.float32)
keys = keys.astype(mx.float32)
2023-12-14 11:22:56 +08:00
# Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores = scores + mask
2023-12-15 00:08:28 +08:00
scores = mx.softmax(scores, axis=-1).astype(values.dtype)
2023-12-14 11:22:56 +08:00
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat), (keys, values)
class ParallelBlock(nn.Module):
2023-12-15 00:08:28 +08:00
def __init__(self, config: ModelArgs):
2023-12-14 11:22:56 +08:00
super().__init__()
2023-12-15 00:08:28 +08:00
dims = config.model_dim
mlp_dims = dims * 4
self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim)
2023-12-15 00:08:28 +08:00
self.ln = LayerNorm(dims)
2023-12-14 11:22:56 +08:00
self.fc1 = nn.Linear(dims, mlp_dims)
self.fc2 = nn.Linear(mlp_dims, dims)
self.act = nn.GELU(approx="precise")
2023-12-14 11:22:56 +08:00
def __call__(self, x, mask, cache):
h = self.ln(x)
attn_h, cache = self.mixer(h, mask, cache)
ff_h = self.fc2(self.act(self.fc1(h)))
return attn_h + ff_h + x, cache
2023-12-14 11:22:56 +08:00
class TransformerDecoder(nn.Module):
2023-12-15 00:08:28 +08:00
def __init__(self, config: ModelArgs):
2023-12-14 11:22:56 +08:00
super().__init__()
2023-12-15 00:08:28 +08:00
self.h = [ParallelBlock(config) for i in range(config.num_layers)]
2023-12-14 11:22:56 +08:00
def __call__(self, x, mask, cache):
if cache is None:
cache = [None] * len(self.h)
for e, layer in enumerate(self.h):
x, cache[e] = layer(x, mask, cache[e])
return x, cache
class OutputHead(nn.Module):
def __init__(self, config: ModelArgs) -> None:
2023-12-15 00:08:28 +08:00
self.ln = LayerNorm(config.model_dim)
self.linear = nn.Linear(config.model_dim, config.num_vocab)
def __call__(self, inputs):
return self.linear(self.ln(inputs))
2023-12-14 11:22:56 +08:00
class Phi2(nn.Module):
def __init__(self, config: ModelArgs):
self.wte = nn.Embedding(config.num_vocab, config.model_dim)
2023-12-15 00:08:28 +08:00
self.transformer = TransformerDecoder(config)
self.lm_head = OutputHead(config)
2023-12-14 11:22:56 +08:00
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache: mx.array = None,
2023-12-14 11:22:56 +08:00
) -> tuple[mx.array, mx.array]:
x = self.wte(inputs)
2023-12-14 11:22:56 +08:00
mask = None
if x.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(x.dtype)
2023-12-14 11:22:56 +08:00
y, cache = self.transformer(x, mask, cache)
return self.lm_head(y), cache
2023-12-14 11:22:56 +08:00
def generate(prompt: mx.array, model: Phi2, temp: Optional[float] = 0.0):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
2023-12-14 11:22:56 +08:00
logits, cache = model(prompt)
y = sample(logits[:, -1, :])
yield y
2023-12-14 11:22:56 +08:00
while True:
logits, cache = model(y[:, None], cache=cache)
y = sample(logits.squeeze(1))
yield y
2023-12-14 11:22:56 +08:00
def load_model(model_path: str):
2023-12-14 11:22:56 +08:00
model = Phi2(ModelArgs())
model_path = Path(model_path)
weights = mx.load(str(model_path / "weights.npz"))
2023-12-15 00:27:44 +08:00
model.update(tree_unflatten(list(weights.items())))
2023-12-14 11:22:56 +08:00
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
2023-12-15 00:08:28 +08:00
return model, tokenizer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Phi-2 inference script")
parser.add_argument(
"--model-path",
type=str,
default="phi-2",
help="The path to the model weights",
)
2023-12-15 00:08:28 +08:00
parser.add_argument(
"--prompt",
help="The message to be processed by the model",
default="Write a detailed analogy between mathematics and a lighthouse.",
)
parser.add_argument(
"--max-tokens",
2023-12-15 00:08:28 +08:00
"-m",
type=int,
default=100,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temp",
help="The sampling temperature.",
type=float,
default=0.0,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args()
mx.random.seed(args.seed)
model, tokenizer = load_model(args.model_path)
2023-12-15 00:08:28 +08:00
prompt = tokenizer(
args.prompt,
2023-12-14 11:22:56 +08:00
return_tensors="np",
return_attention_mask=False,
)["input_ids"]
prompt = mx.array(prompt)
print("[INFO] Generating with Phi-2...", flush=True)
print(args.prompt, end="", flush=True)
tokens = []
2023-12-16 04:16:41 +08:00
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
2023-12-15 13:11:23 +08:00
tokens.append(token)
if (len(tokens) % 10) == 0:
mx.eval(tokens)
eos_index = next(
(i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id),
None,
)
2023-12-15 13:11:23 +08:00
if eos_index is not None:
tokens = tokens[:eos_index]
s = tokenizer.decode([t.item() for t in tokens])
print(s, end="", flush=True)
tokens = []
2023-12-15 13:11:23 +08:00
if eos_index is not None:
break
2023-12-15 00:27:44 +08:00
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s, flush=True)