Add support for byt5 models (#161)

* Add support for byt5 models

* Remove unused import
This commit is contained in:
Juarez Bochi
2023-12-21 11:46:36 -05:00
committed by GitHub
parent 6c574dbecf
commit 4c9db80ed2
2 changed files with 8 additions and 7 deletions

View File

@@ -1,6 +1,6 @@
import argparse
from transformers import AutoTokenizer, T5EncoderModel, T5ForConditionalGeneration
from transformers import AutoTokenizer, T5EncoderModel, AutoModelForSeq2SeqLM
def embed(t5_model: str):
@@ -25,7 +25,7 @@ def embed(t5_model: str):
def generate(t5_model: str):
prompt = "translate English to German: As much as six inches of rain could fall in the New York City region through Monday morning, and officials warned of flooding along the coast."
tokenizer = AutoTokenizer.from_pretrained(t5_model)
torch_model = T5ForConditionalGeneration.from_pretrained(t5_model)
torch_model = AutoModelForSeq2SeqLM.from_pretrained(t5_model)
torch_tokens = tokenizer(prompt, return_tensors="pt", padding=True).input_ids
outputs = torch_model.generate(torch_tokens, do_sample=False, max_length=512)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))