From ec94fcf430930c2d087477a3501b3738345734b8 Mon Sep 17 00:00:00 2001 From: Juni May Date: Mon, 18 Dec 2023 15:04:21 +0800 Subject: [PATCH] Add qwen model draft --- qwen/convert.py | 23 +++++ qwen/qwen.py | 249 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 272 insertions(+) create mode 100644 qwen/convert.py create mode 100644 qwen/qwen.py diff --git a/qwen/convert.py b/qwen/convert.py new file mode 100644 index 00000000..5ddbad1d --- /dev/null +++ b/qwen/convert.py @@ -0,0 +1,23 @@ +from transformers import AutoModelForCausalLM +import numpy as np + + +def replace_key(key: str) -> str: + if key.startswith("transformer."): + # remove transformer prefix + key = key.replace("transformer.", "") + + return key + + +def convert(): + model = AutoModelForCausalLM.from_pretrained( + "Qwen/Qwen-1_8B", trust_remote_code=True + ) + state_dict = model.state_dict() + weights = {replace_key(k): v.numpy() for k, v in state_dict.items()} + np.savez("weights.npz", **weights) + + +if __name__ == "__main__": + convert() diff --git a/qwen/qwen.py b/qwen/qwen.py new file mode 100644 index 00000000..877c46c3 --- /dev/null +++ b/qwen/qwen.py @@ -0,0 +1,249 @@ +# The architecture of Qwen is similar to Llama. + +import argparse + +from typing import Any +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_unflatten + +from dataclasses import dataclass + +from transformers import AutoTokenizer + + +@dataclass +class ModelArgs: + hidden_size: int = 2048 + num_attention_heads: int = 16 + num_hidden_layers: int = 24 + kv_channels: int = 128 + + max_position_embeddings: int = 8192 + layer_norm_epsilon: float = 1e-6 + intermediate_size: int = 11008 + + no_bias: bool = True + + vocab_size: int = 151936 + + +class QWenAttntion(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + self.hidden_size = args.hidden_size + self.num_attention_heads = args.num_attention_heads + + self.hidden_size_per_attention_head = ( + self.hidden_size // self.num_attention_heads + ) + + self.rotary_emb = nn.RoPE( + self.hidden_size_per_attention_head, traditional=False + ) + + self.proj_size = args.kv_channels * self.num_attention_heads + + self.c_attn = nn.Linear(self.hidden_size, self.proj_size * 3, bias=True) + self.c_proj = nn.Linear(self.hidden_size, self.proj_size, bias=not args.no_bias) + + self.scale = self.hidden_size_per_attention_head**-0.5 + + def __call__(self, x, mask=None, cache=None): + qkv = self.c_attn(x) + + q, k, v = mx.split(qkv, 3, axis=-1) + + B, L, D = q.shape + + q = q.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) + k = k.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) + v = v.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + k_cache, v_cache = cache + q = self.rotary_emb(q, offset=k_cache.shape[2]) + k = self.rotary_emb(k, offset=k_cache.shape[2]) + k = mx.concatenate([k_cache, k], axis=2) + v = mx.concatenate([v_cache, v], axis=2) + + else: + q = self.rotary_emb(q) + k = self.rotary_emb(k) + + q = q.astype(mx.float32) + k = k.astype(mx.float32) + + scores = (q * self.scale) @ k.transpose(0, 1, 3, 2) + + if mask is not None: + scores = scores + mask + + scores = mx.softmax(scores, axis=-1).astype(v.dtype) + v_hat = (scores @ v).transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.c_proj(v_hat), (k, v) + + +class QWenMlp(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + self.w1 = nn.Linear( + args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias + ) + self.w2 = nn.Linear( + args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias + ) + self.c_proj = nn.Linear( + args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias + ) + + def __call__(self, x) -> Any: + a1 = self.w1(x) + a2 = self.w2(x) + intermediate_parallel = a1 * nn.silu(a2) + out = self.c_proj(intermediate_parallel) + return out + + +class QWenBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + self.attn = QWenAttntion(args) + self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + self.mlp = QWenMlp(args) + + def __call__(self, x, mask=None, cache=None): + residual = x + x = self.ln_1(x) + x, cache = self.attn(x, mask=mask, cache=cache) + residual = x + residual + x = self.ln_2(residual) + x = self.mlp(x) + x = x + residual + + return x, cache + + +class QWen(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + self.embed_dim = args.hidden_size + + self.wte = nn.Embedding(args.vocab_size, args.hidden_size) + self.h = [QWenBlock(args) for _ in range(args.num_hidden_layers)] + self.ln_f = nn.RMSNorm(self.embed_dim, eps=args.layer_norm_epsilon) + + self.lm_head = nn.Linear(self.embed_dim, args.vocab_size, bias=False) + + def __call__(self, inputs, mask=None, cache=None): + x = self.wte(inputs) + + mask = None + if x.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) + mask = mask.astype(x.dtype) + + if cache is None: + cache = [None] * len(self.h) + + for e, layer in enumerate(self.h): + x, cache[e] = layer(x, mask, cache[e]) + + x = self.ln_f(x) + + return self.lm_head(x), cache + + +def generate(prompt: mx.array, model: QWen, temp: 0.0): + def sample(logits): + if temp == 0: + return mx.argmax(logits, axis=-1) + else: + return mx.random.categorical(logits * (1 / temp)) + + logits, cache = model(prompt) + y = sample(logits[:, -1, :]) + yield y + + while True: + logits, cache = model(y[:, None], cache=cache) + y = sample(logits.squeeze(1)) + yield y + + +def load_model(): + model = QWen(ModelArgs()) + weights = mx.load("weights.npz") + model.update(tree_unflatten(list(weights.items()))) + # print([x for x, _ in tree_flatten(model.parameters())]) + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-1_8B", trust_remote_code=True) + return model, tokenizer + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Phi-2 inference script") + parser.add_argument( + "--prompt", + help="The message to be processed by the model", + default="Write a detailed analogy between mathematics and a lighthouse.", + ) + parser.add_argument( + "--max_tokens", + "-m", + type=int, + default=100, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--temp", + help="The sampling temperature.", + type=float, + default=0.0, + ) + parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") + args = parser.parse_args() + + mx.random.seed(args.seed) + + model, tokenizer = load_model() + + prompt = tokenizer( + args.prompt, + return_tensors="np", + return_attention_mask=False, + )["input_ids"] + + prompt = mx.array(prompt) + + print("[INFO] Generating with QWen...", flush=True) + print(args.prompt, end="", flush=True) + + tokens = [] + for token, _ in zip(generate(prompt, model, args.temp), range(args.max_tokens)): + tokens.append(token) + + if (len(tokens) % 10) == 0: + mx.eval(tokens) + eos_index = next( + (i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id), + None, + ) + + if eos_index is not None: + tokens = tokens[:eos_index] + + s = tokenizer.decode([t.item() for t in tokens]) + print(s, end="", flush=True) + tokens = [] + if eos_index is not None: + break + + mx.eval(tokens) + s = tokenizer.decode([t.item() for t in tokens]) + print(s, flush=True)