mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 11:45:16 +08:00
Remove model_config & add eos_token_id
This commit is contained in:
parent
1cfb005647
commit
facaf01b8d
@ -259,7 +259,11 @@ class TokenizerWrapper:
|
|||||||
):
|
):
|
||||||
self._tokenizer = tokenizer
|
self._tokenizer = tokenizer
|
||||||
self._detokenizer = detokenizer_class(tokenizer)
|
self._detokenizer = detokenizer_class(tokenizer)
|
||||||
self._eos_token_ids = set(eos_token_ids) if eos_token_ids is not None else {tokenizer.eos_token_id}
|
self._eos_token_ids = (
|
||||||
|
set(eos_token_ids)
|
||||||
|
if eos_token_ids is not None
|
||||||
|
else {tokenizer.eos_token_id}
|
||||||
|
)
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
def __getattr__(self, attr):
|
||||||
if attr == "detokenizer":
|
if attr == "detokenizer":
|
||||||
@ -323,7 +327,7 @@ def _is_bpe_decoder(decoder):
|
|||||||
return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel"
|
return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel"
|
||||||
|
|
||||||
|
|
||||||
def load_tokenizer(model_path, tokenizer_config_extra={}, model_config={}):
|
def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_id=None):
|
||||||
"""Load a huggingface tokenizer and try to infer the type of streaming
|
"""Load a huggingface tokenizer and try to infer the type of streaming
|
||||||
detokenizer to use.
|
detokenizer to use.
|
||||||
|
|
||||||
@ -344,8 +348,9 @@ def load_tokenizer(model_path, tokenizer_config_extra={}, model_config={}):
|
|||||||
elif _is_bpe_decoder(tokenizer_content["decoder"]):
|
elif _is_bpe_decoder(tokenizer_content["decoder"]):
|
||||||
detokenizer_class = BPEStreamingDetokenizer
|
detokenizer_class = BPEStreamingDetokenizer
|
||||||
|
|
||||||
eos_token_id = model_config["eos_token_id"]
|
eos_token_ids = (
|
||||||
eos_token_ids = set(eos_token_id) if isinstance(eos_token_id, list) else {eos_token_id}
|
set(eos_token_id) if isinstance(eos_token_id, list) else {eos_token_id}
|
||||||
|
)
|
||||||
|
|
||||||
return TokenizerWrapper(
|
return TokenizerWrapper(
|
||||||
AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra),
|
AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra),
|
||||||
|
@ -551,8 +551,9 @@ def load(
|
|||||||
if adapter_path is not None:
|
if adapter_path is not None:
|
||||||
model = load_adapters(model, adapter_path)
|
model = load_adapters(model, adapter_path)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
tokenizer = load_tokenizer(
|
||||||
tokenizer = load_tokenizer(model_path, tokenizer_config, model_config=config)
|
model_path, tokenizer_config, eos_token_id=config["eos_token_id"]
|
||||||
|
)
|
||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
@ -562,7 +563,7 @@ def fetch_from_hub(
|
|||||||
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
||||||
config = load_config(model_path)
|
config = load_config(model_path)
|
||||||
model = load_model(model_path, lazy, model_config=config)
|
model = load_model(model_path, lazy, model_config=config)
|
||||||
tokenizer = load_tokenizer(model_path, model_config=config)
|
tokenizer = load_tokenizer(model_path, eos_token_id=config["eos_token_id"])
|
||||||
return model, config, tokenizer
|
return model, config, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user