mlx-examples/llms/qwen/convert.py

43 lines
1.0 KiB
Python
Raw Normal View History

2023-12-18 15:30:36 +08:00
import argparse
2023-12-18 15:04:21 +08:00
from transformers import AutoModelForCausalLM
import numpy as np
2023-12-18 16:59:51 +08:00
import torch
import json
2023-12-18 15:04:21 +08:00
def replace_key(key: str) -> str:
if key.startswith("transformer."):
# remove transformer prefix
key = key.replace("transformer.", "")
return key
2023-12-18 15:30:36 +08:00
def convert(model_path: str = "Qwen/Qwen-1_8B"):
2023-12-18 15:04:21 +08:00
model = AutoModelForCausalLM.from_pretrained(
2023-12-18 16:59:51 +08:00
model_path, trust_remote_code=True, torch_dtype=torch.float16
2023-12-18 15:04:21 +08:00
)
state_dict = model.state_dict()
weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
np.savez("weights.npz", **weights)
2023-12-18 16:59:51 +08:00
# write config
config = model.config
config_dict = config.to_dict()
with open("config.json", "w") as f:
2023-12-20 04:58:59 +08:00
json.dump(config_dict, f, indent=4)
2023-12-18 16:59:51 +08:00
2023-12-18 15:04:21 +08:00
if __name__ == "__main__":
2023-12-18 15:30:36 +08:00
parser = argparse.ArgumentParser(description="Convert Qwen model to npz")
parser.add_argument(
"--model",
help="The huggingface model to be converted",
default="Qwen/Qwen-1_8B",
)
args = parser.parse_args()
convert(args.model)