mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
LoRA on all linear transformer block layers (#546)
* Add --lora-all-linear option to apply LoRa to all linear transfer block layers * Moved to YAML config and added specification of rank & alpha * nits in conifg, more tests * nit * run tests for prs --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
@@ -9,8 +11,8 @@ class LoRALinear(nn.Module):
|
||||
def from_linear(
|
||||
linear: nn.Linear,
|
||||
r: int = 8,
|
||||
lora_alpha: float = 16,
|
||||
lora_dropout: float = 0.0,
|
||||
alpha: float = 16,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 10.0,
|
||||
):
|
||||
# TODO remove when input_dims and output_dims are attributes
|
||||
@@ -22,8 +24,8 @@ class LoRALinear(nn.Module):
|
||||
input_dims=input_dims,
|
||||
output_dims=output_dims,
|
||||
r=r,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
alpha=alpha,
|
||||
dropout=dropout,
|
||||
scale=scale,
|
||||
)
|
||||
lora_lin.linear = linear
|
||||
@@ -70,8 +72,8 @@ class LoRALinear(nn.Module):
|
||||
input_dims: int,
|
||||
output_dims: int,
|
||||
r: int = 8,
|
||||
lora_alpha: float = 16,
|
||||
lora_dropout: float = 0.0,
|
||||
alpha: float = 16,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 10.0,
|
||||
bias: bool = False,
|
||||
):
|
||||
@@ -80,10 +82,10 @@ class LoRALinear(nn.Module):
|
||||
# Regular linear layer weights
|
||||
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||
|
||||
self.lora_dropout = nn.Dropout(p=lora_dropout)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
# Scale for low-rank update
|
||||
self.scale = scale * (lora_alpha / r)
|
||||
self.scale = scale * (alpha / r)
|
||||
|
||||
# Low rank lora weights
|
||||
scale = 1 / math.sqrt(input_dims)
|
||||
@@ -99,5 +101,5 @@ class LoRALinear(nn.Module):
|
||||
if isinstance(self.linear, nn.QuantizedLinear):
|
||||
dtype = self.linear.scales.dtype
|
||||
y = self.linear(x.astype(dtype))
|
||||
z = (self.lora_dropout(x) @ self.lora_a) @ self.lora_b
|
||||
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
|
||||
return y + self.scale * z
|
||||
|
@@ -1,3 +1,5 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -7,7 +8,11 @@ from mlx.utils import tree_unflatten
|
||||
from .lora import LoRALinear
|
||||
|
||||
|
||||
def linear_to_lora_layers(model: nn.Module, num_lora_layers: int):
|
||||
def linear_to_lora_layers(
|
||||
model: nn.Module,
|
||||
num_lora_layers: int,
|
||||
config: Dict,
|
||||
):
|
||||
"""
|
||||
Convert some of the models linear layers to lora layers.
|
||||
|
||||
@@ -15,16 +20,28 @@ def linear_to_lora_layers(model: nn.Module, num_lora_layers: int):
|
||||
model (nn.Module): The neural network model.
|
||||
num_lora_layers (int): The number of blocks to convert to lora layers
|
||||
starting from the last layer.
|
||||
config (dict): More configuration parameters for LoRA, including the
|
||||
rank, alpha, scale, and optional layer keys.
|
||||
"""
|
||||
|
||||
def check_lora_layers(num_model):
|
||||
if num_lora_layers > num_model:
|
||||
raise ValueError(
|
||||
f"Requested {num_lora_layers} LoRA layers "
|
||||
f"but the model only has {num_model} layers."
|
||||
)
|
||||
num_layers = len(model.layers)
|
||||
if num_lora_layers > num_layers:
|
||||
raise ValueError(
|
||||
f"Requested {num_lora_layers} LoRA layers "
|
||||
f"but the model only has {num_layers} layers."
|
||||
)
|
||||
|
||||
if model.model_type in [
|
||||
to_lora = lambda lin: LoRALinear.from_linear(
|
||||
lin, r=config["rank"], alpha=config["alpha"], scale=config["scale"]
|
||||
)
|
||||
|
||||
# If the lora_parameters are set, we assume the keys
|
||||
# are correct for the given model
|
||||
|
||||
keys = config.get("keys", None)
|
||||
if keys is not None:
|
||||
keys = set(keys)
|
||||
elif model.model_type in [
|
||||
"mistral",
|
||||
"llama",
|
||||
"phi",
|
||||
@@ -34,32 +51,21 @@ def linear_to_lora_layers(model: nn.Module, num_lora_layers: int):
|
||||
"gemma",
|
||||
"starcoder2",
|
||||
]:
|
||||
check_lora_layers(len(model.model.layers))
|
||||
|
||||
for l in model.model.layers[len(model.model.layers) - num_lora_layers :]:
|
||||
l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
|
||||
l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
|
||||
if hasattr(l, "block_sparse_moe"):
|
||||
l.block_sparse_moe.gate = LoRALinear.from_linear(
|
||||
l.block_sparse_moe.gate
|
||||
)
|
||||
keys = set(["self_attn.q_proj", "self_attn.v_proj"])
|
||||
if model.model_type == "mixtral":
|
||||
keys.add(["block_sparse_moe.gate"])
|
||||
elif model.model_type == "olmo":
|
||||
check_lora_layers(len(model.model.transformer.blocks))
|
||||
|
||||
for l in model.model.transformer.blocks[
|
||||
len(model.model.transformer.blocks) - num_lora_layers :
|
||||
]:
|
||||
l.att_proj = LoRALinear.from_linear(l.att_proj)
|
||||
keys = set(["att_proj"])
|
||||
elif model.model_type == "phi-msft":
|
||||
check_lora_layers(len(model.transformer.h))
|
||||
|
||||
for l in model.transformer.h[len(model.transformer.h) - num_lora_layers :]:
|
||||
l.mixer.Wqkv = LoRALinear.from_linear(l.mixer.Wqkv)
|
||||
l.moe.gate = LoRALinear.from_linear(l.moe.gate)
|
||||
|
||||
keys = set(["mixer.Wqkv", "moe.gate"])
|
||||
else:
|
||||
raise ValueError(f"Lora does not support {model.model_type}")
|
||||
|
||||
for l in model.layers[num_layers - num_lora_layers :]:
|
||||
modules = l.named_modules()
|
||||
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
|
||||
l.update_modules(tree_unflatten(lora_layers))
|
||||
|
||||
|
||||
def apply_lora_layers(model: nn.Module, adapter_file: str) -> nn.Module:
|
||||
"""
|
||||
|
Reference in New Issue
Block a user