diff --git a/t5/hf_t5.py b/t5/hf_t5.py new file mode 100644 index 00000000..a1910afb --- /dev/null +++ b/t5/hf_t5.py @@ -0,0 +1,39 @@ +from transformers import T5EncoderModel, AutoTokenizer + +import argparse + + +def run(t5_model: str): + batch = [ + "translate English to German: That is good.", + "This is an example of T5 working on MLX.", + ] + + tokenizer = AutoTokenizer.from_pretrained(t5_model) + torch_model = T5EncoderModel.from_pretrained(t5_model) + torch_tokens = tokenizer(batch, return_tensors="pt", padding=True) + torch_forward = torch_model(**torch_tokens) + torch_output = torch_forward.last_hidden_state.detach().numpy() + + print("\n TF BERT:") + for input_str, embedding in list(zip(batch, torch_output)): + print("Input:", input_str) + print(embedding) + print() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Run the T5 model using Hugging Face Transformers." + ) + parser.add_argument( + "--model", + choices=[ + "t5-small", + ], + default="t5-small", + help="The huggingface name of the T5 model to save.", + ) + args = parser.parse_args() + + run(args.model) diff --git a/t5/t5.py b/t5/t5.py index a59e5ab4..63222eeb 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -355,6 +355,12 @@ if __name__ == "__main__": help="", default="translate English to German: That is good.", ) + parser.add_argument( + "--encode-only", + action='store_true', + default=False, + help="Whether to decode or not", + ) parser.add_argument( "--max_tokens", "-m", @@ -384,6 +390,14 @@ if __name__ == "__main__": prompt = mx.array(prompt) + if args.encode_only: + print("[INFO] Encoding with T5...", flush=True) + print(args.prompt, end="", flush=True) + embeddings = model.wte(prompt) + encoder_output = model.encoder(embeddings, mask=None) + print(encoder_output, flush=True) + exit(0) + print("[INFO] Generating with T5...", flush=True) print(args.prompt, end="", flush=True)