mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
LoRA: Extract small function (#614)
* LoRA: Extract pre_processing_model function * LoRA: Extract small functions(train_model,evaluate_model) * move test case to test_tuner_utils.py * nits * nits * remove extra param, validate at it 0 * version * fix test --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -13,9 +13,8 @@ class LoRALinear(nn.Module):
|
||||
def from_linear(
|
||||
linear: nn.Linear,
|
||||
r: int = 8,
|
||||
alpha: float = 16,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 10.0,
|
||||
scale: float = 20.0,
|
||||
):
|
||||
# TODO remove when input_dims and output_dims are attributes
|
||||
# on linear and quantized linear
|
||||
@@ -26,7 +25,6 @@ class LoRALinear(nn.Module):
|
||||
input_dims=input_dims,
|
||||
output_dims=output_dims,
|
||||
r=r,
|
||||
alpha=alpha,
|
||||
dropout=dropout,
|
||||
scale=scale,
|
||||
)
|
||||
@@ -74,9 +72,8 @@ class LoRALinear(nn.Module):
|
||||
input_dims: int,
|
||||
output_dims: int,
|
||||
r: int = 8,
|
||||
alpha: float = 16,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 10.0,
|
||||
scale: float = 20.0,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -87,7 +84,7 @@ class LoRALinear(nn.Module):
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
# Scale for low-rank update
|
||||
self.scale = scale * (alpha / r)
|
||||
self.scale = scale
|
||||
|
||||
# Low rank lora weights
|
||||
scale = 1 / math.sqrt(input_dims)
|
||||
@@ -109,16 +106,14 @@ class LoRASwitchLinear(nn.Module):
|
||||
def from_linear(
|
||||
linear: nn.Module,
|
||||
r: int = 8,
|
||||
alpha: float = 16,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 10.0,
|
||||
scale: float = 20.0,
|
||||
):
|
||||
lora_lin = LoRASwitchLinear(
|
||||
input_dims=linear.input_dims,
|
||||
output_dims=linear.output_dims,
|
||||
num_experts=linear.num_experts,
|
||||
r=r,
|
||||
alpha=alpha,
|
||||
dropout=dropout,
|
||||
scale=scale,
|
||||
)
|
||||
@@ -163,9 +158,8 @@ class LoRASwitchLinear(nn.Module):
|
||||
output_dims: int,
|
||||
num_experts: int,
|
||||
r: int = 8,
|
||||
alpha: float = 16,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 10.0,
|
||||
scale: float = 20.0,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -176,7 +170,7 @@ class LoRASwitchLinear(nn.Module):
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
# Scale for low-rank update
|
||||
self.scale = scale * (alpha / r)
|
||||
self.scale = scale
|
||||
|
||||
# Low rank lora weights
|
||||
scale = 1 / math.sqrt(input_dims)
|
||||
|
Reference in New Issue
Block a user