2023-12-13 22:22:56 -05:00
|
|
|
from transformers import AutoModelForCausalLM
|
2023-12-14 09:19:44 -08:00
|
|
|
import numpy as np
|
2023-12-13 22:22:56 -05:00
|
|
|
|
2023-12-14 16:56:50 -08:00
|
|
|
|
2023-12-13 22:22:56 -05:00
|
|
|
def replace_key(key: str) -> str:
|
|
|
|
if "wte.weight" in key:
|
|
|
|
key = "wte.weight"
|
|
|
|
|
|
|
|
if ".mlp" in key:
|
|
|
|
key = key.replace(".mlp", "")
|
|
|
|
return key
|
|
|
|
|
|
|
|
|
|
|
|
def convert():
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
"microsoft/phi-2", torch_dtype="auto", trust_remote_code=True
|
|
|
|
)
|
|
|
|
state_dict = model.state_dict()
|
|
|
|
weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
|
2023-12-14 09:19:44 -08:00
|
|
|
np.savez("weights.npz", **weights)
|
2023-12-13 22:22:56 -05:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
convert()
|