Add hf generation for comparison

This commit is contained in:
Juarez Bochi 2023-12-18 11:35:16 -05:00
parent 305a52dde8
commit 5ae339f6d2
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6

View File

@ -1,9 +1,9 @@
from transformers import T5EncoderModel, AutoTokenizer from transformers import T5ForConditionalGeneration, T5EncoderModel, AutoTokenizer
import argparse import argparse
def run(t5_model: str): def embed(t5_model: str):
batch = [ batch = [
"translate English to German: That is good.", "translate English to German: That is good.",
"This is an example of T5 working on MLX.", "This is an example of T5 working on MLX.",
@ -22,15 +22,33 @@ def run(t5_model: str):
print() print()
def generate(t5_model: str):
prompt = "translate English to German: As much as six inches of rain could fall in the New York City region through Monday morning, and officials warned of flooding along the coast."
tokenizer = AutoTokenizer.from_pretrained(t5_model)
torch_model = T5ForConditionalGeneration.from_pretrained(t5_model)
torch_tokens = tokenizer(prompt, return_tensors="pt", padding=True).input_ids
outputs = torch_model.generate(torch_tokens)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Run the T5 model using Hugging Face Transformers." description="Run the T5 model using Hugging Face Transformers."
) )
parser.add_argument(
"--encode-only",
action="store_true",
help="Only run the encoder and print the embeddings.",
default=False,
)
parser.add_argument( parser.add_argument(
"--model", "--model",
default="t5-small", default="t5-small",
help="The huggingface name of the T5 model to save.", help="The huggingface name of the T5 model to save.",
) )
args = parser.parse_args() args = parser.parse_args()
if args.encode_only:
embed(args.model)
else:
generate(args.model)
run(args.model)