diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 1ea66384..3301edae 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -43,10 +43,11 @@ def setup_arg_parser(): help="Optional path for the trained adapter weights and config.", ) parser.add_argument( - "--eos-token", + "--extra-eos-token", type=str, default=None, - help="End of sequence token for tokenizer", + nargs="+", + help="Add tokens in the list of eos tokens that stop generation.", ) parser.add_argument( "--system-prompt", @@ -161,8 +162,6 @@ def main(): {} if not using_cache else json.loads(metadata["tokenizer_config"]) ) tokenizer_config["trust_remote_code"] = True - if args.eos_token is not None: - tokenizer_config["eos_token"] = args.eos_token model_path = args.model if using_cache: @@ -181,6 +180,8 @@ def main(): adapter_path=args.adapter_path, tokenizer_config=tokenizer_config, ) + for eos_token in args.extra_eos_token: + tokenizer.add_eos_token(eos_token) if args.use_default_chat_template: if tokenizer.chat_template is None: diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index ca3d6c06..1b5bdd77 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -266,6 +266,18 @@ class TokenizerWrapper: else {tokenizer.eos_token_id} ) + def add_eos_token(self, token: str): + token_id = None + try: + token_id = int(token) + except ValueError: + token_id = self._tokenizer.convert_tokens_to_ids(token) + + if token_id is None: + raise ValueError(f"'{token}' is not a token for this tokenizer") + + self._eos_token_ids.add(token_id) + def __getattr__(self, attr): if attr == "detokenizer": return self._detokenizer