mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Add hf generation for comparison
This commit is contained in:
parent
305a52dde8
commit
5ae339f6d2
24
t5/hf_t5.py
24
t5/hf_t5.py
@ -1,9 +1,9 @@
|
|||||||
from transformers import T5EncoderModel, AutoTokenizer
|
from transformers import T5ForConditionalGeneration, T5EncoderModel, AutoTokenizer
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
def run(t5_model: str):
|
def embed(t5_model: str):
|
||||||
batch = [
|
batch = [
|
||||||
"translate English to German: That is good.",
|
"translate English to German: That is good.",
|
||||||
"This is an example of T5 working on MLX.",
|
"This is an example of T5 working on MLX.",
|
||||||
@ -22,15 +22,33 @@ def run(t5_model: str):
|
|||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def generate(t5_model: str):
|
||||||
|
prompt = "translate English to German: As much as six inches of rain could fall in the New York City region through Monday morning, and officials warned of flooding along the coast."
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(t5_model)
|
||||||
|
torch_model = T5ForConditionalGeneration.from_pretrained(t5_model)
|
||||||
|
torch_tokens = tokenizer(prompt, return_tensors="pt", padding=True).input_ids
|
||||||
|
outputs = torch_model.generate(torch_tokens)
|
||||||
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Run the T5 model using Hugging Face Transformers."
|
description="Run the T5 model using Hugging Face Transformers."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--encode-only",
|
||||||
|
action="store_true",
|
||||||
|
help="Only run the encoder and print the embeddings.",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
default="t5-small",
|
default="t5-small",
|
||||||
help="The huggingface name of the T5 model to save.",
|
help="The huggingface name of the T5 model to save.",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
if args.encode_only:
|
||||||
|
embed(args.model)
|
||||||
|
else:
|
||||||
|
generate(args.model)
|
||||||
|
|
||||||
run(args.model)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user