mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-15 09:48:54 +08:00
@@ -9,7 +9,7 @@ import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -354,7 +354,7 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
|
||||
def load(
|
||||
path_or_hf_repo: str,
|
||||
tokenizer_config={},
|
||||
adapter_file: Optional[str] = None,
|
||||
adapter_path: Optional[str] = None,
|
||||
lazy: bool = False,
|
||||
) -> Tuple[nn.Module, PreTrainedTokenizer]:
|
||||
"""
|
||||
@@ -364,8 +364,8 @@ def load(
|
||||
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
|
||||
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
|
||||
Defaults to an empty dictionary.
|
||||
adapter_file (str, optional): Path to the adapter file. If provided, applies LoRA layers to the model.
|
||||
Defaults to None.
|
||||
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
|
||||
to the model. Default: ``None``.
|
||||
lazy (bool): If False eval the model parameters to make sure they are
|
||||
loaded in memory before returning, otherwise they will be loaded
|
||||
when needed. Default: ``False``
|
||||
@@ -379,8 +379,8 @@ def load(
|
||||
model_path = get_model_path(path_or_hf_repo)
|
||||
|
||||
model = load_model(model_path, lazy)
|
||||
if adapter_file is not None:
|
||||
model = apply_lora_layers(model, adapter_file)
|
||||
if adapter_path is not None:
|
||||
model = apply_lora_layers(model, adapter_path)
|
||||
model.eval()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config)
|
||||
|
||||
Reference in New Issue
Block a user