mlx-examples/t5/hf_t5.py
Juarez Bochi 10a7b99e83
Add T5 and Flan-T5 example (#113)
* Add skeleton

* Load all encoder weights

* Pass config to all modules, fix ln

* Load position bias embeddings

* Load decoder weights

* Move position biases to attention module

* translate pytorch to mx

* Fix default prompt

* Fix relative_attention_max_distance config

* No scaling, no encoder mask

* LM head

* Decode (broken after 1st token)

* Use position bias in all layers

* Utils to compare encoder output

* Fix layer norm

* Fix decoder mask

* Use position bias in decoder

* Concatenate tokens

* Remove prints

* Stop on eos

* Measure tokens/s

* with cache

* bug fix with bidirectional only for encoder, add offset to position bias

* format

* Fix T5.__call__

* Stream output

* Add argument to generate float16 npz

* Load config from HF to support any model

* Uncomment bidirectional param

* Add gitignore

* Add readme.md for t5

* Fix relative position scale

* Fix --encode-only

* Run hf_t5 with any model

* Add hf generation for comparison

* Fix type for attention mask

* Increase hf max_length

* Rescale output before projecting on vocab

* readme updates

* nits

* Pass ln2 to cross attention

* Fix example

* Fix attention for 3b model

* fp16, abstract tokenizer a bit, format

* clamp for low precision

* higher clipping, remove non-helpful casts

* default to fp32 for now

* Adds support for flan-t5

* Update t5 docs on variant support

* readme flan

* nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-18 20:25:34 -08:00

55 lines
1.9 KiB
Python

from transformers import T5ForConditionalGeneration, T5EncoderModel, AutoTokenizer
import argparse
def embed(t5_model: str):
batch = [
"translate English to German: That is good.",
"This is an example of T5 working on MLX.",
]
tokenizer = AutoTokenizer.from_pretrained(t5_model)
torch_model = T5EncoderModel.from_pretrained(t5_model)
torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
torch_forward = torch_model(**torch_tokens, output_hidden_states=True)
torch_output = torch_forward.last_hidden_state.detach().numpy()
print("\n TF BERT:")
for input_str, embedding in list(zip(batch, torch_output)):
print("Input:", input_str)
print(embedding)
print()
def generate(t5_model: str):
prompt = "translate English to German: As much as six inches of rain could fall in the New York City region through Monday morning, and officials warned of flooding along the coast."
tokenizer = AutoTokenizer.from_pretrained(t5_model)
torch_model = T5ForConditionalGeneration.from_pretrained(t5_model)
torch_tokens = tokenizer(prompt, return_tensors="pt", padding=True).input_ids
outputs = torch_model.generate(torch_tokens, do_sample=False, max_length=512)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run the T5 model using Hugging Face Transformers."
)
parser.add_argument(
"--encode-only",
action="store_true",
help="Only run the encoder and print the embeddings.",
default=False,
)
parser.add_argument(
"--model",
default="t5-small",
help="The huggingface name of the T5 model to save.",
)
args = parser.parse_args()
if args.encode_only:
embed(args.model)
else:
generate(args.model)