mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00

* Add skeleton * Load all encoder weights * Pass config to all modules, fix ln * Load position bias embeddings * Load decoder weights * Move position biases to attention module * translate pytorch to mx * Fix default prompt * Fix relative_attention_max_distance config * No scaling, no encoder mask * LM head * Decode (broken after 1st token) * Use position bias in all layers * Utils to compare encoder output * Fix layer norm * Fix decoder mask * Use position bias in decoder * Concatenate tokens * Remove prints * Stop on eos * Measure tokens/s * with cache * bug fix with bidirectional only for encoder, add offset to position bias * format * Fix T5.__call__ * Stream output * Add argument to generate float16 npz * Load config from HF to support any model * Uncomment bidirectional param * Add gitignore * Add readme.md for t5 * Fix relative position scale * Fix --encode-only * Run hf_t5 with any model * Add hf generation for comparison * Fix type for attention mask * Increase hf max_length * Rescale output before projecting on vocab * readme updates * nits * Pass ln2 to cross attention * Fix example * Fix attention for 3b model * fp16, abstract tokenizer a bit, format * clamp for low precision * higher clipping, remove non-helpful casts * default to fp32 for now * Adds support for flan-t5 * Update t5 docs on variant support * readme flan * nit --------- Co-authored-by: Awni Hannun <awni@apple.com>
69 lines
1.9 KiB
Python
69 lines
1.9 KiB
Python
from transformers import T5ForConditionalGeneration
|
|
import numpy as np
|
|
|
|
|
|
SHARED_REPLACEMENT_PATTERNS = [
|
|
(".block.", ".layers."),
|
|
(".k.", ".key_proj."),
|
|
(".o.", ".out_proj."),
|
|
(".q.", ".query_proj."),
|
|
(".v.", ".value_proj."),
|
|
("shared.", "wte."),
|
|
("lm_head.", "lm_head.linear."),
|
|
(".layer.0.layer_norm.", ".ln1."),
|
|
(".layer.1.layer_norm.", ".ln2."),
|
|
(".layer.2.layer_norm.", ".ln3."),
|
|
(".final_layer_norm.", ".ln."),
|
|
(
|
|
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
|
|
"relative_attention_bias.embeddings.",
|
|
),
|
|
]
|
|
|
|
ENCODER_REPLACEMENT_PATTERNS = [
|
|
(".layer.0.SelfAttention.", ".attention."),
|
|
(".layer.1.DenseReluDense.", ".dense."),
|
|
]
|
|
|
|
DECODER_REPLACEMENT_PATTERNS = [
|
|
(".layer.0.SelfAttention.", ".self_attention."),
|
|
(".layer.1.EncDecAttention.", ".cross_attention."),
|
|
(".layer.2.DenseReluDense.", ".dense."),
|
|
]
|
|
|
|
|
|
def replace_key(key: str) -> str:
|
|
for old, new in SHARED_REPLACEMENT_PATTERNS:
|
|
key = key.replace(old, new)
|
|
if key.startswith("encoder."):
|
|
for old, new in ENCODER_REPLACEMENT_PATTERNS:
|
|
key = key.replace(old, new)
|
|
elif key.startswith("decoder."):
|
|
for old, new in DECODER_REPLACEMENT_PATTERNS:
|
|
key = key.replace(old, new)
|
|
return key
|
|
|
|
|
|
def convert(model_name):
|
|
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
|
|
weights = {
|
|
replace_key(k): v.numpy().astype(np.float16)
|
|
for k, v in model.state_dict().items()
|
|
}
|
|
file_name = model_name.replace("/", "-")
|
|
np.savez(f"{file_name}.npz", **weights)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="Convert T5 weights to MLX")
|
|
parser.add_argument(
|
|
"--model",
|
|
type=str,
|
|
help="Name of the T5 model.",
|
|
default="t5-small",
|
|
)
|
|
args = parser.parse_args()
|
|
convert(args.model)
|