support load model by custom get_model_classes (#899)

* feature(mlx_lm): support load model by custom get classes

* rename the param
This commit is contained in:
Anchen 2024-07-26 04:01:17 +10:00 committed by GitHub
parent cd8efc7fbc
commit 7a3ab1620a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 57 additions and 3 deletions

View File

@ -9,7 +9,7 @@ import shutil
import time import time
from pathlib import Path from pathlib import Path
from textwrap import dedent from textwrap import dedent
from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union from typing import Any, Callable, Dict, Generator, Optional, Tuple, Type, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -355,6 +355,7 @@ def load_model(
model_path: Path, model_path: Path,
lazy: bool = False, lazy: bool = False,
model_config: dict = {}, model_config: dict = {},
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
) -> nn.Module: ) -> nn.Module:
""" """
Load and initialize the model from a given path. Load and initialize the model from a given path.
@ -366,6 +367,9 @@ def load_model(
when needed. Default: ``False`` when needed. Default: ``False``
model_config (dict, optional): Configuration parameters for the model. model_config (dict, optional): Configuration parameters for the model.
Defaults to an empty dictionary. Defaults to an empty dictionary.
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
A function that returns the model class and model args class given a config.
Defaults to the _get_classes function.
Returns: Returns:
nn.Module: The loaded and initialized model. nn.Module: The loaded and initialized model.
@ -392,7 +396,7 @@ def load_model(
for wf in weight_files: for wf in weight_files:
weights.update(mx.load(wf)) weights.update(mx.load(wf))
model_class, model_args_class = _get_classes(config=config) model_class, model_args_class = get_model_classes(config=config)
model_args = model_args_class.from_dict(config) model_args = model_args_class.from_dict(config)
model = model_class(model_args) model = model_class(model_args)

View File

@ -0,0 +1,50 @@
import unittest
from pathlib import Path
import mlx.nn as nn
from mlx_lm.models.qwen2 import Model as Qwen2Model
from mlx_lm.utils import get_model_path, load_model
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
class TestLoadModelCustomGetClasses(unittest.TestCase):
def test_load_model_with_custom_get_classes(self):
class CustomQwenModel(nn.Module):
def __init__(self, args):
super().__init__()
self.config = args
self.custom_attribute = "This is a custom model"
def load_weights(self, weights):
self.qwenWeights = weights
class CustomQwenConfig:
@classmethod
def from_dict(cls, config):
instance = cls()
for k, v in config.items():
setattr(instance, k, v)
return instance
def custom_get_classes(config):
return CustomQwenModel, CustomQwenConfig
model_path = get_model_path(HF_MODEL_PATH)
model = load_model(model_path, get_model_classes=custom_get_classes)
self.assertIsInstance(model, CustomQwenModel)
self.assertTrue(hasattr(model, "custom_attribute"))
self.assertEqual(model.custom_attribute, "This is a custom model")
self.assertTrue(hasattr(model, "qwenWeights"))
def test_load_model_with_default_get_classes(self):
model_path = get_model_path(HF_MODEL_PATH)
model = load_model(model_path)
self.assertIsInstance(model, Qwen2Model)
if __name__ == "__main__":
unittest.main()