From b1dec281b3a1d5482ccfc703473eb09e8f51a9bf Mon Sep 17 00:00:00 2001
From: Anchen
Date: Thu, 25 Jan 2024 03:11:25 +1100
Subject: [PATCH] feat(mlx-lm): add lora hypeparameters in lora layer (#366)
* feat(mlx-lm): add lora hypeparameters in lora layer
* chore: address comments
---
llms/mlx_lm/lora.py | 1 +
llms/mlx_lm/tuner/lora.py | 31 +++++++++++++++++++++++--------
llms/mlx_lm/utils.py | 1 +
3 files changed, 25 insertions(+), 8 deletions(-)
diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py
index 2bcb8099..ce4d1854 100644
--- a/llms/mlx_lm/lora.py
+++ b/llms/mlx_lm/lora.py
@@ -242,6 +242,7 @@ if __name__ == "__main__":
if args.prompt is not None:
print("Generating")
+ model.eval()
generate(
model=model,
tokenizer=tokenizer,
diff --git a/llms/mlx_lm/tuner/lora.py b/llms/mlx_lm/tuner/lora.py
index f0ec601b..2a64e5a0 100644
--- a/llms/mlx_lm/tuner/lora.py
+++ b/llms/mlx_lm/tuner/lora.py
@@ -6,14 +6,25 @@ import mlx.nn as nn
class LoRALinear(nn.Module):
@staticmethod
- def from_linear(linear: nn.Linear, rank: int = 8, scale: float = 20.0):
+ def from_linear(
+ linear: nn.Linear,
+ r: int = 8,
+ lora_alpha: float = 16,
+ lora_dropout: float = 0.05,
+ scale: float = 10.0,
+ ):
# TODO remove when input_dims and output_dims are attributes
# on linear and quantized linear
output_dims, input_dims = linear.weight.shape
if isinstance(linear, nn.QuantizedLinear):
input_dims *= 32 // linear.bits
lora_lin = LoRALinear(
- input_dims=input_dims, output_dims=output_dims, rank=rank, scale=scale
+ input_dims=input_dims,
+ output_dims=output_dims,
+ r=r,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ scale=scale,
)
lora_lin.linear = linear
return lora_lin
@@ -58,31 +69,35 @@ class LoRALinear(nn.Module):
self,
input_dims: int,
output_dims: int,
- rank: int = 8,
+ r: int = 8,
+ lora_alpha: float = 16,
+ lora_dropout: float = 0.0,
+ scale: float = 10.0,
bias: bool = False,
- scale: float = 20.0,
):
super().__init__()
# Regular linear layer weights
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
+
# Scale for low-rank update
- self.scale = scale
+ self.scale = scale * (lora_alpha / r)
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)
self.lora_a = mx.random.uniform(
low=-scale,
high=scale,
- shape=(input_dims, rank),
+ shape=(input_dims, r),
)
- self.lora_b = mx.zeros(shape=(rank, output_dims))
+ self.lora_b = mx.zeros(shape=(r, output_dims))
def __call__(self, x):
dtype = self.linear.weight.dtype
if isinstance(self.linear, nn.QuantizedLinear):
dtype = self.linear.scales.dtype
y = self.linear(x.astype(dtype))
- z = (x @ self.lora_a) @ self.lora_b
+ z = (self.lora_dropout(x) @ self.lora_a) @ self.lora_b
return y + self.scale * z
diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py
index ab5b99af..d670ee71 100644
--- a/llms/mlx_lm/utils.py
+++ b/llms/mlx_lm/utils.py
@@ -278,6 +278,7 @@ def load(
model = load_model(model_path)
if adapter_file is not None:
model = apply_lora_layers(model, adapter_file)
+ model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config)
return model, tokenizer