mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-03 23:14:34 +08:00
Add T5 and Flan-T5 example (#113)
* Add skeleton * Load all encoder weights * Pass config to all modules, fix ln * Load position bias embeddings * Load decoder weights * Move position biases to attention module * translate pytorch to mx * Fix default prompt * Fix relative_attention_max_distance config * No scaling, no encoder mask * LM head * Decode (broken after 1st token) * Use position bias in all layers * Utils to compare encoder output * Fix layer norm * Fix decoder mask * Use position bias in decoder * Concatenate tokens * Remove prints * Stop on eos * Measure tokens/s * with cache * bug fix with bidirectional only for encoder, add offset to position bias * format * Fix T5.__call__ * Stream output * Add argument to generate float16 npz * Load config from HF to support any model * Uncomment bidirectional param * Add gitignore * Add readme.md for t5 * Fix relative position scale * Fix --encode-only * Run hf_t5 with any model * Add hf generation for comparison * Fix type for attention mask * Increase hf max_length * Rescale output before projecting on vocab * readme updates * nits * Pass ln2 to cross attention * Fix example * Fix attention for 3b model * fp16, abstract tokenizer a bit, format * clamp for low precision * higher clipping, remove non-helpful casts * default to fp32 for now * Adds support for flan-t5 * Update t5 docs on variant support * readme flan * nit --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
54
t5/hf_t5.py
Normal file
54
t5/hf_t5.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from transformers import T5ForConditionalGeneration, T5EncoderModel, AutoTokenizer
|
||||
|
||||
import argparse
|
||||
|
||||
|
||||
def embed(t5_model: str):
|
||||
batch = [
|
||||
"translate English to German: That is good.",
|
||||
"This is an example of T5 working on MLX.",
|
||||
]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(t5_model)
|
||||
torch_model = T5EncoderModel.from_pretrained(t5_model)
|
||||
torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
|
||||
torch_forward = torch_model(**torch_tokens, output_hidden_states=True)
|
||||
torch_output = torch_forward.last_hidden_state.detach().numpy()
|
||||
|
||||
print("\n TF BERT:")
|
||||
for input_str, embedding in list(zip(batch, torch_output)):
|
||||
print("Input:", input_str)
|
||||
print(embedding)
|
||||
print()
|
||||
|
||||
|
||||
def generate(t5_model: str):
|
||||
prompt = "translate English to German: As much as six inches of rain could fall in the New York City region through Monday morning, and officials warned of flooding along the coast."
|
||||
tokenizer = AutoTokenizer.from_pretrained(t5_model)
|
||||
torch_model = T5ForConditionalGeneration.from_pretrained(t5_model)
|
||||
torch_tokens = tokenizer(prompt, return_tensors="pt", padding=True).input_ids
|
||||
outputs = torch_model.generate(torch_tokens, do_sample=False, max_length=512)
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run the T5 model using Hugging Face Transformers."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encode-only",
|
||||
action="store_true",
|
||||
help="Only run the encoder and print the embeddings.",
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="t5-small",
|
||||
help="The huggingface name of the T5 model to save.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.encode_only:
|
||||
embed(args.model)
|
||||
else:
|
||||
generate(args.model)
|
||||
|
Reference in New Issue
Block a user