mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
Utils to compare encoder output
This commit is contained in:
parent
7e42349f4c
commit
4ec2b6eec3
39
t5/hf_t5.py
Normal file
39
t5/hf_t5.py
Normal file
@ -0,0 +1,39 @@
|
||||
from transformers import T5EncoderModel, AutoTokenizer
|
||||
|
||||
import argparse
|
||||
|
||||
|
||||
def run(t5_model: str):
|
||||
batch = [
|
||||
"translate English to German: That is good.",
|
||||
"This is an example of T5 working on MLX.",
|
||||
]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(t5_model)
|
||||
torch_model = T5EncoderModel.from_pretrained(t5_model)
|
||||
torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
|
||||
torch_forward = torch_model(**torch_tokens)
|
||||
torch_output = torch_forward.last_hidden_state.detach().numpy()
|
||||
|
||||
print("\n TF BERT:")
|
||||
for input_str, embedding in list(zip(batch, torch_output)):
|
||||
print("Input:", input_str)
|
||||
print(embedding)
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run the T5 model using Hugging Face Transformers."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
choices=[
|
||||
"t5-small",
|
||||
],
|
||||
default="t5-small",
|
||||
help="The huggingface name of the T5 model to save.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
run(args.model)
|
14
t5/t5.py
14
t5/t5.py
@ -355,6 +355,12 @@ if __name__ == "__main__":
|
||||
help="",
|
||||
default="translate English to German: That is good.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encode-only",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Whether to decode or not",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_tokens",
|
||||
"-m",
|
||||
@ -384,6 +390,14 @@ if __name__ == "__main__":
|
||||
|
||||
prompt = mx.array(prompt)
|
||||
|
||||
if args.encode_only:
|
||||
print("[INFO] Encoding with T5...", flush=True)
|
||||
print(args.prompt, end="", flush=True)
|
||||
embeddings = model.wte(prompt)
|
||||
encoder_output = model.encoder(embeddings, mask=None)
|
||||
print(encoder_output, flush=True)
|
||||
exit(0)
|
||||
|
||||
print("[INFO] Generating with T5...", flush=True)
|
||||
print(args.prompt, end="", flush=True)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user