diff --git a/t5/hf_t5.py b/t5/hf_t5.py index b39f807c..ddd99610 100644 --- a/t5/hf_t5.py +++ b/t5/hf_t5.py @@ -27,7 +27,7 @@ def generate(t5_model: str): tokenizer = AutoTokenizer.from_pretrained(t5_model) torch_model = T5ForConditionalGeneration.from_pretrained(t5_model) torch_tokens = tokenizer(prompt, return_tensors="pt", padding=True).input_ids - outputs = torch_model.generate(torch_tokens) + outputs = torch_model.generate(torch_tokens, do_sample=False, max_length=512) print(tokenizer.decode(outputs[0], skip_special_tokens=True))