From 2a9c5e8a8c21018eb61fafbee3452b27ee55c4fb Mon Sep 17 00:00:00 2001 From: Juni May Date: Mon, 18 Dec 2023 16:59:51 +0800 Subject: [PATCH] Fix convert and tokenizer --- qwen/.gitignore | 2 ++ qwen/convert.py | 10 +++++++++- qwen/qwen.py | 40 ++++++++++++++++++++++++++-------------- 3 files changed, 37 insertions(+), 15 deletions(-) create mode 100644 qwen/.gitignore diff --git a/qwen/.gitignore b/qwen/.gitignore new file mode 100644 index 00000000..0c68f15d --- /dev/null +++ b/qwen/.gitignore @@ -0,0 +1,2 @@ +weights.npz +config.json diff --git a/qwen/convert.py b/qwen/convert.py index 2a154814..3406dde8 100644 --- a/qwen/convert.py +++ b/qwen/convert.py @@ -1,6 +1,8 @@ import argparse from transformers import AutoModelForCausalLM import numpy as np +import torch +import json def replace_key(key: str) -> str: @@ -13,12 +15,18 @@ def replace_key(key: str) -> str: def convert(model_path: str = "Qwen/Qwen-1_8B"): model = AutoModelForCausalLM.from_pretrained( - model_path, trust_remote_code=True + model_path, trust_remote_code=True, torch_dtype=torch.float16 ) state_dict = model.state_dict() weights = {replace_key(k): v.numpy() for k, v in state_dict.items()} np.savez("weights.npz", **weights) + # write config + config = model.config + config_dict = config.to_dict() + with open("config.json", "w") as f: + json.dump(config_dict, f) + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert Qwen model to npz") diff --git a/qwen/qwen.py b/qwen/qwen.py index e64ece38..fa5cc834 100644 --- a/qwen/qwen.py +++ b/qwen/qwen.py @@ -2,6 +2,7 @@ # This inference script is mainly for compatibility with the huggingface model of qwen. import argparse +import json import mlx.core as mx import mlx.nn as nn @@ -43,10 +44,8 @@ class QWenAttntion(nn.Module): self.proj_size = args.kv_channels * self.num_attention_heads - self.c_attn = nn.Linear( - self.hidden_size, self.proj_size * 3, bias=True) - self.c_proj = nn.Linear( - self.hidden_size, self.proj_size, bias=not args.no_bias) + self.c_attn = nn.Linear(self.hidden_size, self.proj_size * 3, bias=True) + self.c_proj = nn.Linear(self.hidden_size, self.proj_size, bias=not args.no_bias) self.scale = self.hidden_size_per_attention_head**-0.5 @@ -72,9 +71,6 @@ class QWenAttntion(nn.Module): q = self.rotary_emb(q) k = self.rotary_emb(k) - q = q.astype(mx.float32) - k = k.astype(mx.float32) - scores = (q * self.scale) @ k.transpose(0, 1, 3, 2) if mask is not None: @@ -146,8 +142,7 @@ class QWen(nn.Module): mask = None if x.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask( - x.shape[1]) + mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) mask = mask.astype(x.dtype) if cache is None: @@ -178,12 +173,30 @@ def generate(prompt: mx.array, model: QWen, temp: 0.0): yield y -def load_model(tokenizer_path: str = "Qwen/Qwen-1_8B"): - model = QWen(ModelArgs()) +def load_model( + tokenizer_path: str = "Qwen/Qwen-1_8B", config_path: str = "config.json" +): + model_args = ModelArgs() + + with open(config_path, "r") as f: + config = json.load(f) + model_args.vocab_size = config["vocab_size"] + model_args.hidden_size = config["hidden_size"] + model_args.num_attention_heads = config["num_attention_heads"] + model_args.num_hidden_layers = config["num_hidden_layers"] + model_args.kv_channels = config["kv_channels"] + model_args.max_position_embeddings = config["max_position_embeddings"] + model_args.layer_norm_epsilon = config["layer_norm_epsilon"] + model_args.intermediate_size = config["intermediate_size"] + model_args.no_bias = config["no_bias"] + + model = QWen(model_args) + weights = mx.load("weights.npz") model.update(tree_unflatten(list(weights.items()))) tokenizer = AutoTokenizer.from_pretrained( - tokenizer_path, trust_remote_code=True) + tokenizer_path, trust_remote_code=True, eos_token="<|endoftext|>" + ) return model, tokenizer @@ -239,8 +252,7 @@ if __name__ == "__main__": if (len(tokens) % 10) == 0: mx.eval(tokens) eos_index = next( - (i for i, t in enumerate(tokens) - if t.item() == tokenizer.eos_token_id), + (i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id), None, )