mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Change the eos-token argument for mlx_lm.generate (#1176)
This commit is contained in:
parent
c4833a2f55
commit
25ec2d8c44
@ -43,10 +43,11 @@ def setup_arg_parser():
|
|||||||
help="Optional path for the trained adapter weights and config.",
|
help="Optional path for the trained adapter weights and config.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--eos-token",
|
"--extra-eos-token",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="End of sequence token for tokenizer",
|
nargs="+",
|
||||||
|
help="Add tokens in the list of eos tokens that stop generation.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--system-prompt",
|
"--system-prompt",
|
||||||
@ -161,8 +162,6 @@ def main():
|
|||||||
{} if not using_cache else json.loads(metadata["tokenizer_config"])
|
{} if not using_cache else json.loads(metadata["tokenizer_config"])
|
||||||
)
|
)
|
||||||
tokenizer_config["trust_remote_code"] = True
|
tokenizer_config["trust_remote_code"] = True
|
||||||
if args.eos_token is not None:
|
|
||||||
tokenizer_config["eos_token"] = args.eos_token
|
|
||||||
|
|
||||||
model_path = args.model
|
model_path = args.model
|
||||||
if using_cache:
|
if using_cache:
|
||||||
@ -181,6 +180,8 @@ def main():
|
|||||||
adapter_path=args.adapter_path,
|
adapter_path=args.adapter_path,
|
||||||
tokenizer_config=tokenizer_config,
|
tokenizer_config=tokenizer_config,
|
||||||
)
|
)
|
||||||
|
for eos_token in args.extra_eos_token:
|
||||||
|
tokenizer.add_eos_token(eos_token)
|
||||||
|
|
||||||
if args.use_default_chat_template:
|
if args.use_default_chat_template:
|
||||||
if tokenizer.chat_template is None:
|
if tokenizer.chat_template is None:
|
||||||
|
@ -266,6 +266,18 @@ class TokenizerWrapper:
|
|||||||
else {tokenizer.eos_token_id}
|
else {tokenizer.eos_token_id}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def add_eos_token(self, token: str):
|
||||||
|
token_id = None
|
||||||
|
try:
|
||||||
|
token_id = int(token)
|
||||||
|
except ValueError:
|
||||||
|
token_id = self._tokenizer.convert_tokens_to_ids(token)
|
||||||
|
|
||||||
|
if token_id is None:
|
||||||
|
raise ValueError(f"'{token}' is not a token for this tokenizer")
|
||||||
|
|
||||||
|
self._eos_token_ids.add(token_id)
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
def __getattr__(self, attr):
|
||||||
if attr == "detokenizer":
|
if attr == "detokenizer":
|
||||||
return self._detokenizer
|
return self._detokenizer
|
||||||
|
Loading…
Reference in New Issue
Block a user