diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 0fa41ac0..5c605ed9 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -254,21 +254,29 @@ class TokenizerWrapper: huggingface tokenizer. """ - def __init__(self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer): + def __init__( + self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer, eos_token_ids=None + ): self._tokenizer = tokenizer self._detokenizer = detokenizer_class(tokenizer) + self._eos_token_ids = eos_token_ids or [tokenizer.eos_token_id] def __getattr__(self, attr): if attr == "detokenizer": return self._detokenizer + elif attr == "eos_token_ids": + return self._eos_token_ids elif attr.startswith("_"): return self.__getattribute__(attr) else: return getattr(self._tokenizer, attr) def __setattr__(self, attr, value): - if attr == "detokenizer": - raise AttributeError("Cannot set the detokenizer.") + if attr in {"detokenizer", "eos_token_ids"}: + if attr == "detokenizer": + raise AttributeError("Cannot set the detokenizer.") + elif attr == "eos_token_ids": + self._eos_token_ids = value elif attr.startswith("_"): super().__setattr__(attr, value) else: @@ -315,7 +323,7 @@ def _is_bpe_decoder(decoder): return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel" -def load_tokenizer(model_path, tokenizer_config_extra={}): +def load_tokenizer(model_path, tokenizer_config_extra={}, model_config={}): """Load a huggingface tokenizer and try to infer the type of streaming detokenizer to use. @@ -336,7 +344,11 @@ def load_tokenizer(model_path, tokenizer_config_extra={}): elif _is_bpe_decoder(tokenizer_content["decoder"]): detokenizer_class = BPEStreamingDetokenizer + eos_token_id = model_config["eos_token_id"] + eos_token_ids = eos_token_id if isinstance(eos_token_id, list) else [eos_token_id] + return TokenizerWrapper( AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra), detokenizer_class, + eos_token_ids=eos_token_ids, ) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 86b786ce..17d47697 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -350,7 +350,7 @@ def stream_generate( prompt_time = time.perf_counter() - tic prompt_tps = prompt.size / prompt_time tic = time.perf_counter() - if token == tokenizer.eos_token_id: + if token in tokenizer.eos_token_ids: break detokenizer.add_token(token) @@ -470,9 +470,6 @@ def load_model( ValueError: If the model class or args class are not found or cannot be instantiated. """ - config = load_config(model_path) - config.update(model_config) - weight_files = glob.glob(str(model_path / "model*.safetensors")) if not weight_files: @@ -487,15 +484,15 @@ def load_model( for wf in weight_files: weights.update(mx.load(wf)) - model_class, model_args_class = get_model_classes(config=config) + model_class, model_args_class = get_model_classes(config=model_config) - model_args = model_args_class.from_dict(config) + model_args = model_args_class.from_dict(model_config) model = model_class(model_args) if hasattr(model, "sanitize"): weights = model.sanitize(weights) - if (quantization := config.get("quantization", None)) is not None: + if (quantization := model_config.get("quantization", None)) is not None: # Handle legacy models which may not have everything quantized def class_predicate(p, m): if not hasattr(m, "to_quantized"): @@ -547,11 +544,15 @@ def load( """ model_path = get_model_path(path_or_hf_repo) - model = load_model(model_path, lazy, model_config) + config = load_config(model_path) + config.update(model_config) + + model = load_model(model_path, lazy, config) if adapter_path is not None: model = load_adapters(model, adapter_path) model.eval() - tokenizer = load_tokenizer(model_path, tokenizer_config) + + tokenizer = load_tokenizer(model_path, tokenizer_config, model_config=config) return model, tokenizer @@ -559,9 +560,9 @@ def load( def fetch_from_hub( model_path: Path, lazy: bool = False ) -> Tuple[nn.Module, dict, PreTrainedTokenizer]: - model = load_model(model_path, lazy) config = load_config(model_path) - tokenizer = load_tokenizer(model_path) + model = load_model(model_path, lazy, model_config=config) + tokenizer = load_tokenizer(model_path, model_config=config) return model, config, tokenizer