mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
feat(mlx-lm): add lora hypeparameters in lora layer (#366)
* feat(mlx-lm): add lora hypeparameters in lora layer * chore: address comments
This commit is contained in:
parent
5fc8668a53
commit
b1dec281b3
@ -242,6 +242,7 @@ if __name__ == "__main__":
|
||||
|
||||
if args.prompt is not None:
|
||||
print("Generating")
|
||||
model.eval()
|
||||
generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
|
@ -6,14 +6,25 @@ import mlx.nn as nn
|
||||
|
||||
class LoRALinear(nn.Module):
|
||||
@staticmethod
|
||||
def from_linear(linear: nn.Linear, rank: int = 8, scale: float = 20.0):
|
||||
def from_linear(
|
||||
linear: nn.Linear,
|
||||
r: int = 8,
|
||||
lora_alpha: float = 16,
|
||||
lora_dropout: float = 0.05,
|
||||
scale: float = 10.0,
|
||||
):
|
||||
# TODO remove when input_dims and output_dims are attributes
|
||||
# on linear and quantized linear
|
||||
output_dims, input_dims = linear.weight.shape
|
||||
if isinstance(linear, nn.QuantizedLinear):
|
||||
input_dims *= 32 // linear.bits
|
||||
lora_lin = LoRALinear(
|
||||
input_dims=input_dims, output_dims=output_dims, rank=rank, scale=scale
|
||||
input_dims=input_dims,
|
||||
output_dims=output_dims,
|
||||
r=r,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
scale=scale,
|
||||
)
|
||||
lora_lin.linear = linear
|
||||
return lora_lin
|
||||
@ -58,31 +69,35 @@ class LoRALinear(nn.Module):
|
||||
self,
|
||||
input_dims: int,
|
||||
output_dims: int,
|
||||
rank: int = 8,
|
||||
r: int = 8,
|
||||
lora_alpha: float = 16,
|
||||
lora_dropout: float = 0.0,
|
||||
scale: float = 10.0,
|
||||
bias: bool = False,
|
||||
scale: float = 20.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Regular linear layer weights
|
||||
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||
|
||||
self.lora_dropout = nn.Dropout(p=lora_dropout)
|
||||
|
||||
# Scale for low-rank update
|
||||
self.scale = scale
|
||||
self.scale = scale * (lora_alpha / r)
|
||||
|
||||
# Low rank lora weights
|
||||
scale = 1 / math.sqrt(input_dims)
|
||||
self.lora_a = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(input_dims, rank),
|
||||
shape=(input_dims, r),
|
||||
)
|
||||
self.lora_b = mx.zeros(shape=(rank, output_dims))
|
||||
self.lora_b = mx.zeros(shape=(r, output_dims))
|
||||
|
||||
def __call__(self, x):
|
||||
dtype = self.linear.weight.dtype
|
||||
if isinstance(self.linear, nn.QuantizedLinear):
|
||||
dtype = self.linear.scales.dtype
|
||||
y = self.linear(x.astype(dtype))
|
||||
z = (x @ self.lora_a) @ self.lora_b
|
||||
z = (self.lora_dropout(x) @ self.lora_a) @ self.lora_b
|
||||
return y + self.scale * z
|
||||
|
@ -278,6 +278,7 @@ def load(
|
||||
model = load_model(model_path)
|
||||
if adapter_file is not None:
|
||||
model = apply_lora_layers(model, adapter_file)
|
||||
model.eval()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config)
|
||||
return model, tokenizer
|
||||
|
Loading…
Reference in New Issue
Block a user