mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Fix convert and tokenizer
This commit is contained in:
parent
702ecbb671
commit
2a9c5e8a8c
2
qwen/.gitignore
vendored
Normal file
2
qwen/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
weights.npz
|
||||
config.json
|
@ -1,6 +1,8 @@
|
||||
import argparse
|
||||
from transformers import AutoModelForCausalLM
|
||||
import numpy as np
|
||||
import torch
|
||||
import json
|
||||
|
||||
|
||||
def replace_key(key: str) -> str:
|
||||
@ -13,12 +15,18 @@ def replace_key(key: str) -> str:
|
||||
|
||||
def convert(model_path: str = "Qwen/Qwen-1_8B"):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, trust_remote_code=True
|
||||
model_path, trust_remote_code=True, torch_dtype=torch.float16
|
||||
)
|
||||
state_dict = model.state_dict()
|
||||
weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
|
||||
np.savez("weights.npz", **weights)
|
||||
|
||||
# write config
|
||||
config = model.config
|
||||
config_dict = config.to_dict()
|
||||
with open("config.json", "w") as f:
|
||||
json.dump(config_dict, f)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Qwen model to npz")
|
||||
|
40
qwen/qwen.py
40
qwen/qwen.py
@ -2,6 +2,7 @@
|
||||
# This inference script is mainly for compatibility with the huggingface model of qwen.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
@ -43,10 +44,8 @@ class QWenAttntion(nn.Module):
|
||||
|
||||
self.proj_size = args.kv_channels * self.num_attention_heads
|
||||
|
||||
self.c_attn = nn.Linear(
|
||||
self.hidden_size, self.proj_size * 3, bias=True)
|
||||
self.c_proj = nn.Linear(
|
||||
self.hidden_size, self.proj_size, bias=not args.no_bias)
|
||||
self.c_attn = nn.Linear(self.hidden_size, self.proj_size * 3, bias=True)
|
||||
self.c_proj = nn.Linear(self.hidden_size, self.proj_size, bias=not args.no_bias)
|
||||
|
||||
self.scale = self.hidden_size_per_attention_head**-0.5
|
||||
|
||||
@ -72,9 +71,6 @@ class QWenAttntion(nn.Module):
|
||||
q = self.rotary_emb(q)
|
||||
k = self.rotary_emb(k)
|
||||
|
||||
q = q.astype(mx.float32)
|
||||
k = k.astype(mx.float32)
|
||||
|
||||
scores = (q * self.scale) @ k.transpose(0, 1, 3, 2)
|
||||
|
||||
if mask is not None:
|
||||
@ -146,8 +142,7 @@ class QWen(nn.Module):
|
||||
|
||||
mask = None
|
||||
if x.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(
|
||||
x.shape[1])
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(x.dtype)
|
||||
|
||||
if cache is None:
|
||||
@ -178,12 +173,30 @@ def generate(prompt: mx.array, model: QWen, temp: 0.0):
|
||||
yield y
|
||||
|
||||
|
||||
def load_model(tokenizer_path: str = "Qwen/Qwen-1_8B"):
|
||||
model = QWen(ModelArgs())
|
||||
def load_model(
|
||||
tokenizer_path: str = "Qwen/Qwen-1_8B", config_path: str = "config.json"
|
||||
):
|
||||
model_args = ModelArgs()
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
model_args.vocab_size = config["vocab_size"]
|
||||
model_args.hidden_size = config["hidden_size"]
|
||||
model_args.num_attention_heads = config["num_attention_heads"]
|
||||
model_args.num_hidden_layers = config["num_hidden_layers"]
|
||||
model_args.kv_channels = config["kv_channels"]
|
||||
model_args.max_position_embeddings = config["max_position_embeddings"]
|
||||
model_args.layer_norm_epsilon = config["layer_norm_epsilon"]
|
||||
model_args.intermediate_size = config["intermediate_size"]
|
||||
model_args.no_bias = config["no_bias"]
|
||||
|
||||
model = QWen(model_args)
|
||||
|
||||
weights = mx.load("weights.npz")
|
||||
model.update(tree_unflatten(list(weights.items())))
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_path, trust_remote_code=True)
|
||||
tokenizer_path, trust_remote_code=True, eos_token="<|endoftext|>"
|
||||
)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
@ -239,8 +252,7 @@ if __name__ == "__main__":
|
||||
if (len(tokens) % 10) == 0:
|
||||
mx.eval(tokens)
|
||||
eos_index = next(
|
||||
(i for i, t in enumerate(tokens)
|
||||
if t.item() == tokenizer.eos_token_id),
|
||||
(i for i, t in enumerate(tokens) if t.item() == tokenizer.eos_token_id),
|
||||
None,
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user