feat(mlx-lm): add lora hypeparameters in lora layer (#366)

* feat(mlx-lm): add lora hypeparameters in lora layer

* chore: address comments
This commit is contained in:
Anchen 2024-01-25 03:11:25 +11:00 committed by GitHub
parent 5fc8668a53
commit b1dec281b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 25 additions and 8 deletions

View File

@ -242,6 +242,7 @@ if __name__ == "__main__":
if args.prompt is not None:
print("Generating")
model.eval()
generate(
model=model,
tokenizer=tokenizer,

View File

@ -6,14 +6,25 @@ import mlx.nn as nn
class LoRALinear(nn.Module):
@staticmethod
def from_linear(linear: nn.Linear, rank: int = 8, scale: float = 20.0):
def from_linear(
linear: nn.Linear,
r: int = 8,
lora_alpha: float = 16,
lora_dropout: float = 0.05,
scale: float = 10.0,
):
# TODO remove when input_dims and output_dims are attributes
# on linear and quantized linear
output_dims, input_dims = linear.weight.shape
if isinstance(linear, nn.QuantizedLinear):
input_dims *= 32 // linear.bits
lora_lin = LoRALinear(
input_dims=input_dims, output_dims=output_dims, rank=rank, scale=scale
input_dims=input_dims,
output_dims=output_dims,
r=r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
scale=scale,
)
lora_lin.linear = linear
return lora_lin
@ -58,31 +69,35 @@ class LoRALinear(nn.Module):
self,
input_dims: int,
output_dims: int,
rank: int = 8,
r: int = 8,
lora_alpha: float = 16,
lora_dropout: float = 0.0,
scale: float = 10.0,
bias: bool = False,
scale: float = 20.0,
):
super().__init__()
# Regular linear layer weights
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
self.lora_dropout = nn.Dropout(p=lora_dropout)
# Scale for low-rank update
self.scale = scale
self.scale = scale * (lora_alpha / r)
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)
self.lora_a = mx.random.uniform(
low=-scale,
high=scale,
shape=(input_dims, rank),
shape=(input_dims, r),
)
self.lora_b = mx.zeros(shape=(rank, output_dims))
self.lora_b = mx.zeros(shape=(r, output_dims))
def __call__(self, x):
dtype = self.linear.weight.dtype
if isinstance(self.linear, nn.QuantizedLinear):
dtype = self.linear.scales.dtype
y = self.linear(x.astype(dtype))
z = (x @ self.lora_a) @ self.lora_b
z = (self.lora_dropout(x) @ self.lora_a) @ self.lora_b
return y + self.scale * z

View File

@ -278,6 +278,7 @@ def load(
model = load_model(model_path)
if adapter_file is not None:
model = apply_lora_layers(model, adapter_file)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config)
return model, tokenizer