LoRA: Extract small function (#614)

* LoRA: Extract pre_processing_model  function

* LoRA: Extract small functions(train_model,evaluate_model)

* move test case to test_tuner_utils.py

* nits

* nits

* remove extra param, validate at it 0

* version

* fix test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
madroid
2024-06-02 21:38:42 +08:00
committed by GitHub
parent 81318ad4a8
commit c457a3f88b
10 changed files with 232 additions and 206 deletions

View File

@@ -11,9 +11,8 @@ class DoRALinear(nn.Module):
def from_linear(
linear: nn.Linear,
r: int = 8,
alpha: float = 16,
dropout: float = 0.0,
scale: float = 10.0,
scale: float = 20.0,
):
# TODO support quantized weights in DoRALinear
output_dims, input_dims = linear.weight.shape
@@ -23,7 +22,6 @@ class DoRALinear(nn.Module):
input_dims=input_dims,
output_dims=output_dims,
r=r,
alpha=alpha,
dropout=dropout,
scale=scale,
)
@@ -56,9 +54,8 @@ class DoRALinear(nn.Module):
input_dims: int,
output_dims: int,
r: int = 8,
alpha: float = 16,
dropout: float = 0.0,
scale: float = 10.0,
scale: float = 20.0,
bias: bool = False,
):
super().__init__()
@@ -68,7 +65,7 @@ class DoRALinear(nn.Module):
self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update
self.scale = scale * (alpha / r)
self.scale = scale
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)