Allow the entire model to be targed for LoRA and DoRA fine tuning: LoRA and DoRA embeddings with small DoRALinear bug fix (#914)

* feature: LoRA adapter for Embeddings

* feature: wire in LoRAEmbedding into the tuner. Allow the embedding and non model.layers Linear layers to be targeted for fine tuning

* feature: DoRA adapter for Embeddings

* feature: wire in DoRAEmbedding

* bugfix: ensure self.m is recalculated when the linear layer is changed in DoRALinear.from_linear

* refactor: prefer from_base over from_linear or from_embedding. prefer fuse over to_linear or to_embedding

* cleanup: remove unused imports in test_dora.py

* refactor: remove unnecessary non_layer_modules

* cleanup: remove wrong comments for lora embedding dropout. remove uncessary parens in dora embedding dropout

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Zai Thottakath
2024-08-16 09:38:36 -05:00
committed by GitHub
parent c50971e860
commit 4e01700816
5 changed files with 306 additions and 21 deletions

View File

@@ -6,8 +6,8 @@ from pathlib import Path
from mlx.utils import tree_flatten, tree_unflatten
from .gguf import convert_to_gguf
from .tuner.dora import DoRALinear
from .tuner.lora import LoRALinear, LoRASwitchLinear
from .tuner.dora import DoRAEmbedding, DoRALinear
from .tuner.lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear
from .tuner.utils import apply_lora_layers, dequantize
from .utils import (
fetch_from_hub,
@@ -80,9 +80,11 @@ def main() -> None:
model = apply_lora_layers(model, args.adapter_path)
fused_linears = [
(n, m.to_linear())
(n, m.fuse())
for n, m in model.named_modules()
if isinstance(m, (LoRASwitchLinear, LoRALinear, DoRALinear))
if isinstance(
m, (LoRASwitchLinear, LoRALinear, LoRAEmbedding, DoRALinear, DoRAEmbedding)
)
]
model.update_modules(tree_unflatten(fused_linears))

View File

@@ -8,7 +8,7 @@ import mlx.nn as nn
class DoRALinear(nn.Module):
@staticmethod
def from_linear(
def from_base(
linear: nn.Linear,
r: int = 8,
dropout: float = 0.0,
@@ -25,10 +25,10 @@ class DoRALinear(nn.Module):
dropout=dropout,
scale=scale,
)
dora_lin.linear = linear
dora_lin.set_linear(linear)
return dora_lin
def to_linear(self, de_quantize: bool = False):
def fuse(self, de_quantize: bool = False):
linear = self.linear
bias = "bias" in linear
weight = linear.weight
@@ -61,7 +61,7 @@ class DoRALinear(nn.Module):
super().__init__()
# Regular linear layer weights
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
self.set_linear(nn.Linear(input_dims, output_dims, bias=bias))
self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update
@@ -75,6 +75,9 @@ class DoRALinear(nn.Module):
shape=(input_dims, r),
)
self.lora_b = mx.zeros(shape=(r, output_dims))
def set_linear(self, linear: nn.Linear):
self.linear = linear
self.m = mx.linalg.norm(self.linear.weight, axis=1)
def __call__(self, x):
@@ -93,3 +96,102 @@ class DoRALinear(nn.Module):
if "bias" in self.linear:
out = out + self.linear.bias
return out
class DoRAEmbedding(nn.Module):
def from_base(
embedding: nn.Embedding,
r: int = 8,
dropout: float = 0.0,
scale: float = 20.0,
):
num_embeddings, dims = embedding.weight.shape
# TODO support quantized weights in DoRALinear
if isinstance(embedding, nn.QuantizedLinear):
raise ValueError("DoRAEmbedding does not yet support quantization.")
dora_embedding = DoRAEmbedding(
num_embeddings=num_embeddings,
dims=dims,
r=r,
dropout=dropout,
scale=scale,
)
dora_embedding.set_embedding(embedding)
return dora_embedding
def fuse(self, de_quantize: bool = False):
embedding = self.embedding
weight = embedding.weight
# Use the same type as the linear weight if not quantized
dtype = weight.dtype
num_embeddings, dims = weight.shape
fused_embedding = nn.Embedding(num_embeddings, dims)
lora_a = (self.scale * self.lora_a).astype(dtype)
lora_b = self.lora_b.astype(dtype)
weight = weight + lora_a @ lora_b
norm_scale = self.m / mx.linalg.norm(weight, axis=1)
fused_embedding.weight = norm_scale[:, None] * weight
return fused_embedding
def __init__(
self,
num_embeddings: int,
dims: int,
r: int = 8,
dropout: float = 0.0,
scale: float = 20.0,
):
super().__init__()
# Regular embedding layer weights
self.set_embedding(nn.Embedding(num_embeddings, dims))
self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update
self.scale = scale
# Low rank lora weights
scale = 1 / math.sqrt(num_embeddings)
self.lora_a = mx.random.uniform(
low=-scale,
high=scale,
shape=(num_embeddings, r),
)
self.lora_b = mx.zeros(shape=(r, dims))
def set_embedding(self, embedding: nn.Module):
self.embedding = embedding
self.m = mx.linalg.norm(embedding.weight, axis=1)
def __call__(self, x):
y = self.embedding(x)
z = self.scale * self.lora_a[x] @ self.lora_b
out = y + self.dropout(z).astype(y.dtype)
# Compute the norm of the adapted weights for the individual embeddings
adapted = y + z
denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=-1))
# Remove the norm and scale by the learned magnitude
out = (self.m[x] / denom)[..., None] * out
return out
def as_linear(self, x):
y = self.embedding.as_linear(x)
z = (self.dropout(x) @ self.lora_b.T) @ self.lora_a.T
out = y + (self.scale * z).astype(x.dtype)
# Compute the norm of the adapted weights
adapted = self.embedding.weight + (self.scale * self.lora_a) @ self.lora_b
denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1))
# Remove the norm and scale by the learned magnitude
out = (self.m / denom) * out
return out

View File

@@ -10,7 +10,7 @@ from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
class LoRALinear(nn.Module):
@staticmethod
def from_linear(
def from_base(
linear: nn.Linear,
r: int = 8,
dropout: float = 0.0,
@@ -31,7 +31,7 @@ class LoRALinear(nn.Module):
lora_lin.linear = linear
return lora_lin
def to_linear(self, de_quantize: bool = False):
def fuse(self, de_quantize: bool = False):
linear = self.linear
bias = "bias" in linear
weight = linear.weight
@@ -41,7 +41,7 @@ class LoRALinear(nn.Module):
dtype = weight.dtype
if is_quantized:
dtype = mx.float16
dtype = linear.scales.dtype
weight = mx.dequantize(
weight,
linear.scales,
@@ -103,7 +103,7 @@ class LoRALinear(nn.Module):
class LoRASwitchLinear(nn.Module):
@staticmethod
def from_linear(
def from_base(
linear: nn.Module,
r: int = 8,
dropout: float = 0.0,
@@ -120,7 +120,7 @@ class LoRASwitchLinear(nn.Module):
lora_lin.linear = linear
return lora_lin
def to_linear(self, de_quantize: bool = False):
def fuse(self, de_quantize: bool = False):
linear = self.linear
bias = "bias" in linear
weight = linear.weight
@@ -191,3 +191,95 @@ class LoRASwitchLinear(nn.Module):
z = z[..., None, :] @ self.lora_b[indices].swapaxes(-2, -1)
return y + (self.scale * z).astype(x.dtype)
class LoRAEmbedding(nn.Module):
@staticmethod
def from_base(
embedding: nn.Embedding,
r: int = 8,
dropout: float = 0.0,
scale: float = 20.0,
):
num_embeddings, dims = embedding.weight.shape
if isinstance(embedding, nn.QuantizedEmbedding):
dims *= 32 // embedding.bits
lora_embedding = LoRAEmbedding(
num_embeddings=num_embeddings,
dims=dims,
r=r,
dropout=dropout,
scale=scale,
)
lora_embedding.embedding = embedding
return lora_embedding
def fuse(self, de_quantize: bool = False):
embedding = self.embedding
weight = embedding.weight
is_quantized = isinstance(embedding, nn.QuantizedEmbedding)
# Use the same type as the linear weight if not quantized
dtype = weight.dtype
if is_quantized:
dtype = embedding.scales.dtype
weight = mx.dequantize(
weight,
embedding.scales,
embedding.biases,
embedding.group_size,
embedding.bits,
)
num_embeddings, dims = weight.shape
fused_embedding = nn.Embedding(num_embeddings, dims)
lora_a = (self.scale * self.lora_a).astype(dtype)
lora_b = self.lora_b.astype(dtype)
fused_embedding.weight = weight + lora_a @ lora_b
if is_quantized and not de_quantize:
fused_embedding = nn.QuantizedEmbedding.from_embedding(
fused_embedding,
embedding.group_size,
embedding.bits,
)
return fused_embedding
def __init__(
self,
num_embeddings: int,
dims: int,
r: int = 8,
dropout: float = 0.0,
scale: float = 20.0,
):
super().__init__()
# Regular embedding layer
self.embedding = nn.Embedding(num_embeddings, dims)
self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update
self.scale = scale
# Low rank lora weights
scale = 1 / math.sqrt(num_embeddings)
self.lora_a = mx.random.uniform(
low=-scale,
high=scale,
shape=(num_embeddings, r),
)
self.lora_b = mx.zeros(shape=(r, dims))
def __call__(self, x):
y = self.embedding(x)
z = self.dropout(self.lora_a[x] @ self.lora_b)
out = y + (self.scale * z).astype(y.dtype)
return out
def as_linear(self, x):
y = self.embedding.as_linear(x)
z = (self.dropout(x) @ self.lora_b.T) @ self.lora_a.T
return y + (self.scale * z).astype(x.dtype)

View File

@@ -10,8 +10,8 @@ import mlx.optimizers as opt
from mlx.utils import tree_flatten, tree_unflatten
from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
from .dora import DoRALinear
from .lora import LoRALinear, LoRASwitchLinear
from .dora import DoRAEmbedding, DoRALinear
from .lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear
def build_schedule(schedule_config: Dict):
@@ -71,12 +71,14 @@ def linear_to_lora_layers(
if use_dora:
raise ValueError(f"{type(layer).__name__} doesn't support DoRA yet.")
LoRALayer = LoRASwitchLinear
elif isinstance(layer, (nn.Embedding, nn.QuantizedEmbedding)):
LoRALayer = DoRAEmbedding if use_dora else LoRAEmbedding
else:
raise ValueError(
f"Can't convert layer of type {type(layer).__name__} to LoRA"
)
return LoRALayer.from_linear(
return LoRALayer.from_base(
layer,
r=config["rank"],
scale=config["scale"],
@@ -130,7 +132,12 @@ def linear_to_lora_layers(
for l in model.layers[num_layers - num_lora_layers :]:
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
l.update_modules(tree_unflatten(lora_layers))
if lora_layers:
l.update_modules(tree_unflatten(lora_layers))
lora_modules = [(k, to_lora(m)) for k, m in model.named_modules() if k in keys]
if lora_modules:
model.update_modules(tree_unflatten(lora_modules))
def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module: