mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
Add support for byt5 models (#161)
* Add support for byt5 models * Remove unused import
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import argparse
|
||||
|
||||
from transformers import AutoTokenizer, T5EncoderModel, T5ForConditionalGeneration
|
||||
from transformers import AutoTokenizer, T5EncoderModel, AutoModelForSeq2SeqLM
|
||||
|
||||
|
||||
def embed(t5_model: str):
|
||||
@@ -25,7 +25,7 @@ def embed(t5_model: str):
|
||||
def generate(t5_model: str):
|
||||
prompt = "translate English to German: As much as six inches of rain could fall in the New York City region through Monday morning, and officials warned of flooding along the coast."
|
||||
tokenizer = AutoTokenizer.from_pretrained(t5_model)
|
||||
torch_model = T5ForConditionalGeneration.from_pretrained(t5_model)
|
||||
torch_model = AutoModelForSeq2SeqLM.from_pretrained(t5_model)
|
||||
torch_tokens = tokenizer(prompt, return_tensors="pt", padding=True).input_ids
|
||||
outputs = torch_model.generate(torch_tokens, do_sample=False, max_length=512)
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
|
||||
Reference in New Issue
Block a user