diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 0fa41ac0..10a257f6 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -254,21 +254,33 @@ class TokenizerWrapper: huggingface tokenizer. """ - def __init__(self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer): + def __init__( + self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer, eos_token_ids=None + ): self._tokenizer = tokenizer self._detokenizer = detokenizer_class(tokenizer) + self._eos_token_ids = ( + set(eos_token_ids) + if eos_token_ids is not None + else {tokenizer.eos_token_id} + ) def __getattr__(self, attr): if attr == "detokenizer": return self._detokenizer + elif attr == "eos_token_ids": + return self._eos_token_ids elif attr.startswith("_"): return self.__getattribute__(attr) else: return getattr(self._tokenizer, attr) def __setattr__(self, attr, value): - if attr == "detokenizer": - raise AttributeError("Cannot set the detokenizer.") + if attr in {"detokenizer", "eos_token_ids"}: + if attr == "detokenizer": + raise AttributeError("Cannot set the detokenizer.") + elif attr == "eos_token_ids": + self._eos_token_ids = set(value) if value is not None else set() elif attr.startswith("_"): super().__setattr__(attr, value) else: @@ -315,7 +327,7 @@ def _is_bpe_decoder(decoder): return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel" -def load_tokenizer(model_path, tokenizer_config_extra={}): +def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None): """Load a huggingface tokenizer and try to infer the type of streaming detokenizer to use. @@ -336,7 +348,10 @@ def load_tokenizer(model_path, tokenizer_config_extra={}): elif _is_bpe_decoder(tokenizer_content["decoder"]): detokenizer_class = BPEStreamingDetokenizer + if isinstance(eos_token_ids, int): + eos_token_ids = [eos_token_ids] return TokenizerWrapper( AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra), detokenizer_class, + eos_token_ids=eos_token_ids, ) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 66a106a1..d81bb66a 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -361,7 +361,7 @@ def stream_generate( prompt_time = time.perf_counter() - tic prompt_tps = prompt.size / prompt_time tic = time.perf_counter() - if token == tokenizer.eos_token_id: + if token in tokenizer.eos_token_ids: break detokenizer.add_token(token) @@ -467,11 +467,11 @@ def load_model( lazy (bool): If False eval the model parameters to make sure they are loaded in memory before returning, otherwise they will be loaded when needed. Default: ``False`` - model_config (dict, optional): Configuration parameters for the model. - Defaults to an empty dictionary. + model_config (dict, optional): Optional configuration parameters for the + model. Defaults to an empty dictionary. get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional): A function that returns the model class and model args class given a config. - Defaults to the _get_classes function. + Defaults to the ``_get_classes`` function. Returns: nn.Module: The loaded and initialized model. @@ -480,7 +480,6 @@ def load_model( FileNotFoundError: If the weight files (.safetensors) are not found. ValueError: If the model class or args class are not found or cannot be instantiated. """ - config = load_config(model_path) config.update(model_config) @@ -530,7 +529,7 @@ def load_model( mx.eval(model.parameters()) model.eval() - return model + return model, config def load( @@ -563,11 +562,13 @@ def load( """ model_path = get_model_path(path_or_hf_repo) - model = load_model(model_path, lazy, model_config) + model, config = load_model(model_path, lazy) if adapter_path is not None: model = load_adapters(model, adapter_path) model.eval() - tokenizer = load_tokenizer(model_path, tokenizer_config) + tokenizer = load_tokenizer( + model_path, tokenizer_config, eos_token_ids=config.get("eos_token_id", None) + ) return model, tokenizer @@ -575,9 +576,10 @@ def load( def fetch_from_hub( model_path: Path, lazy: bool = False ) -> Tuple[nn.Module, dict, PreTrainedTokenizer]: - model = load_model(model_path, lazy) - config = load_config(model_path) - tokenizer = load_tokenizer(model_path) + model, config = load_model(model_path, lazy) + tokenizer = load_tokenizer( + model_path, eos_token_ids=config.get("eos_token_id", None) + ) return model, config, tokenizer diff --git a/llms/tests/test_utils_load_model.py b/llms/tests/test_utils_load_model.py index 73ee1352..5821f9e9 100644 --- a/llms/tests/test_utils_load_model.py +++ b/llms/tests/test_utils_load_model.py @@ -32,7 +32,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase): return CustomQwenModel, CustomQwenConfig model_path = get_model_path(HF_MODEL_PATH) - model = load_model(model_path, get_model_classes=custom_get_classes) + model, _ = load_model(model_path, get_model_classes=custom_get_classes) self.assertIsInstance(model, CustomQwenModel) self.assertTrue(hasattr(model, "custom_attribute")) @@ -41,7 +41,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase): def test_load_model_with_default_get_classes(self): model_path = get_model_path(HF_MODEL_PATH) - model = load_model(model_path) + model, _ = load_model(model_path) self.assertIsInstance(model, Qwen2Model)