Use max tokens from options in mlx_lm evaluate (#1302)

This commit is contained in:
Awni Hannun 2025-02-26 15:46:16 -08:00 committed by GitHub
parent 56e60ad5a6
commit 0f240a4c7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -289,17 +289,15 @@ class MLXLM(LM):
contexts, options = zip(*[req.args for req in requests]) contexts, options = zip(*[req.args for req in requests])
# contrary to the doc the second element of the tuple contains # contrary to the doc the second element of the tuple contains
# {'do_sample': False, 'until': ['\n\n'], 'temperature': 0} # {'do_sample': False, 'until': ['\n\n'], 'temperature': 0}
keys = list(options[0].keys())
assert "until" in keys
untils = [x["until"] for x in options]
completions = [] completions = []
for context, until in tqdm(zip(contexts, untils), total=len(contexts)): for context, opt in tqdm(zip(contexts, options), total=len(contexts)):
until = opt["until"]
context = self.tokenizer.encode( context = self.tokenizer.encode(
context, add_special_tokens=not self.use_chat_template context, add_special_tokens=not self.use_chat_template
) )
max_tokens = min( max_tokens = min(
self._max_tokens, opt.get("max_gen_tokens", self._max_tokens),
self.tokenizer.model_max_length - len(context), self.tokenizer.model_max_length - len(context),
) )
text = "" text = ""
@ -334,9 +332,9 @@ def main():
) )
parser.add_argument( parser.add_argument(
"--limit", "--limit",
default=1.0, default=100,
help="Limit the number of examples per task.", help="Limit the number of examples per task.",
type=float, type=int,
) )
parser.add_argument("--seed", type=int, default=123, help="Random seed.") parser.add_argument("--seed", type=int, default=123, help="Random seed.")
parser.add_argument( parser.add_argument(