From b1dec281b3a1d5482ccfc703473eb09e8f51a9bf Mon Sep 17 00:00:00 2001 From: Anchen Date: Thu, 25 Jan 2024 03:11:25 +1100 Subject: [PATCH] feat(mlx-lm): add lora hypeparameters in lora layer (#366) * feat(mlx-lm): add lora hypeparameters in lora layer * chore: address comments --- llms/mlx_lm/lora.py | 1 + llms/mlx_lm/tuner/lora.py | 31 +++++++++++++++++++++++-------- llms/mlx_lm/utils.py | 1 + 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 2bcb8099..ce4d1854 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -242,6 +242,7 @@ if __name__ == "__main__": if args.prompt is not None: print("Generating") + model.eval() generate( model=model, tokenizer=tokenizer, diff --git a/llms/mlx_lm/tuner/lora.py b/llms/mlx_lm/tuner/lora.py index f0ec601b..2a64e5a0 100644 --- a/llms/mlx_lm/tuner/lora.py +++ b/llms/mlx_lm/tuner/lora.py @@ -6,14 +6,25 @@ import mlx.nn as nn class LoRALinear(nn.Module): @staticmethod - def from_linear(linear: nn.Linear, rank: int = 8, scale: float = 20.0): + def from_linear( + linear: nn.Linear, + r: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.05, + scale: float = 10.0, + ): # TODO remove when input_dims and output_dims are attributes # on linear and quantized linear output_dims, input_dims = linear.weight.shape if isinstance(linear, nn.QuantizedLinear): input_dims *= 32 // linear.bits lora_lin = LoRALinear( - input_dims=input_dims, output_dims=output_dims, rank=rank, scale=scale + input_dims=input_dims, + output_dims=output_dims, + r=r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + scale=scale, ) lora_lin.linear = linear return lora_lin @@ -58,31 +69,35 @@ class LoRALinear(nn.Module): self, input_dims: int, output_dims: int, - rank: int = 8, + r: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.0, + scale: float = 10.0, bias: bool = False, - scale: float = 20.0, ): super().__init__() # Regular linear layer weights self.linear = nn.Linear(input_dims, output_dims, bias=bias) + self.lora_dropout = nn.Dropout(p=lora_dropout) + # Scale for low-rank update - self.scale = scale + self.scale = scale * (lora_alpha / r) # Low rank lora weights scale = 1 / math.sqrt(input_dims) self.lora_a = mx.random.uniform( low=-scale, high=scale, - shape=(input_dims, rank), + shape=(input_dims, r), ) - self.lora_b = mx.zeros(shape=(rank, output_dims)) + self.lora_b = mx.zeros(shape=(r, output_dims)) def __call__(self, x): dtype = self.linear.weight.dtype if isinstance(self.linear, nn.QuantizedLinear): dtype = self.linear.scales.dtype y = self.linear(x.astype(dtype)) - z = (x @ self.lora_a) @ self.lora_b + z = (self.lora_dropout(x) @ self.lora_a) @ self.lora_b return y + self.scale * z diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index ab5b99af..d670ee71 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -278,6 +278,7 @@ def load( model = load_model(model_path) if adapter_file is not None: model = apply_lora_layers(model, adapter_file) + model.eval() tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config) return model, tokenizer