diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 229ee238..cffa2a89 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -9,7 +9,7 @@ import shutil import time from pathlib import Path from textwrap import dedent -from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, Optional, Tuple, Type, Union import mlx.core as mx import mlx.nn as nn @@ -355,6 +355,7 @@ def load_model( model_path: Path, lazy: bool = False, model_config: dict = {}, + get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes, ) -> nn.Module: """ Load and initialize the model from a given path. @@ -364,8 +365,11 @@ def load_model( lazy (bool): If False eval the model parameters to make sure they are loaded in memory before returning, otherwise they will be loaded when needed. Default: ``False`` - model_config(dict, optional): Configuration parameters for the model. + model_config (dict, optional): Configuration parameters for the model. Defaults to an empty dictionary. + get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional): + A function that returns the model class and model args class given a config. + Defaults to the _get_classes function. Returns: nn.Module: The loaded and initialized model. @@ -392,7 +396,7 @@ def load_model( for wf in weight_files: weights.update(mx.load(wf)) - model_class, model_args_class = _get_classes(config=config) + model_class, model_args_class = get_model_classes(config=config) model_args = model_args_class.from_dict(config) model = model_class(model_args) diff --git a/llms/tests/test_utils_load_model.py b/llms/tests/test_utils_load_model.py new file mode 100644 index 00000000..73ee1352 --- /dev/null +++ b/llms/tests/test_utils_load_model.py @@ -0,0 +1,50 @@ +import unittest +from pathlib import Path + +import mlx.nn as nn +from mlx_lm.models.qwen2 import Model as Qwen2Model +from mlx_lm.utils import get_model_path, load_model + +HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" + + +class TestLoadModelCustomGetClasses(unittest.TestCase): + + def test_load_model_with_custom_get_classes(self): + class CustomQwenModel(nn.Module): + def __init__(self, args): + super().__init__() + self.config = args + self.custom_attribute = "This is a custom model" + + def load_weights(self, weights): + self.qwenWeights = weights + + class CustomQwenConfig: + @classmethod + def from_dict(cls, config): + instance = cls() + for k, v in config.items(): + setattr(instance, k, v) + return instance + + def custom_get_classes(config): + return CustomQwenModel, CustomQwenConfig + + model_path = get_model_path(HF_MODEL_PATH) + model = load_model(model_path, get_model_classes=custom_get_classes) + + self.assertIsInstance(model, CustomQwenModel) + self.assertTrue(hasattr(model, "custom_attribute")) + self.assertEqual(model.custom_attribute, "This is a custom model") + self.assertTrue(hasattr(model, "qwenWeights")) + + def test_load_model_with_default_get_classes(self): + model_path = get_model_path(HF_MODEL_PATH) + model = load_model(model_path) + + self.assertIsInstance(model, Qwen2Model) + + +if __name__ == "__main__": + unittest.main()