mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
support load model by custom get_model_classes (#899)
* feature(mlx_lm): support load model by custom get classes * rename the param
This commit is contained in:
parent
cd8efc7fbc
commit
7a3ab1620a
@ -9,7 +9,7 @@ import shutil
|
|||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, Generator, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -355,6 +355,7 @@ def load_model(
|
|||||||
model_path: Path,
|
model_path: Path,
|
||||||
lazy: bool = False,
|
lazy: bool = False,
|
||||||
model_config: dict = {},
|
model_config: dict = {},
|
||||||
|
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
"""
|
"""
|
||||||
Load and initialize the model from a given path.
|
Load and initialize the model from a given path.
|
||||||
@ -364,8 +365,11 @@ def load_model(
|
|||||||
lazy (bool): If False eval the model parameters to make sure they are
|
lazy (bool): If False eval the model parameters to make sure they are
|
||||||
loaded in memory before returning, otherwise they will be loaded
|
loaded in memory before returning, otherwise they will be loaded
|
||||||
when needed. Default: ``False``
|
when needed. Default: ``False``
|
||||||
model_config(dict, optional): Configuration parameters for the model.
|
model_config (dict, optional): Configuration parameters for the model.
|
||||||
Defaults to an empty dictionary.
|
Defaults to an empty dictionary.
|
||||||
|
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
|
||||||
|
A function that returns the model class and model args class given a config.
|
||||||
|
Defaults to the _get_classes function.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
nn.Module: The loaded and initialized model.
|
nn.Module: The loaded and initialized model.
|
||||||
@ -392,7 +396,7 @@ def load_model(
|
|||||||
for wf in weight_files:
|
for wf in weight_files:
|
||||||
weights.update(mx.load(wf))
|
weights.update(mx.load(wf))
|
||||||
|
|
||||||
model_class, model_args_class = _get_classes(config=config)
|
model_class, model_args_class = get_model_classes(config=config)
|
||||||
|
|
||||||
model_args = model_args_class.from_dict(config)
|
model_args = model_args_class.from_dict(config)
|
||||||
model = model_class(model_args)
|
model = model_class(model_args)
|
||||||
|
50
llms/tests/test_utils_load_model.py
Normal file
50
llms/tests/test_utils_load_model.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import mlx.nn as nn
|
||||||
|
from mlx_lm.models.qwen2 import Model as Qwen2Model
|
||||||
|
from mlx_lm.utils import get_model_path, load_model
|
||||||
|
|
||||||
|
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoadModelCustomGetClasses(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_load_model_with_custom_get_classes(self):
|
||||||
|
class CustomQwenModel(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__()
|
||||||
|
self.config = args
|
||||||
|
self.custom_attribute = "This is a custom model"
|
||||||
|
|
||||||
|
def load_weights(self, weights):
|
||||||
|
self.qwenWeights = weights
|
||||||
|
|
||||||
|
class CustomQwenConfig:
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, config):
|
||||||
|
instance = cls()
|
||||||
|
for k, v in config.items():
|
||||||
|
setattr(instance, k, v)
|
||||||
|
return instance
|
||||||
|
|
||||||
|
def custom_get_classes(config):
|
||||||
|
return CustomQwenModel, CustomQwenConfig
|
||||||
|
|
||||||
|
model_path = get_model_path(HF_MODEL_PATH)
|
||||||
|
model = load_model(model_path, get_model_classes=custom_get_classes)
|
||||||
|
|
||||||
|
self.assertIsInstance(model, CustomQwenModel)
|
||||||
|
self.assertTrue(hasattr(model, "custom_attribute"))
|
||||||
|
self.assertEqual(model.custom_attribute, "This is a custom model")
|
||||||
|
self.assertTrue(hasattr(model, "qwenWeights"))
|
||||||
|
|
||||||
|
def test_load_model_with_default_get_classes(self):
|
||||||
|
model_path = get_model_path(HF_MODEL_PATH)
|
||||||
|
model = load_model(model_path)
|
||||||
|
|
||||||
|
self.assertIsInstance(model, Qwen2Model)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue
Block a user