mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
support load model by custom get_model_classes (#899)
* feature(mlx_lm): support load model by custom get classes * rename the param
This commit is contained in:
@@ -9,7 +9,7 @@ import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Generator, Optional, Tuple, Type, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -355,6 +355,7 @@ def load_model(
|
||||
model_path: Path,
|
||||
lazy: bool = False,
|
||||
model_config: dict = {},
|
||||
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Load and initialize the model from a given path.
|
||||
@@ -364,8 +365,11 @@ def load_model(
|
||||
lazy (bool): If False eval the model parameters to make sure they are
|
||||
loaded in memory before returning, otherwise they will be loaded
|
||||
when needed. Default: ``False``
|
||||
model_config(dict, optional): Configuration parameters for the model.
|
||||
model_config (dict, optional): Configuration parameters for the model.
|
||||
Defaults to an empty dictionary.
|
||||
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
|
||||
A function that returns the model class and model args class given a config.
|
||||
Defaults to the _get_classes function.
|
||||
|
||||
Returns:
|
||||
nn.Module: The loaded and initialized model.
|
||||
@@ -392,7 +396,7 @@ def load_model(
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf))
|
||||
|
||||
model_class, model_args_class = _get_classes(config=config)
|
||||
model_class, model_args_class = get_model_classes(config=config)
|
||||
|
||||
model_args = model_args_class.from_dict(config)
|
||||
model = model_class(model_args)
|
||||
|
Reference in New Issue
Block a user