diff --git a/qwen/convert.py b/qwen/convert.py index 5ddbad1d..2a154814 100644 --- a/qwen/convert.py +++ b/qwen/convert.py @@ -1,3 +1,4 @@ +import argparse from transformers import AutoModelForCausalLM import numpy as np @@ -10,9 +11,9 @@ def replace_key(key: str) -> str: return key -def convert(): +def convert(model_path: str = "Qwen/Qwen-1_8B"): model = AutoModelForCausalLM.from_pretrained( - "Qwen/Qwen-1_8B", trust_remote_code=True + model_path, trust_remote_code=True ) state_dict = model.state_dict() weights = {replace_key(k): v.numpy() for k, v in state_dict.items()} @@ -20,4 +21,14 @@ def convert(): if __name__ == "__main__": - convert() + parser = argparse.ArgumentParser(description="Convert Qwen model to npz") + + parser.add_argument( + "--model", + help="The huggingface model to be converted", + default="Qwen/Qwen-1_8B", + ) + + args = parser.parse_args() + + convert(args.model) diff --git a/qwen/qwen.py b/qwen/qwen.py index 877c46c3..e64ece38 100644 --- a/qwen/qwen.py +++ b/qwen/qwen.py @@ -1,14 +1,12 @@ -# The architecture of Qwen is similar to Llama. +# The architecture of qwen is similar to Llama. +# This inference script is mainly for compatibility with the huggingface model of qwen. import argparse - -from typing import Any import mlx.core as mx import mlx.nn as nn + from mlx.utils import tree_unflatten - from dataclasses import dataclass - from transformers import AutoTokenizer @@ -45,8 +43,10 @@ class QWenAttntion(nn.Module): self.proj_size = args.kv_channels * self.num_attention_heads - self.c_attn = nn.Linear(self.hidden_size, self.proj_size * 3, bias=True) - self.c_proj = nn.Linear(self.hidden_size, self.proj_size, bias=not args.no_bias) + self.c_attn = nn.Linear( + self.hidden_size, self.proj_size * 3, bias=True) + self.c_proj = nn.Linear( + self.hidden_size, self.proj_size, bias=not args.no_bias) self.scale = self.hidden_size_per_attention_head**-0.5 @@ -55,7 +55,7 @@ class QWenAttntion(nn.Module): q, k, v = mx.split(qkv, 3, axis=-1) - B, L, D = q.shape + B, L, _ = q.shape q = q.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) k = k.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) @@ -100,7 +100,7 @@ class QWenMlp(nn.Module): args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias ) - def __call__(self, x) -> Any: + def __call__(self, x): a1 = self.w1(x) a2 = self.w2(x) intermediate_parallel = a1 * nn.silu(a2) @@ -146,7 +146,8 @@ class QWen(nn.Module): mask = None if x.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) + mask = nn.MultiHeadAttention.create_additive_causal_mask( + x.shape[1]) mask = mask.astype(x.dtype) if cache is None: @@ -177,21 +178,29 @@ def generate(prompt: mx.array, model: QWen, temp: 0.0): yield y -def load_model(): +def load_model(tokenizer_path: str = "Qwen/Qwen-1_8B"): model = QWen(ModelArgs()) weights = mx.load("weights.npz") model.update(tree_unflatten(list(weights.items()))) - # print([x for x, _ in tree_flatten(model.parameters())]) - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-1_8B", trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path, trust_remote_code=True) return model, tokenizer if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Phi-2 inference script") + # The infernece code and arguments were mainly derived from phi-2 example. + + parser = argparse.ArgumentParser(description="Qwen inference script") + parser.add_argument( + "--tokenizer", + help="The tokenizer to be used, defaults to Qwen/Qwen-1_8B", + default="Qwen/Qwen-1_8B", + ) parser.add_argument( "--prompt", help="The message to be processed by the model", - default="Write a detailed analogy between mathematics and a lighthouse.", + # The example from the official huggingface repo of Qwen + default="蒙古国的首都是乌兰巴托(Ulaanbaatar)\n冰岛的首都是雷克雅未克(Reykjavik)\n埃塞俄比亚的首都是", ) parser.add_argument( "--max_tokens", @@ -211,7 +220,7 @@ if __name__ == "__main__": mx.random.seed(args.seed) - model, tokenizer = load_model() + model, tokenizer = load_model(args.tokenizer) prompt = tokenizer( args.prompt, @@ -221,7 +230,6 @@ if __name__ == "__main__": prompt = mx.array(prompt) - print("[INFO] Generating with QWen...", flush=True) print(args.prompt, end="", flush=True) tokens = [] @@ -231,7 +239,8 @@ if __name__ == "__main__": if (len(tokens) % 10) == 0: mx.eval(tokens) eos_index = next( - (i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id), + (i for i, t in enumerate(tokens) + if t.item() == tokenizer.eos_token_id), None, )