From 7a3ab1620a6b853f76fcdacdc35df308a255036e Mon Sep 17 00:00:00 2001
From: Anchen
Date: Fri, 26 Jul 2024 04:01:17 +1000
Subject: [PATCH] support load model by custom get_model_classes (#899)
* feature(mlx_lm): support load model by custom get classes
* rename the param
---
llms/mlx_lm/utils.py | 10 ++++--
llms/tests/test_utils_load_model.py | 50 +++++++++++++++++++++++++++++
2 files changed, 57 insertions(+), 3 deletions(-)
create mode 100644 llms/tests/test_utils_load_model.py
diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py
index 229ee238..cffa2a89 100644
--- a/llms/mlx_lm/utils.py
+++ b/llms/mlx_lm/utils.py
@@ -9,7 +9,7 @@ import shutil
import time
from pathlib import Path
from textwrap import dedent
-from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union
+from typing import Any, Callable, Dict, Generator, Optional, Tuple, Type, Union
import mlx.core as mx
import mlx.nn as nn
@@ -355,6 +355,7 @@ def load_model(
model_path: Path,
lazy: bool = False,
model_config: dict = {},
+ get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
) -> nn.Module:
"""
Load and initialize the model from a given path.
@@ -364,8 +365,11 @@ def load_model(
lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
- model_config(dict, optional): Configuration parameters for the model.
+ model_config (dict, optional): Configuration parameters for the model.
Defaults to an empty dictionary.
+ get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
+ A function that returns the model class and model args class given a config.
+ Defaults to the _get_classes function.
Returns:
nn.Module: The loaded and initialized model.
@@ -392,7 +396,7 @@ def load_model(
for wf in weight_files:
weights.update(mx.load(wf))
- model_class, model_args_class = _get_classes(config=config)
+ model_class, model_args_class = get_model_classes(config=config)
model_args = model_args_class.from_dict(config)
model = model_class(model_args)
diff --git a/llms/tests/test_utils_load_model.py b/llms/tests/test_utils_load_model.py
new file mode 100644
index 00000000..73ee1352
--- /dev/null
+++ b/llms/tests/test_utils_load_model.py
@@ -0,0 +1,50 @@
+import unittest
+from pathlib import Path
+
+import mlx.nn as nn
+from mlx_lm.models.qwen2 import Model as Qwen2Model
+from mlx_lm.utils import get_model_path, load_model
+
+HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
+
+
+class TestLoadModelCustomGetClasses(unittest.TestCase):
+
+ def test_load_model_with_custom_get_classes(self):
+ class CustomQwenModel(nn.Module):
+ def __init__(self, args):
+ super().__init__()
+ self.config = args
+ self.custom_attribute = "This is a custom model"
+
+ def load_weights(self, weights):
+ self.qwenWeights = weights
+
+ class CustomQwenConfig:
+ @classmethod
+ def from_dict(cls, config):
+ instance = cls()
+ for k, v in config.items():
+ setattr(instance, k, v)
+ return instance
+
+ def custom_get_classes(config):
+ return CustomQwenModel, CustomQwenConfig
+
+ model_path = get_model_path(HF_MODEL_PATH)
+ model = load_model(model_path, get_model_classes=custom_get_classes)
+
+ self.assertIsInstance(model, CustomQwenModel)
+ self.assertTrue(hasattr(model, "custom_attribute"))
+ self.assertEqual(model.custom_attribute, "This is a custom model")
+ self.assertTrue(hasattr(model, "qwenWeights"))
+
+ def test_load_model_with_default_get_classes(self):
+ model_path = get_model_path(HF_MODEL_PATH)
+ model = load_model(model_path)
+
+ self.assertIsInstance(model, Qwen2Model)
+
+
+if __name__ == "__main__":
+ unittest.main()