Fix convert and tokenizer

This commit is contained in:
Juni May 2023-12-18 16:59:51 +08:00
parent 702ecbb671
commit 2a9c5e8a8c
3 changed files with 37 additions and 15 deletions

2
qwen/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
weights.npz
config.json

View File

@ -1,6 +1,8 @@
import argparse import argparse
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
import numpy as np import numpy as np
import torch
import json
def replace_key(key: str) -> str: def replace_key(key: str) -> str:
@ -13,12 +15,18 @@ def replace_key(key: str) -> str:
def convert(model_path: str = "Qwen/Qwen-1_8B"): def convert(model_path: str = "Qwen/Qwen-1_8B"):
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_path, trust_remote_code=True model_path, trust_remote_code=True, torch_dtype=torch.float16
) )
state_dict = model.state_dict() state_dict = model.state_dict()
weights = {replace_key(k): v.numpy() for k, v in state_dict.items()} weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
np.savez("weights.npz", **weights) np.savez("weights.npz", **weights)
# write config
config = model.config
config_dict = config.to_dict()
with open("config.json", "w") as f:
json.dump(config_dict, f)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Qwen model to npz") parser = argparse.ArgumentParser(description="Convert Qwen model to npz")

View File

@ -2,6 +2,7 @@
# This inference script is mainly for compatibility with the huggingface model of qwen. # This inference script is mainly for compatibility with the huggingface model of qwen.
import argparse import argparse
import json
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -43,10 +44,8 @@ class QWenAttntion(nn.Module):
self.proj_size = args.kv_channels * self.num_attention_heads self.proj_size = args.kv_channels * self.num_attention_heads
self.c_attn = nn.Linear( self.c_attn = nn.Linear(self.hidden_size, self.proj_size * 3, bias=True)
self.hidden_size, self.proj_size * 3, bias=True) self.c_proj = nn.Linear(self.hidden_size, self.proj_size, bias=not args.no_bias)
self.c_proj = nn.Linear(
self.hidden_size, self.proj_size, bias=not args.no_bias)
self.scale = self.hidden_size_per_attention_head**-0.5 self.scale = self.hidden_size_per_attention_head**-0.5
@ -72,9 +71,6 @@ class QWenAttntion(nn.Module):
q = self.rotary_emb(q) q = self.rotary_emb(q)
k = self.rotary_emb(k) k = self.rotary_emb(k)
q = q.astype(mx.float32)
k = k.astype(mx.float32)
scores = (q * self.scale) @ k.transpose(0, 1, 3, 2) scores = (q * self.scale) @ k.transpose(0, 1, 3, 2)
if mask is not None: if mask is not None:
@ -146,8 +142,7 @@ class QWen(nn.Module):
mask = None mask = None
if x.shape[1] > 1: if x.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask( mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
x.shape[1])
mask = mask.astype(x.dtype) mask = mask.astype(x.dtype)
if cache is None: if cache is None:
@ -178,12 +173,30 @@ def generate(prompt: mx.array, model: QWen, temp: 0.0):
yield y yield y
def load_model(tokenizer_path: str = "Qwen/Qwen-1_8B"): def load_model(
model = QWen(ModelArgs()) tokenizer_path: str = "Qwen/Qwen-1_8B", config_path: str = "config.json"
):
model_args = ModelArgs()
with open(config_path, "r") as f:
config = json.load(f)
model_args.vocab_size = config["vocab_size"]
model_args.hidden_size = config["hidden_size"]
model_args.num_attention_heads = config["num_attention_heads"]
model_args.num_hidden_layers = config["num_hidden_layers"]
model_args.kv_channels = config["kv_channels"]
model_args.max_position_embeddings = config["max_position_embeddings"]
model_args.layer_norm_epsilon = config["layer_norm_epsilon"]
model_args.intermediate_size = config["intermediate_size"]
model_args.no_bias = config["no_bias"]
model = QWen(model_args)
weights = mx.load("weights.npz") weights = mx.load("weights.npz")
model.update(tree_unflatten(list(weights.items()))) model.update(tree_unflatten(list(weights.items())))
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, trust_remote_code=True) tokenizer_path, trust_remote_code=True, eos_token="<|endoftext|>"
)
return model, tokenizer return model, tokenizer
@ -239,8 +252,7 @@ if __name__ == "__main__":
if (len(tokens) % 10) == 0: if (len(tokens) % 10) == 0:
mx.eval(tokens) mx.eval(tokens)
eos_index = next( eos_index = next(
(i for i, t in enumerate(tokens) (i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id),
if t.item() == tokenizer.eos_token_id),
None, None,
) )