mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
feat(mlx-lm): add lora hypeparameters in lora layer (#366)
* feat(mlx-lm): add lora hypeparameters in lora layer * chore: address comments
This commit is contained in:
parent
5fc8668a53
commit
b1dec281b3
@ -242,6 +242,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
if args.prompt is not None:
|
if args.prompt is not None:
|
||||||
print("Generating")
|
print("Generating")
|
||||||
|
model.eval()
|
||||||
generate(
|
generate(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
@ -6,14 +6,25 @@ import mlx.nn as nn
|
|||||||
|
|
||||||
class LoRALinear(nn.Module):
|
class LoRALinear(nn.Module):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_linear(linear: nn.Linear, rank: int = 8, scale: float = 20.0):
|
def from_linear(
|
||||||
|
linear: nn.Linear,
|
||||||
|
r: int = 8,
|
||||||
|
lora_alpha: float = 16,
|
||||||
|
lora_dropout: float = 0.05,
|
||||||
|
scale: float = 10.0,
|
||||||
|
):
|
||||||
# TODO remove when input_dims and output_dims are attributes
|
# TODO remove when input_dims and output_dims are attributes
|
||||||
# on linear and quantized linear
|
# on linear and quantized linear
|
||||||
output_dims, input_dims = linear.weight.shape
|
output_dims, input_dims = linear.weight.shape
|
||||||
if isinstance(linear, nn.QuantizedLinear):
|
if isinstance(linear, nn.QuantizedLinear):
|
||||||
input_dims *= 32 // linear.bits
|
input_dims *= 32 // linear.bits
|
||||||
lora_lin = LoRALinear(
|
lora_lin = LoRALinear(
|
||||||
input_dims=input_dims, output_dims=output_dims, rank=rank, scale=scale
|
input_dims=input_dims,
|
||||||
|
output_dims=output_dims,
|
||||||
|
r=r,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
lora_dropout=lora_dropout,
|
||||||
|
scale=scale,
|
||||||
)
|
)
|
||||||
lora_lin.linear = linear
|
lora_lin.linear = linear
|
||||||
return lora_lin
|
return lora_lin
|
||||||
@ -58,31 +69,35 @@ class LoRALinear(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_dims: int,
|
input_dims: int,
|
||||||
output_dims: int,
|
output_dims: int,
|
||||||
rank: int = 8,
|
r: int = 8,
|
||||||
|
lora_alpha: float = 16,
|
||||||
|
lora_dropout: float = 0.0,
|
||||||
|
scale: float = 10.0,
|
||||||
bias: bool = False,
|
bias: bool = False,
|
||||||
scale: float = 20.0,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Regular linear layer weights
|
# Regular linear layer weights
|
||||||
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
|
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||||
|
|
||||||
|
self.lora_dropout = nn.Dropout(p=lora_dropout)
|
||||||
|
|
||||||
# Scale for low-rank update
|
# Scale for low-rank update
|
||||||
self.scale = scale
|
self.scale = scale * (lora_alpha / r)
|
||||||
|
|
||||||
# Low rank lora weights
|
# Low rank lora weights
|
||||||
scale = 1 / math.sqrt(input_dims)
|
scale = 1 / math.sqrt(input_dims)
|
||||||
self.lora_a = mx.random.uniform(
|
self.lora_a = mx.random.uniform(
|
||||||
low=-scale,
|
low=-scale,
|
||||||
high=scale,
|
high=scale,
|
||||||
shape=(input_dims, rank),
|
shape=(input_dims, r),
|
||||||
)
|
)
|
||||||
self.lora_b = mx.zeros(shape=(rank, output_dims))
|
self.lora_b = mx.zeros(shape=(r, output_dims))
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
dtype = self.linear.weight.dtype
|
dtype = self.linear.weight.dtype
|
||||||
if isinstance(self.linear, nn.QuantizedLinear):
|
if isinstance(self.linear, nn.QuantizedLinear):
|
||||||
dtype = self.linear.scales.dtype
|
dtype = self.linear.scales.dtype
|
||||||
y = self.linear(x.astype(dtype))
|
y = self.linear(x.astype(dtype))
|
||||||
z = (x @ self.lora_a) @ self.lora_b
|
z = (self.lora_dropout(x) @ self.lora_a) @ self.lora_b
|
||||||
return y + self.scale * z
|
return y + self.scale * z
|
||||||
|
@ -278,6 +278,7 @@ def load(
|
|||||||
model = load_model(model_path)
|
model = load_model(model_path)
|
||||||
if adapter_file is not None:
|
if adapter_file is not None:
|
||||||
model = apply_lora_layers(model, adapter_file)
|
model = apply_lora_layers(model, adapter_file)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config)
|
tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_config)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
Loading…
Reference in New Issue
Block a user