support load model by custom get_model_classes (#899)

* feature(mlx_lm): support load model by custom get classes

* rename the param
This commit is contained in:
Anchen
2024-07-26 04:01:17 +10:00
committed by GitHub
parent cd8efc7fbc
commit 7a3ab1620a
2 changed files with 57 additions and 3 deletions

View File

@@ -9,7 +9,7 @@ import shutil
import time
from pathlib import Path
from textwrap import dedent
from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union
from typing import Any, Callable, Dict, Generator, Optional, Tuple, Type, Union
import mlx.core as mx
import mlx.nn as nn
@@ -355,6 +355,7 @@ def load_model(
model_path: Path,
lazy: bool = False,
model_config: dict = {},
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
) -> nn.Module:
"""
Load and initialize the model from a given path.
@@ -364,8 +365,11 @@ def load_model(
lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
model_config(dict, optional): Configuration parameters for the model.
model_config (dict, optional): Configuration parameters for the model.
Defaults to an empty dictionary.
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
A function that returns the model class and model args class given a config.
Defaults to the _get_classes function.
Returns:
nn.Module: The loaded and initialized model.
@@ -392,7 +396,7 @@ def load_model(
for wf in weight_files:
weights.update(mx.load(wf))
model_class, model_args_class = _get_classes(config=config)
model_class, model_args_class = get_model_classes(config=config)
model_args = model_args_class.from_dict(config)
model = model_class(model_args)