mlx-examples/t5/convert.py

69 lines
1.9 KiB
Python
Raw Normal View History

2023-12-15 04:21:36 +08:00
from transformers import T5ForConditionalGeneration
import numpy as np
2023-12-15 23:50:04 +08:00
SHARED_REPLACEMENT_PATTERNS = [
(".block.", ".layers."),
(".k.", ".key_proj."),
(".o.", ".out_proj."),
(".q.", ".query_proj."),
(".v.", ".value_proj."),
("shared.", "wte."),
2023-12-17 03:44:15 +08:00
("lm_head.", "lm_head.linear."),
2023-12-15 23:50:04 +08:00
(".layer.0.layer_norm.", ".ln1."),
(".layer.1.layer_norm.", ".ln2."),
(".layer.2.layer_norm.", ".ln3."),
(".final_layer_norm.", ".ln."),
2023-12-18 13:30:28 +08:00
(
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
"relative_attention_bias.embeddings.",
),
2023-12-15 23:50:04 +08:00
]
ENCODER_REPLACEMENT_PATTERNS = [
(".layer.0.SelfAttention.", ".attention."),
2023-12-19 07:05:40 +08:00
(".layer.1.DenseReluDense.", ".dense."),
2023-12-15 23:50:04 +08:00
]
DECODER_REPLACEMENT_PATTERNS = [
(".layer.0.SelfAttention.", ".self_attention."),
(".layer.1.EncDecAttention.", ".cross_attention."),
2023-12-19 07:05:40 +08:00
(".layer.2.DenseReluDense.", ".dense."),
2023-12-15 23:50:04 +08:00
]
2023-12-18 09:35:53 +08:00
2023-12-15 04:21:36 +08:00
def replace_key(key: str) -> str:
2023-12-15 23:50:04 +08:00
for old, new in SHARED_REPLACEMENT_PATTERNS:
key = key.replace(old, new)
if key.startswith("encoder."):
for old, new in ENCODER_REPLACEMENT_PATTERNS:
key = key.replace(old, new)
elif key.startswith("decoder."):
for old, new in DECODER_REPLACEMENT_PATTERNS:
key = key.replace(old, new)
2023-12-15 04:21:36 +08:00
return key
2023-12-15 23:50:04 +08:00
2023-12-19 05:15:02 +08:00
def convert(model_name):
2023-12-18 09:35:53 +08:00
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
2023-12-19 05:15:02 +08:00
weights = {
replace_key(k): v.numpy().astype(np.float16)
for k, v in model.state_dict().items()
}
2023-12-19 07:05:40 +08:00
file_name = model_name.replace("/", "-")
np.savez(f"{file_name}.npz", **weights)
2023-12-15 04:21:36 +08:00
if __name__ == "__main__":
2023-12-18 09:35:53 +08:00
import argparse
parser = argparse.ArgumentParser(description="Convert T5 weights to MLX")
parser.add_argument(
"--model",
2023-12-18 09:35:53 +08:00
type=str,
help="Name of the T5 model.",
default="t5-small",
)
args = parser.parse_args()
2023-12-19 05:15:02 +08:00
convert(args.model)