mlx-examples/t5/convert.py
Juarez Bochi 10a7b99e83
Add T5 and Flan-T5 example (#113)
* Add skeleton

* Load all encoder weights

* Pass config to all modules, fix ln

* Load position bias embeddings

* Load decoder weights

* Move position biases to attention module

* translate pytorch to mx

* Fix default prompt

* Fix relative_attention_max_distance config

* No scaling, no encoder mask

* LM head

* Decode (broken after 1st token)

* Use position bias in all layers

* Utils to compare encoder output

* Fix layer norm

* Fix decoder mask

* Use position bias in decoder

* Concatenate tokens

* Remove prints

* Stop on eos

* Measure tokens/s

* with cache

* bug fix with bidirectional only for encoder, add offset to position bias

* format

* Fix T5.__call__

* Stream output

* Add argument to generate float16 npz

* Load config from HF to support any model

* Uncomment bidirectional param

* Add gitignore

* Add readme.md for t5

* Fix relative position scale

* Fix --encode-only

* Run hf_t5 with any model

* Add hf generation for comparison

* Fix type for attention mask

* Increase hf max_length

* Rescale output before projecting on vocab

* readme updates

* nits

* Pass ln2 to cross attention

* Fix example

* Fix attention for 3b model

* fp16, abstract tokenizer a bit, format

* clamp for low precision

* higher clipping, remove non-helpful casts

* default to fp32 for now

* Adds support for flan-t5

* Update t5 docs on variant support

* readme flan

* nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-18 20:25:34 -08:00

69 lines
1.9 KiB
Python

from transformers import T5ForConditionalGeneration
import numpy as np
SHARED_REPLACEMENT_PATTERNS = [
(".block.", ".layers."),
(".k.", ".key_proj."),
(".o.", ".out_proj."),
(".q.", ".query_proj."),
(".v.", ".value_proj."),
("shared.", "wte."),
("lm_head.", "lm_head.linear."),
(".layer.0.layer_norm.", ".ln1."),
(".layer.1.layer_norm.", ".ln2."),
(".layer.2.layer_norm.", ".ln3."),
(".final_layer_norm.", ".ln."),
(
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
"relative_attention_bias.embeddings.",
),
]
ENCODER_REPLACEMENT_PATTERNS = [
(".layer.0.SelfAttention.", ".attention."),
(".layer.1.DenseReluDense.", ".dense."),
]
DECODER_REPLACEMENT_PATTERNS = [
(".layer.0.SelfAttention.", ".self_attention."),
(".layer.1.EncDecAttention.", ".cross_attention."),
(".layer.2.DenseReluDense.", ".dense."),
]
def replace_key(key: str) -> str:
for old, new in SHARED_REPLACEMENT_PATTERNS:
key = key.replace(old, new)
if key.startswith("encoder."):
for old, new in ENCODER_REPLACEMENT_PATTERNS:
key = key.replace(old, new)
elif key.startswith("decoder."):
for old, new in DECODER_REPLACEMENT_PATTERNS:
key = key.replace(old, new)
return key
def convert(model_name):
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
weights = {
replace_key(k): v.numpy().astype(np.float16)
for k, v in model.state_dict().items()
}
file_name = model_name.replace("/", "-")
np.savez(f"{file_name}.npz", **weights)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Convert T5 weights to MLX")
parser.add_argument(
"--model",
type=str,
help="Name of the T5 model.",
default="t5-small",
)
args = parser.parse_args()
convert(args.model)