Add optional EOS token for llava example (#753)

* add optional EOS token

* add tokenizer config to align with MLX LM example

* formtatting fixes
This commit is contained in:
Albert Avetisian 2024-05-08 09:04:36 -04:00 committed by GitHub
parent c0019c4908
commit bfbc0e434a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -43,6 +43,12 @@ def parse_arguments():
parser.add_argument( parser.add_argument(
"--temp", type=float, default=0.3, help="Temperature for sampling." "--temp", type=float, default=0.3, help="Temperature for sampling."
) )
parser.add_argument(
"--eos-token",
type=str,
default=None,
help="End of sequence token for tokenizer",
)
return parser.parse_args() return parser.parse_args()
@ -79,8 +85,8 @@ def prepare_inputs(processor, image, prompt):
return input_ids, pixel_values return input_ids, pixel_values
def load_model(model_path): def load_model(model_path, tokenizer_config={}):
processor = AutoProcessor.from_pretrained(model_path) processor = AutoProcessor.from_pretrained(model_path, **tokenizer_config)
model = LlavaModel.from_pretrained(model_path) model = LlavaModel.from_pretrained(model_path)
return processor, model return processor, model
@ -93,7 +99,6 @@ def sample(logits, temperature=0.0):
def generate_text(input_ids, pixel_values, model, processor, max_tokens, temperature): def generate_text(input_ids, pixel_values, model, processor, max_tokens, temperature):
logits, cache = model(input_ids, pixel_values) logits, cache = model(input_ids, pixel_values)
logits = logits[:, -1, :] logits = logits[:, -1, :]
y = sample(logits, temperature=temperature) y = sample(logits, temperature=temperature)
@ -113,7 +118,12 @@ def generate_text(input_ids, pixel_values, model, processor, max_tokens, tempera
def main(): def main():
args = parse_arguments() args = parse_arguments()
processor, model = load_model(args.model)
tokenizer_config = {}
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
processor, model = load_model(args.model, tokenizer_config)
prompt = codecs.decode(args.prompt, "unicode_escape") prompt = codecs.decode(args.prompt, "unicode_escape")