Add qwen model draft

This commit is contained in:
Juni May 2023-12-18 15:04:21 +08:00
parent 08e862336a
commit ec94fcf430
2 changed files with 272 additions and 0 deletions

23
qwen/convert.py Normal file
View File

@ -0,0 +1,23 @@
from transformers import AutoModelForCausalLM
import numpy as np
def replace_key(key: str) -> str:
if key.startswith("transformer."):
# remove transformer prefix
key = key.replace("transformer.", "")
return key
def convert():
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen-1_8B", trust_remote_code=True
)
state_dict = model.state_dict()
weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
np.savez("weights.npz", **weights)
if __name__ == "__main__":
convert()

249
qwen/qwen.py Normal file
View File

@ -0,0 +1,249 @@
# The architecture of Qwen is similar to Llama.
import argparse
from typing import Any
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten
from dataclasses import dataclass
from transformers import AutoTokenizer
@dataclass
class ModelArgs:
hidden_size: int = 2048
num_attention_heads: int = 16
num_hidden_layers: int = 24
kv_channels: int = 128
max_position_embeddings: int = 8192
layer_norm_epsilon: float = 1e-6
intermediate_size: int = 11008
no_bias: bool = True
vocab_size: int = 151936
class QWenAttntion(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.num_attention_heads = args.num_attention_heads
self.hidden_size_per_attention_head = (
self.hidden_size // self.num_attention_heads
)
self.rotary_emb = nn.RoPE(
self.hidden_size_per_attention_head, traditional=False
)
self.proj_size = args.kv_channels * self.num_attention_heads
self.c_attn = nn.Linear(self.hidden_size, self.proj_size * 3, bias=True)
self.c_proj = nn.Linear(self.hidden_size, self.proj_size, bias=not args.no_bias)
self.scale = self.hidden_size_per_attention_head**-0.5
def __call__(self, x, mask=None, cache=None):
qkv = self.c_attn(x)
q, k, v = mx.split(qkv, 3, axis=-1)
B, L, D = q.shape
q = q.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
k = k.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
v = v.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
k_cache, v_cache = cache
q = self.rotary_emb(q, offset=k_cache.shape[2])
k = self.rotary_emb(k, offset=k_cache.shape[2])
k = mx.concatenate([k_cache, k], axis=2)
v = mx.concatenate([v_cache, v], axis=2)
else:
q = self.rotary_emb(q)
k = self.rotary_emb(k)
q = q.astype(mx.float32)
k = k.astype(mx.float32)
scores = (q * self.scale) @ k.transpose(0, 1, 3, 2)
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, axis=-1).astype(v.dtype)
v_hat = (scores @ v).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.c_proj(v_hat), (k, v)
class QWenMlp(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.w1 = nn.Linear(
args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias
)
self.w2 = nn.Linear(
args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias
)
self.c_proj = nn.Linear(
args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias
)
def __call__(self, x) -> Any:
a1 = self.w1(x)
a2 = self.w2(x)
intermediate_parallel = a1 * nn.silu(a2)
out = self.c_proj(intermediate_parallel)
return out
class QWenBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.attn = QWenAttntion(args)
self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.mlp = QWenMlp(args)
def __call__(self, x, mask=None, cache=None):
residual = x
x = self.ln_1(x)
x, cache = self.attn(x, mask=mask, cache=cache)
residual = x + residual
x = self.ln_2(residual)
x = self.mlp(x)
x = x + residual
return x, cache
class QWen(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.embed_dim = args.hidden_size
self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
self.h = [QWenBlock(args) for _ in range(args.num_hidden_layers)]
self.ln_f = nn.RMSNorm(self.embed_dim, eps=args.layer_norm_epsilon)
self.lm_head = nn.Linear(self.embed_dim, args.vocab_size, bias=False)
def __call__(self, inputs, mask=None, cache=None):
x = self.wte(inputs)
mask = None
if x.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(x.dtype)
if cache is None:
cache = [None] * len(self.h)
for e, layer in enumerate(self.h):
x, cache[e] = layer(x, mask, cache[e])
x = self.ln_f(x)
return self.lm_head(x), cache
def generate(prompt: mx.array, model: QWen, temp: 0.0):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
logits, cache = model(prompt)
y = sample(logits[:, -1, :])
yield y
while True:
logits, cache = model(y[:, None], cache=cache)
y = sample(logits.squeeze(1))
yield y
def load_model():
model = QWen(ModelArgs())
weights = mx.load("weights.npz")
model.update(tree_unflatten(list(weights.items())))
# print([x for x, _ in tree_flatten(model.parameters())])
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-1_8B", trust_remote_code=True)
return model, tokenizer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Phi-2 inference script")
parser.add_argument(
"--prompt",
help="The message to be processed by the model",
default="Write a detailed analogy between mathematics and a lighthouse.",
)
parser.add_argument(
"--max_tokens",
"-m",
type=int,
default=100,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temp",
help="The sampling temperature.",
type=float,
default=0.0,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args()
mx.random.seed(args.seed)
model, tokenizer = load_model()
prompt = tokenizer(
args.prompt,
return_tensors="np",
return_attention_mask=False,
)["input_ids"]
prompt = mx.array(prompt)
print("[INFO] Generating with QWen...", flush=True)
print(args.prompt, end="", flush=True)
tokens = []
for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)):
tokens.append(token)
if (len(tokens) % 10) == 0:
mx.eval(tokens)
eos_index = next(
(i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id),
None,
)
if eos_index is not None:
tokens = tokens[:eos_index]
s = tokenizer.decode([t.item() for t in tokens])
print(s, end="", flush=True)
tokens = []
if eos_index is not None:
break
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s, flush=True)