mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Add optional EOS token for llava example (#753)
* add optional EOS token * add tokenizer config to align with MLX LM example * formtatting fixes
This commit is contained in:
parent
c0019c4908
commit
bfbc0e434a
@ -43,6 +43,12 @@ def parse_arguments():
|
||||
parser.add_argument(
|
||||
"--temp", type=float, default=0.3, help="Temperature for sampling."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eos-token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="End of sequence token for tokenizer",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -79,8 +85,8 @@ def prepare_inputs(processor, image, prompt):
|
||||
return input_ids, pixel_values
|
||||
|
||||
|
||||
def load_model(model_path):
|
||||
processor = AutoProcessor.from_pretrained(model_path)
|
||||
def load_model(model_path, tokenizer_config={}):
|
||||
processor = AutoProcessor.from_pretrained(model_path, **tokenizer_config)
|
||||
model = LlavaModel.from_pretrained(model_path)
|
||||
return processor, model
|
||||
|
||||
@ -93,7 +99,6 @@ def sample(logits, temperature=0.0):
|
||||
|
||||
|
||||
def generate_text(input_ids, pixel_values, model, processor, max_tokens, temperature):
|
||||
|
||||
logits, cache = model(input_ids, pixel_values)
|
||||
logits = logits[:, -1, :]
|
||||
y = sample(logits, temperature=temperature)
|
||||
@ -113,7 +118,12 @@ def generate_text(input_ids, pixel_values, model, processor, max_tokens, tempera
|
||||
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
processor, model = load_model(args.model)
|
||||
|
||||
tokenizer_config = {}
|
||||
if args.eos_token is not None:
|
||||
tokenizer_config["eos_token"] = args.eos_token
|
||||
|
||||
processor, model = load_model(args.model, tokenizer_config)
|
||||
|
||||
prompt = codecs.decode(args.prompt, "unicode_escape")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user