2023-12-20 05:06:19 +08:00
|
|
|
import argparse
|
2023-12-21 02:22:25 +08:00
|
|
|
import json
|
|
|
|
|
2023-12-20 05:06:19 +08:00
|
|
|
import numpy as np
|
|
|
|
import torch
|
2023-12-21 02:22:25 +08:00
|
|
|
from transformers import AutoModelForCausalLM
|
2023-12-20 05:06:19 +08:00
|
|
|
|
|
|
|
|
|
|
|
def replace_key(key: str) -> str:
|
|
|
|
if key.startswith("transformer."):
|
|
|
|
# remove transformer prefix
|
|
|
|
key = key.replace("transformer.", "")
|
|
|
|
|
|
|
|
return key
|
|
|
|
|
|
|
|
|
|
|
|
def convert(model_path: str = "Qwen/Qwen-1_8B"):
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
model_path, trust_remote_code=True, torch_dtype=torch.float16
|
|
|
|
)
|
|
|
|
state_dict = model.state_dict()
|
|
|
|
weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
|
|
|
|
np.savez("weights.npz", **weights)
|
|
|
|
|
|
|
|
# write config
|
|
|
|
config = model.config
|
|
|
|
config_dict = config.to_dict()
|
|
|
|
with open("config.json", "w") as f:
|
|
|
|
json.dump(config_dict, f, indent=4)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser(description="Convert Qwen model to npz")
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
"--model",
|
|
|
|
help="The huggingface model to be converted",
|
|
|
|
default="Qwen/Qwen-1_8B",
|
|
|
|
)
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
convert(args.model)
|