mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Lazy import + refactor Lora layer addition (#426)
* lazy model import in mlx_lm * change lora loading * fix olmo lora * remove a bunch of unused stuff from plamo * move phixtral to mlx-lm and out of llms/
This commit is contained in:
@@ -7,6 +7,44 @@ from mlx.utils import tree_unflatten
|
||||
from .lora import LoRALinear
|
||||
|
||||
|
||||
def linear_to_lora_layers(model: nn.Module, num_lora_layers: int):
|
||||
"""
|
||||
Convert some of the models linear layers to lora layers.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The neural network model.
|
||||
num_lora_layers (int): The number of blocks to convert to lora layers
|
||||
starting from the last layer.
|
||||
"""
|
||||
if model.model_type in [
|
||||
"mistral",
|
||||
"llama",
|
||||
"phi",
|
||||
"mixtral",
|
||||
"stablelm_epoch",
|
||||
"qwen2",
|
||||
]:
|
||||
for l in model.model.layers[len(model.model.layers) - num_lora_layers :]:
|
||||
l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
|
||||
l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
|
||||
if hasattr(l, "block_sparse_moe"):
|
||||
l.block_sparse_moe.gate = LoRALinear.from_linear(
|
||||
l.block_sparse_moe.gate
|
||||
)
|
||||
elif model.model_type == "olmo":
|
||||
for l in model.model.transformer.blocks[
|
||||
len(model.model.transformer.blocks) - num_lora_layers :
|
||||
]:
|
||||
l.att_proj = LoRALinear.from_linear(l.att_proj)
|
||||
elif model.model_type == "phi-msft":
|
||||
for l in model.transformer.h[len(model.transformer.h) - num_lora_layers :]:
|
||||
l.mixer.Wqkv = LoRALinear.from_linear(l.mixer.Wqkv)
|
||||
l.moe.gate = LoRALinear.from_linear(l.moe.gate)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Lora does not support {model.model_type}")
|
||||
|
||||
|
||||
def apply_lora_layers(model: nn.Module, adapter_file: str) -> nn.Module:
|
||||
"""
|
||||
Apply LoRA layers to the model.
|
||||
|
Reference in New Issue
Block a user