diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 7f0e77b9..ca9bb792 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -259,7 +259,11 @@ class TokenizerWrapper: ): self._tokenizer = tokenizer self._detokenizer = detokenizer_class(tokenizer) - self._eos_token_ids = set(eos_token_ids) if eos_token_ids is not None else {tokenizer.eos_token_id} + self._eos_token_ids = ( + set(eos_token_ids) + if eos_token_ids is not None + else {tokenizer.eos_token_id} + ) def __getattr__(self, attr): if attr == "detokenizer": @@ -323,7 +327,7 @@ def _is_bpe_decoder(decoder): return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel" -def load_tokenizer(model_path, tokenizer_config_extra={}, model_config={}): +def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_id=None): """Load a huggingface tokenizer and try to infer the type of streaming detokenizer to use. @@ -344,8 +348,9 @@ def load_tokenizer(model_path, tokenizer_config_extra={}, model_config={}): elif _is_bpe_decoder(tokenizer_content["decoder"]): detokenizer_class = BPEStreamingDetokenizer - eos_token_id = model_config["eos_token_id"] - eos_token_ids = set(eos_token_id) if isinstance(eos_token_id, list) else {eos_token_id} + eos_token_ids = ( + set(eos_token_id) if isinstance(eos_token_id, list) else {eos_token_id} + ) return TokenizerWrapper( AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra), diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 17d47697..32b2fd20 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -551,8 +551,9 @@ def load( if adapter_path is not None: model = load_adapters(model, adapter_path) model.eval() - - tokenizer = load_tokenizer(model_path, tokenizer_config, model_config=config) + tokenizer = load_tokenizer( + model_path, tokenizer_config, eos_token_id=config["eos_token_id"] + ) return model, tokenizer @@ -562,7 +563,7 @@ def fetch_from_hub( ) -> Tuple[nn.Module, dict, PreTrainedTokenizer]: config = load_config(model_path) model = load_model(model_path, lazy, model_config=config) - tokenizer = load_tokenizer(model_path, model_config=config) + tokenizer = load_tokenizer(model_path, eos_token_id=config["eos_token_id"]) return model, config, tokenizer