diff --git a/t5/hf_t5.py b/t5/hf_t5.py index da269e2e..b39f807c 100644 --- a/t5/hf_t5.py +++ b/t5/hf_t5.py @@ -1,9 +1,9 @@ -from transformers import T5EncoderModel, AutoTokenizer +from transformers import T5ForConditionalGeneration, T5EncoderModel, AutoTokenizer import argparse -def run(t5_model: str): +def embed(t5_model: str): batch = [ "translate English to German: That is good.", "This is an example of T5 working on MLX.", @@ -22,15 +22,33 @@ def run(t5_model: str): print() +def generate(t5_model: str): + prompt = "translate English to German: As much as six inches of rain could fall in the New York City region through Monday morning, and officials warned of flooding along the coast." + tokenizer = AutoTokenizer.from_pretrained(t5_model) + torch_model = T5ForConditionalGeneration.from_pretrained(t5_model) + torch_tokens = tokenizer(prompt, return_tensors="pt", padding=True).input_ids + outputs = torch_model.generate(torch_tokens) + print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + + if __name__ == "__main__": parser = argparse.ArgumentParser( description="Run the T5 model using Hugging Face Transformers." ) + parser.add_argument( + "--encode-only", + action="store_true", + help="Only run the encoder and print the embeddings.", + default=False, + ) parser.add_argument( "--model", default="t5-small", help="The huggingface name of the T5 model to save.", ) args = parser.parse_args() + if args.encode_only: + embed(args.model) + else: + generate(args.model) - run(args.model)