LoRA on all linear transformer block layers (#546)

* Add --lora-all-linear option to apply LoRa to all linear transfer block layers

* Moved to YAML config and added specification of rank & alpha

* nits in conifg, more tests

* nit

* run tests for prs

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Chime Ogbuji
2024-03-12 10:37:40 -04:00
committed by GitHub
parent fe5edee360
commit e56d9015ef
8 changed files with 163 additions and 40 deletions

View File

@@ -1,3 +1,5 @@
# Copyright © 2024 Apple Inc.
import math
import mlx.core as mx
@@ -9,8 +11,8 @@ class LoRALinear(nn.Module):
def from_linear(
linear: nn.Linear,
r: int = 8,
lora_alpha: float = 16,
lora_dropout: float = 0.0,
alpha: float = 16,
dropout: float = 0.0,
scale: float = 10.0,
):
# TODO remove when input_dims and output_dims are attributes
@@ -22,8 +24,8 @@ class LoRALinear(nn.Module):
input_dims=input_dims,
output_dims=output_dims,
r=r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
alpha=alpha,
dropout=dropout,
scale=scale,
)
lora_lin.linear = linear
@@ -70,8 +72,8 @@ class LoRALinear(nn.Module):
input_dims: int,
output_dims: int,
r: int = 8,
lora_alpha: float = 16,
lora_dropout: float = 0.0,
alpha: float = 16,
dropout: float = 0.0,
scale: float = 10.0,
bias: bool = False,
):
@@ -80,10 +82,10 @@ class LoRALinear(nn.Module):
# Regular linear layer weights
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
self.lora_dropout = nn.Dropout(p=lora_dropout)
self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update
self.scale = scale * (lora_alpha / r)
self.scale = scale * (alpha / r)
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)
@@ -99,5 +101,5 @@ class LoRALinear(nn.Module):
if isinstance(self.linear, nn.QuantizedLinear):
dtype = self.linear.scales.dtype
y = self.linear(x.astype(dtype))
z = (self.lora_dropout(x) @ self.lora_a) @ self.lora_b
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
return y + self.scale * z