Lazy import + refactor Lora layer addition (#426)

* lazy model import in mlx_lm

* change lora loading

* fix olmo lora

* remove a bunch of unused stuff from plamo

* move phixtral to mlx-lm and out of llms/
This commit is contained in:
Awni Hannun
2024-02-12 10:51:02 -08:00
committed by GitHub
parent 4576946151
commit d4666615bb
15 changed files with 127 additions and 393 deletions

View File

@@ -7,6 +7,44 @@ from mlx.utils import tree_unflatten
from .lora import LoRALinear
def linear_to_lora_layers(model: nn.Module, num_lora_layers: int):
"""
Convert some of the models linear layers to lora layers.
Args:
model (nn.Module): The neural network model.
num_lora_layers (int): The number of blocks to convert to lora layers
starting from the last layer.
"""
if model.model_type in [
"mistral",
"llama",
"phi",
"mixtral",
"stablelm_epoch",
"qwen2",
]:
for l in model.model.layers[len(model.model.layers) - num_lora_layers :]:
l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
if hasattr(l, "block_sparse_moe"):
l.block_sparse_moe.gate = LoRALinear.from_linear(
l.block_sparse_moe.gate
)
elif model.model_type == "olmo":
for l in model.model.transformer.blocks[
len(model.model.transformer.blocks) - num_lora_layers :
]:
l.att_proj = LoRALinear.from_linear(l.att_proj)
elif model.model_type == "phi-msft":
for l in model.transformer.h[len(model.transformer.h) - num_lora_layers :]:
l.mixer.Wqkv = LoRALinear.from_linear(l.mixer.Wqkv)
l.moe.gate = LoRALinear.from_linear(l.moe.gate)
else:
raise ValueError(f"Lora does not support {model.model_type}")
def apply_lora_layers(model: nn.Module, adapter_file: str) -> nn.Module:
"""
Apply LoRA layers to the model.