mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Add optional EOS token for llava example (#753)
* add optional EOS token * add tokenizer config to align with MLX LM example * formtatting fixes
This commit is contained in:
parent
c0019c4908
commit
bfbc0e434a
@ -43,6 +43,12 @@ def parse_arguments():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--temp", type=float, default=0.3, help="Temperature for sampling."
|
"--temp", type=float, default=0.3, help="Temperature for sampling."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--eos-token",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="End of sequence token for tokenizer",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -79,8 +85,8 @@ def prepare_inputs(processor, image, prompt):
|
|||||||
return input_ids, pixel_values
|
return input_ids, pixel_values
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_path):
|
def load_model(model_path, tokenizer_config={}):
|
||||||
processor = AutoProcessor.from_pretrained(model_path)
|
processor = AutoProcessor.from_pretrained(model_path, **tokenizer_config)
|
||||||
model = LlavaModel.from_pretrained(model_path)
|
model = LlavaModel.from_pretrained(model_path)
|
||||||
return processor, model
|
return processor, model
|
||||||
|
|
||||||
@ -93,7 +99,6 @@ def sample(logits, temperature=0.0):
|
|||||||
|
|
||||||
|
|
||||||
def generate_text(input_ids, pixel_values, model, processor, max_tokens, temperature):
|
def generate_text(input_ids, pixel_values, model, processor, max_tokens, temperature):
|
||||||
|
|
||||||
logits, cache = model(input_ids, pixel_values)
|
logits, cache = model(input_ids, pixel_values)
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
y = sample(logits, temperature=temperature)
|
y = sample(logits, temperature=temperature)
|
||||||
@ -113,7 +118,12 @@ def generate_text(input_ids, pixel_values, model, processor, max_tokens, tempera
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_arguments()
|
args = parse_arguments()
|
||||||
processor, model = load_model(args.model)
|
|
||||||
|
tokenizer_config = {}
|
||||||
|
if args.eos_token is not None:
|
||||||
|
tokenizer_config["eos_token"] = args.eos_token
|
||||||
|
|
||||||
|
processor, model = load_model(args.model, tokenizer_config)
|
||||||
|
|
||||||
prompt = codecs.decode(args.prompt, "unicode_escape")
|
prompt = codecs.decode(args.prompt, "unicode_escape")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user