Add model and tokenizer options

This commit is contained in:
Juni May 2023-12-18 15:30:36 +08:00
parent a8ef549546
commit 702ecbb671
2 changed files with 41 additions and 21 deletions

View File

@ -1,3 +1,4 @@
import argparse
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
import numpy as np import numpy as np
@ -10,9 +11,9 @@ def replace_key(key: str) -> str:
return key return key
def convert(): def convert(model_path: str = "Qwen/Qwen-1_8B"):
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen-1_8B", trust_remote_code=True model_path, trust_remote_code=True
) )
state_dict = model.state_dict() state_dict = model.state_dict()
weights = {replace_key(k): v.numpy() for k, v in state_dict.items()} weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
@ -20,4 +21,14 @@ def convert():
if __name__ == "__main__": if __name__ == "__main__":
convert() parser = argparse.ArgumentParser(description="Convert Qwen model to npz")
parser.add_argument(
"--model",
help="The huggingface model to be converted",
default="Qwen/Qwen-1_8B",
)
args = parser.parse_args()
convert(args.model)

View File

@ -1,14 +1,12 @@
# The architecture of Qwen is similar to Llama. # The architecture of qwen is similar to Llama.
# This inference script is mainly for compatibility with the huggingface model of qwen.
import argparse import argparse
from typing import Any
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx.utils import tree_unflatten from mlx.utils import tree_unflatten
from dataclasses import dataclass from dataclasses import dataclass
from transformers import AutoTokenizer from transformers import AutoTokenizer
@ -45,8 +43,10 @@ class QWenAttntion(nn.Module):
self.proj_size = args.kv_channels * self.num_attention_heads self.proj_size = args.kv_channels * self.num_attention_heads
self.c_attn = nn.Linear(self.hidden_size, self.proj_size * 3, bias=True) self.c_attn = nn.Linear(
self.c_proj = nn.Linear(self.hidden_size, self.proj_size, bias=not args.no_bias) self.hidden_size, self.proj_size * 3, bias=True)
self.c_proj = nn.Linear(
self.hidden_size, self.proj_size, bias=not args.no_bias)
self.scale = self.hidden_size_per_attention_head**-0.5 self.scale = self.hidden_size_per_attention_head**-0.5
@ -55,7 +55,7 @@ class QWenAttntion(nn.Module):
q, k, v = mx.split(qkv, 3, axis=-1) q, k, v = mx.split(qkv, 3, axis=-1)
B, L, D = q.shape B, L, _ = q.shape
q = q.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) q = q.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
k = k.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) k = k.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
@ -100,7 +100,7 @@ class QWenMlp(nn.Module):
args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias
) )
def __call__(self, x) -> Any: def __call__(self, x):
a1 = self.w1(x) a1 = self.w1(x)
a2 = self.w2(x) a2 = self.w2(x)
intermediate_parallel = a1 * nn.silu(a2) intermediate_parallel = a1 * nn.silu(a2)
@ -146,7 +146,8 @@ class QWen(nn.Module):
mask = None mask = None
if x.shape[1] > 1: if x.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) mask = nn.MultiHeadAttention.create_additive_causal_mask(
x.shape[1])
mask = mask.astype(x.dtype) mask = mask.astype(x.dtype)
if cache is None: if cache is None:
@ -177,21 +178,29 @@ def generate(prompt: mx.array, model: QWen, temp: 0.0):
yield y yield y
def load_model(): def load_model(tokenizer_path: str = "Qwen/Qwen-1_8B"):
model = QWen(ModelArgs()) model = QWen(ModelArgs())
weights = mx.load("weights.npz") weights = mx.load("weights.npz")
model.update(tree_unflatten(list(weights.items()))) model.update(tree_unflatten(list(weights.items())))
# print([x for x, _ in tree_flatten(model.parameters())]) tokenizer = AutoTokenizer.from_pretrained(
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-1_8B", trust_remote_code=True) tokenizer_path, trust_remote_code=True)
return model, tokenizer return model, tokenizer
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Phi-2 inference script") # The infernece code and arguments were mainly derived from phi-2 example.
parser = argparse.ArgumentParser(description="Qwen inference script")
parser.add_argument(
"--tokenizer",
help="The tokenizer to be used, defaults to Qwen/Qwen-1_8B",
default="Qwen/Qwen-1_8B",
)
parser.add_argument( parser.add_argument(
"--prompt", "--prompt",
help="The message to be processed by the model", help="The message to be processed by the model",
default="Write a detailed analogy between mathematics and a lighthouse.", # The example from the official huggingface repo of Qwen
default="蒙古国的首都是乌兰巴托Ulaanbaatar\n冰岛的首都是雷克雅未克Reykjavik\n埃塞俄比亚的首都是",
) )
parser.add_argument( parser.add_argument(
"--max_tokens", "--max_tokens",
@ -211,7 +220,7 @@ if __name__ == "__main__":
mx.random.seed(args.seed) mx.random.seed(args.seed)
model, tokenizer = load_model() model, tokenizer = load_model(args.tokenizer)
prompt = tokenizer( prompt = tokenizer(
args.prompt, args.prompt,
@ -221,7 +230,6 @@ if __name__ == "__main__":
prompt = mx.array(prompt) prompt = mx.array(prompt)
print("[INFO] Generating with QWen...", flush=True)
print(args.prompt, end="", flush=True) print(args.prompt, end="", flush=True)
tokens = [] tokens = []
@ -231,7 +239,8 @@ if __name__ == "__main__":
if (len(tokens) % 10) == 0: if (len(tokens) % 10) == 0:
mx.eval(tokens) mx.eval(tokens)
eos_index = next( eos_index = next(
(i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id), (i for i, t in enumerate(tokens)
if t.item() == tokenizer.eos_token_id),
None, None,
) )