Remove model_config & add eos_token_id

This commit is contained in:
madroid 2024-12-08 12:08:09 +08:00
parent 1cfb005647
commit facaf01b8d
2 changed files with 13 additions and 7 deletions

View File

@ -259,7 +259,11 @@ class TokenizerWrapper:
):
self._tokenizer = tokenizer
self._detokenizer = detokenizer_class(tokenizer)
self._eos_token_ids = set(eos_token_ids) if eos_token_ids is not None else {tokenizer.eos_token_id}
self._eos_token_ids = (
set(eos_token_ids)
if eos_token_ids is not None
else {tokenizer.eos_token_id}
)
def __getattr__(self, attr):
if attr == "detokenizer":
@ -323,7 +327,7 @@ def _is_bpe_decoder(decoder):
return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel"
def load_tokenizer(model_path, tokenizer_config_extra={}, model_config={}):
def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_id=None):
"""Load a huggingface tokenizer and try to infer the type of streaming
detokenizer to use.
@ -344,8 +348,9 @@ def load_tokenizer(model_path, tokenizer_config_extra={}, model_config={}):
elif _is_bpe_decoder(tokenizer_content["decoder"]):
detokenizer_class = BPEStreamingDetokenizer
eos_token_id = model_config["eos_token_id"]
eos_token_ids = set(eos_token_id) if isinstance(eos_token_id, list) else {eos_token_id}
eos_token_ids = (
set(eos_token_id) if isinstance(eos_token_id, list) else {eos_token_id}
)
return TokenizerWrapper(
AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra),

View File

@ -551,8 +551,9 @@ def load(
if adapter_path is not None:
model = load_adapters(model, adapter_path)
model.eval()
tokenizer = load_tokenizer(model_path, tokenizer_config, model_config=config)
tokenizer = load_tokenizer(
model_path, tokenizer_config, eos_token_id=config["eos_token_id"]
)
return model, tokenizer
@ -562,7 +563,7 @@ def fetch_from_hub(
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
config = load_config(model_path)
model = load_model(model_path, lazy, model_config=config)
tokenizer = load_tokenizer(model_path, model_config=config)
tokenizer = load_tokenizer(model_path, eos_token_id=config["eos_token_id"])
return model, config, tokenizer