Change the eos-token argument for mlx_lm.generate (#1176)

This commit is contained in:
Angelos Katharopoulos 2025-01-05 22:26:05 -08:00 committed by GitHub
parent c4833a2f55
commit 25ec2d8c44
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 4 deletions

View File

@ -43,10 +43,11 @@ def setup_arg_parser():
help="Optional path for the trained adapter weights and config.", help="Optional path for the trained adapter weights and config.",
) )
parser.add_argument( parser.add_argument(
"--eos-token", "--extra-eos-token",
type=str, type=str,
default=None, default=None,
help="End of sequence token for tokenizer", nargs="+",
help="Add tokens in the list of eos tokens that stop generation.",
) )
parser.add_argument( parser.add_argument(
"--system-prompt", "--system-prompt",
@ -161,8 +162,6 @@ def main():
{} if not using_cache else json.loads(metadata["tokenizer_config"]) {} if not using_cache else json.loads(metadata["tokenizer_config"])
) )
tokenizer_config["trust_remote_code"] = True tokenizer_config["trust_remote_code"] = True
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
model_path = args.model model_path = args.model
if using_cache: if using_cache:
@ -181,6 +180,8 @@ def main():
adapter_path=args.adapter_path, adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config, tokenizer_config=tokenizer_config,
) )
for eos_token in args.extra_eos_token:
tokenizer.add_eos_token(eos_token)
if args.use_default_chat_template: if args.use_default_chat_template:
if tokenizer.chat_template is None: if tokenizer.chat_template is None:

View File

@ -266,6 +266,18 @@ class TokenizerWrapper:
else {tokenizer.eos_token_id} else {tokenizer.eos_token_id}
) )
def add_eos_token(self, token: str):
token_id = None
try:
token_id = int(token)
except ValueError:
token_id = self._tokenizer.convert_tokens_to_ids(token)
if token_id is None:
raise ValueError(f"'{token}' is not a token for this tokenizer")
self._eos_token_ids.add(token_id)
def __getattr__(self, attr): def __getattr__(self, attr):
if attr == "detokenizer": if attr == "detokenizer":
return self._detokenizer return self._detokenizer