Change the eos-token argument for mlx_lm.generate (#1176)

This commit is contained in:
Angelos Katharopoulos 2025-01-05 22:26:05 -08:00 committed by GitHub
parent c4833a2f55
commit 25ec2d8c44
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 4 deletions

View File

@ -43,10 +43,11 @@ def setup_arg_parser():
help="Optional path for the trained adapter weights and config.",
)
parser.add_argument(
"--eos-token",
"--extra-eos-token",
type=str,
default=None,
help="End of sequence token for tokenizer",
nargs="+",
help="Add tokens in the list of eos tokens that stop generation.",
)
parser.add_argument(
"--system-prompt",
@ -161,8 +162,6 @@ def main():
{} if not using_cache else json.loads(metadata["tokenizer_config"])
)
tokenizer_config["trust_remote_code"] = True
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
model_path = args.model
if using_cache:
@ -181,6 +180,8 @@ def main():
adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config,
)
for eos_token in args.extra_eos_token:
tokenizer.add_eos_token(eos_token)
if args.use_default_chat_template:
if tokenizer.chat_template is None:

View File

@ -266,6 +266,18 @@ class TokenizerWrapper:
else {tokenizer.eos_token_id}
)
def add_eos_token(self, token: str):
token_id = None
try:
token_id = int(token)
except ValueError:
token_id = self._tokenizer.convert_tokens_to_ids(token)
if token_id is None:
raise ValueError(f"'{token}' is not a token for this tokenizer")
self._eos_token_ids.add(token_id)
def __getattr__(self, attr):
if attr == "detokenizer":
return self._detokenizer