From bfbc0e434a02e96a822ccf7bbcc7ae49ca951db6 Mon Sep 17 00:00:00 2001 From: Albert Avetisian Date: Wed, 8 May 2024 09:04:36 -0400 Subject: [PATCH] Add optional EOS token for llava example (#753) * add optional EOS token * add tokenizer config to align with MLX LM example * formtatting fixes --- llava/generate.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/llava/generate.py b/llava/generate.py index 9535bab9..8067839e 100644 --- a/llava/generate.py +++ b/llava/generate.py @@ -43,6 +43,12 @@ def parse_arguments(): parser.add_argument( "--temp", type=float, default=0.3, help="Temperature for sampling." ) + parser.add_argument( + "--eos-token", + type=str, + default=None, + help="End of sequence token for tokenizer", + ) return parser.parse_args() @@ -79,8 +85,8 @@ def prepare_inputs(processor, image, prompt): return input_ids, pixel_values -def load_model(model_path): - processor = AutoProcessor.from_pretrained(model_path) +def load_model(model_path, tokenizer_config={}): + processor = AutoProcessor.from_pretrained(model_path, **tokenizer_config) model = LlavaModel.from_pretrained(model_path) return processor, model @@ -93,7 +99,6 @@ def sample(logits, temperature=0.0): def generate_text(input_ids, pixel_values, model, processor, max_tokens, temperature): - logits, cache = model(input_ids, pixel_values) logits = logits[:, -1, :] y = sample(logits, temperature=temperature) @@ -113,7 +118,12 @@ def generate_text(input_ids, pixel_values, model, processor, max_tokens, tempera def main(): args = parse_arguments() - processor, model = load_model(args.model) + + tokenizer_config = {} + if args.eos_token is not None: + tokenizer_config["eos_token"] = args.eos_token + + processor, model = load_model(args.model, tokenizer_config) prompt = codecs.decode(args.prompt, "unicode_escape")