diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 993f216a..079011f0 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -1,10 +1,11 @@ import mlx.core as mx +import mlx.nn as nn from mlx.utils import tree_unflatten from .lora import LoRALinear -def apply_lora_layers(model, adapter_file: str): +def apply_lora_layers(model: nn.Module, adapter_file: str) -> nn.Module: adapters = list(mx.load(adapter_file).items()) linear_replacements = {} lora_layers = set( diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 522208c1..4a53ee9d 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -13,6 +13,7 @@ from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer # Local imports from .models import llama, mixtral, phi2, plamo, qwen +from .tuner.utils import apply_lora_layers # Constants MODEL_MAPPING = { @@ -98,11 +99,14 @@ def generate_step( """ def sample(logits: mx.array) -> Tuple[mx.array, float]: + softmax_logits = mx.softmax(logits) + if temp == 0: token = mx.argmax(logits, axis=-1) else: token = mx.random.categorical(logits * (1 / temp)) - prob = mx.softmax(logits / temp)[0, token] + + prob = softmax_logits[0, token] return token, prob y = prompt @@ -237,7 +241,7 @@ def load_model(model_path: Path) -> nn.Module: def load( - path_or_hf_repo: str, tokenizer_config={} + path_or_hf_repo: str, tokenizer_config={}, adapter_file: str = None ) -> Tuple[nn.Module, PreTrainedTokenizer]: """ Load the model and tokenizer from a given path or a huggingface repository. @@ -246,8 +250,10 @@ def load( model_path (Path): The path or the huggingface repository to load the model from. tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer. Defaults to an empty dictionary. + adapter_file (str, optional): Path to the adapter file. If provided, applies LoRA layers to the model. + Defaults to None. Returns: - nn.Module: The loaded model. + Tuple[nn.Module, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer. Raises: FileNotFoundError: If config file or safetensors are not found. @@ -256,6 +262,9 @@ def load( model_path = get_model_path(path_or_hf_repo) model = load_model(model_path) + if adapter_file is not None: + model = apply_lora_layers(model, adapter_file) + tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config) return model, tokenizer