mlx-examples/phi2/convert.py

25 lines
566 B
Python
Raw Normal View History

2023-12-14 11:22:56 +08:00
from transformers import AutoModelForCausalLM
import numpy as np
2023-12-14 11:22:56 +08:00
2023-12-15 08:56:50 +08:00
2023-12-14 11:22:56 +08:00
def replace_key(key: str) -> str:
if "wte.weight" in key:
key = "wte.weight"
if ".mlp" in key:
key = key.replace(".mlp", "")
return key
def convert():
model = AutoModelForCausalLM.from_pretrained(
"microsoft/phi-2", torch_dtype="auto", trust_remote_code=True
)
state_dict = model.state_dict()
weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
np.savez("weights.npz", **weights)
2023-12-14 11:22:56 +08:00
if __name__ == "__main__":
convert()