mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-20 03:48:07 +08:00
Add llms subdir + update README (#145)
* add llms subdir + update README * nits * use same pre-commit as mlx * update readmes a bit * format
This commit is contained in:
24
llms/phi2/convert.py
Normal file
24
llms/phi2/convert.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import numpy as np
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
def replace_key(key: str) -> str:
|
||||
if "wte.weight" in key:
|
||||
key = "wte.weight"
|
||||
|
||||
if ".mlp" in key:
|
||||
key = key.replace(".mlp", "")
|
||||
return key
|
||||
|
||||
|
||||
def convert():
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"microsoft/phi-2", torch_dtype="auto", trust_remote_code=True
|
||||
)
|
||||
state_dict = model.state_dict()
|
||||
weights = {replace_key(k): v.numpy() for k, v in state_dict.items()}
|
||||
np.savez("weights.npz", **weights)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
convert()
|
Reference in New Issue
Block a user