mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
LoRA: Extract small function (#614)
* LoRA: Extract pre_processing_model function * LoRA: Extract small functions(train_model,evaluate_model) * move test case to test_tuner_utils.py * nits * nits * remove extra param, validate at it 0 * version * fix test --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -11,9 +11,8 @@ class DoRALinear(nn.Module):
|
||||
def from_linear(
|
||||
linear: nn.Linear,
|
||||
r: int = 8,
|
||||
alpha: float = 16,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 10.0,
|
||||
scale: float = 20.0,
|
||||
):
|
||||
# TODO support quantized weights in DoRALinear
|
||||
output_dims, input_dims = linear.weight.shape
|
||||
@@ -23,7 +22,6 @@ class DoRALinear(nn.Module):
|
||||
input_dims=input_dims,
|
||||
output_dims=output_dims,
|
||||
r=r,
|
||||
alpha=alpha,
|
||||
dropout=dropout,
|
||||
scale=scale,
|
||||
)
|
||||
@@ -56,9 +54,8 @@ class DoRALinear(nn.Module):
|
||||
input_dims: int,
|
||||
output_dims: int,
|
||||
r: int = 8,
|
||||
alpha: float = 16,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 10.0,
|
||||
scale: float = 20.0,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -68,7 +65,7 @@ class DoRALinear(nn.Module):
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
# Scale for low-rank update
|
||||
self.scale = scale * (alpha / r)
|
||||
self.scale = scale
|
||||
|
||||
# Low rank lora weights
|
||||
scale = 1 / math.sqrt(input_dims)
|
||||
|
||||
Reference in New Issue
Block a user