From 9b53599e6c7584606f5072c53cb9cb66e17d8c85 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 9 Dec 2024 08:51:22 -0800 Subject: [PATCH] nits --- llms/mlx_lm/tokenizer_utils.py | 8 +++----- llms/mlx_lm/utils.py | 30 ++++++++++++++--------------- llms/tests/test_utils_load_model.py | 4 ++-- 3 files changed, 20 insertions(+), 22 deletions(-) diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index ca9bb792..10a257f6 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -327,7 +327,7 @@ def _is_bpe_decoder(decoder): return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel" -def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_id=None): +def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None): """Load a huggingface tokenizer and try to infer the type of streaming detokenizer to use. @@ -348,10 +348,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_id=None): elif _is_bpe_decoder(tokenizer_content["decoder"]): detokenizer_class = BPEStreamingDetokenizer - eos_token_ids = ( - set(eos_token_id) if isinstance(eos_token_id, list) else {eos_token_id} - ) - + if isinstance(eos_token_ids, int): + eos_token_ids = [eos_token_ids] return TokenizerWrapper( AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra), detokenizer_class, diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 32b2fd20..85981b7d 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -456,11 +456,11 @@ def load_model( lazy (bool): If False eval the model parameters to make sure they are loaded in memory before returning, otherwise they will be loaded when needed. Default: ``False`` - model_config (dict, optional): Configuration parameters for the model. - Defaults to an empty dictionary. + model_config (dict, optional): Optional configuration parameters for the + model. Defaults to an empty dictionary. get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional): A function that returns the model class and model args class given a config. - Defaults to the _get_classes function. + Defaults to the ``_get_classes`` function. Returns: nn.Module: The loaded and initialized model. @@ -469,6 +469,8 @@ def load_model( FileNotFoundError: If the weight files (.safetensors) are not found. ValueError: If the model class or args class are not found or cannot be instantiated. """ + config = load_config(model_path) + config.update(model_config) weight_files = glob.glob(str(model_path / "model*.safetensors")) @@ -484,15 +486,15 @@ def load_model( for wf in weight_files: weights.update(mx.load(wf)) - model_class, model_args_class = get_model_classes(config=model_config) + model_class, model_args_class = get_model_classes(config=config) - model_args = model_args_class.from_dict(model_config) + model_args = model_args_class.from_dict(config) model = model_class(model_args) if hasattr(model, "sanitize"): weights = model.sanitize(weights) - if (quantization := model_config.get("quantization", None)) is not None: + if (quantization := config.get("quantization", None)) is not None: # Handle legacy models which may not have everything quantized def class_predicate(p, m): if not hasattr(m, "to_quantized"): @@ -511,7 +513,7 @@ def load_model( mx.eval(model.parameters()) model.eval() - return model + return model, config def load( @@ -544,15 +546,12 @@ def load( """ model_path = get_model_path(path_or_hf_repo) - config = load_config(model_path) - config.update(model_config) - - model = load_model(model_path, lazy, config) + model, config = load_model(model_path, lazy) if adapter_path is not None: model = load_adapters(model, adapter_path) model.eval() tokenizer = load_tokenizer( - model_path, tokenizer_config, eos_token_id=config["eos_token_id"] + model_path, tokenizer_config, eos_token_ids=config.get("eos_token_id", None) ) return model, tokenizer @@ -561,9 +560,10 @@ def load( def fetch_from_hub( model_path: Path, lazy: bool = False ) -> Tuple[nn.Module, dict, PreTrainedTokenizer]: - config = load_config(model_path) - model = load_model(model_path, lazy, model_config=config) - tokenizer = load_tokenizer(model_path, eos_token_id=config["eos_token_id"]) + model, config = load_model(model_path, lazy) + tokenizer = load_tokenizer( + model_path, eos_token_ids=config.get("eos_token_id", None) + ) return model, config, tokenizer diff --git a/llms/tests/test_utils_load_model.py b/llms/tests/test_utils_load_model.py index 73ee1352..5821f9e9 100644 --- a/llms/tests/test_utils_load_model.py +++ b/llms/tests/test_utils_load_model.py @@ -32,7 +32,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase): return CustomQwenModel, CustomQwenConfig model_path = get_model_path(HF_MODEL_PATH) - model = load_model(model_path, get_model_classes=custom_get_classes) + model, _ = load_model(model_path, get_model_classes=custom_get_classes) self.assertIsInstance(model, CustomQwenModel) self.assertTrue(hasattr(model, "custom_attribute")) @@ -41,7 +41,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase): def test_load_model_with_default_get_classes(self): model_path = get_model_path(HF_MODEL_PATH) - model = load_model(model_path) + model, _ = load_model(model_path) self.assertIsInstance(model, Qwen2Model)