Utils to compare encoder output

This commit is contained in:
Juarez Bochi 2023-12-17 07:20:24 -05:00
parent 7e42349f4c
commit 4ec2b6eec3
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6
2 changed files with 53 additions and 0 deletions

39
t5/hf_t5.py Normal file
View File

@ -0,0 +1,39 @@
from transformers import T5EncoderModel, AutoTokenizer
import argparse
def run(t5_model: str):
batch = [
"translate English to German: That is good.",
"This is an example of T5 working on MLX.",
]
tokenizer = AutoTokenizer.from_pretrained(t5_model)
torch_model = T5EncoderModel.from_pretrained(t5_model)
torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
torch_forward = torch_model(**torch_tokens)
torch_output = torch_forward.last_hidden_state.detach().numpy()
print("\n TF BERT:")
for input_str, embedding in list(zip(batch, torch_output)):
print("Input:", input_str)
print(embedding)
print()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run the T5 model using Hugging Face Transformers."
)
parser.add_argument(
"--model",
choices=[
"t5-small",
],
default="t5-small",
help="The huggingface name of the T5 model to save.",
)
args = parser.parse_args()
run(args.model)

View File

@ -355,6 +355,12 @@ if __name__ == "__main__":
help="",
default="translate English to German: That is good.",
)
parser.add_argument(
"--encode-only",
action='store_true',
default=False,
help="Whether to decode or not",
)
parser.add_argument(
"--max_tokens",
"-m",
@ -384,6 +390,14 @@ if __name__ == "__main__":
prompt = mx.array(prompt)
if args.encode_only:
print("[INFO] Encoding with T5...", flush=True)
print(args.prompt, end="", flush=True)
embeddings = model.wte(prompt)
encoder_output = model.encoder(embeddings, mask=None)
print(encoder_output, flush=True)
exit(0)
print("[INFO] Generating with T5...", flush=True)
print(args.prompt, end="", flush=True)