diff --git a/llava/generate.py b/llava/generate.py index 9535bab9..8067839e 100644 --- a/llava/generate.py +++ b/llava/generate.py @@ -43,6 +43,12 @@ def parse_arguments(): parser.add_argument( "--temp", type=float, default=0.3, help="Temperature for sampling." ) + parser.add_argument( + "--eos-token", + type=str, + default=None, + help="End of sequence token for tokenizer", + ) return parser.parse_args() @@ -79,8 +85,8 @@ def prepare_inputs(processor, image, prompt): return input_ids, pixel_values -def load_model(model_path): - processor = AutoProcessor.from_pretrained(model_path) +def load_model(model_path, tokenizer_config={}): + processor = AutoProcessor.from_pretrained(model_path, **tokenizer_config) model = LlavaModel.from_pretrained(model_path) return processor, model @@ -93,7 +99,6 @@ def sample(logits, temperature=0.0): def generate_text(input_ids, pixel_values, model, processor, max_tokens, temperature): - logits, cache = model(input_ids, pixel_values) logits = logits[:, -1, :] y = sample(logits, temperature=temperature) @@ -113,7 +118,12 @@ def generate_text(input_ids, pixel_values, model, processor, max_tokens, tempera def main(): args = parse_arguments() - processor, model = load_model(args.model) + + tokenizer_config = {} + if args.eos_token is not None: + tokenizer_config["eos_token"] = args.eos_token + + processor, model = load_model(args.model, tokenizer_config) prompt = codecs.decode(args.prompt, "unicode_escape")