LoRA: Extract small function (#614)

* LoRA: Extract pre_processing_model  function

* LoRA: Extract small functions(train_model,evaluate_model)

* move test case to test_tuner_utils.py

* nits

* nits

* remove extra param, validate at it 0

* version

* fix test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
madroid
2024-06-02 21:38:42 +08:00
committed by GitHub
parent 81318ad4a8
commit c457a3f88b
10 changed files with 232 additions and 206 deletions

View File

@@ -13,9 +13,8 @@ class LoRALinear(nn.Module):
def from_linear(
linear: nn.Linear,
r: int = 8,
alpha: float = 16,
dropout: float = 0.0,
scale: float = 10.0,
scale: float = 20.0,
):
# TODO remove when input_dims and output_dims are attributes
# on linear and quantized linear
@@ -26,7 +25,6 @@ class LoRALinear(nn.Module):
input_dims=input_dims,
output_dims=output_dims,
r=r,
alpha=alpha,
dropout=dropout,
scale=scale,
)
@@ -74,9 +72,8 @@ class LoRALinear(nn.Module):
input_dims: int,
output_dims: int,
r: int = 8,
alpha: float = 16,
dropout: float = 0.0,
scale: float = 10.0,
scale: float = 20.0,
bias: bool = False,
):
super().__init__()
@@ -87,7 +84,7 @@ class LoRALinear(nn.Module):
self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update
self.scale = scale * (alpha / r)
self.scale = scale
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)
@@ -109,16 +106,14 @@ class LoRASwitchLinear(nn.Module):
def from_linear(
linear: nn.Module,
r: int = 8,
alpha: float = 16,
dropout: float = 0.0,
scale: float = 10.0,
scale: float = 20.0,
):
lora_lin = LoRASwitchLinear(
input_dims=linear.input_dims,
output_dims=linear.output_dims,
num_experts=linear.num_experts,
r=r,
alpha=alpha,
dropout=dropout,
scale=scale,
)
@@ -163,9 +158,8 @@ class LoRASwitchLinear(nn.Module):
output_dims: int,
num_experts: int,
r: int = 8,
alpha: float = 16,
dropout: float = 0.0,
scale: float = 10.0,
scale: float = 20.0,
bias: bool = False,
):
super().__init__()
@@ -176,7 +170,7 @@ class LoRASwitchLinear(nn.Module):
self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update
self.scale = scale * (alpha / r)
self.scale = scale
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)