Allow the entire model to be targed for LoRA and DoRA fine tuning: LoRA and DoRA embeddings with small DoRALinear bug fix (#914)

* feature: LoRA adapter for Embeddings

* feature: wire in LoRAEmbedding into the tuner. Allow the embedding and non model.layers Linear layers to be targeted for fine tuning

* feature: DoRA adapter for Embeddings

* feature: wire in DoRAEmbedding

* bugfix: ensure self.m is recalculated when the linear layer is changed in DoRALinear.from_linear

* refactor: prefer from_base over from_linear or from_embedding. prefer fuse over to_linear or to_embedding

* cleanup: remove unused imports in test_dora.py

* refactor: remove unnecessary non_layer_modules

* cleanup: remove wrong comments for lora embedding dropout. remove uncessary parens in dora embedding dropout

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Zai Thottakath
2024-08-16 09:38:36 -05:00
committed by GitHub
parent c50971e860
commit 4e01700816
5 changed files with 306 additions and 21 deletions

View File

@@ -10,8 +10,8 @@ import mlx.optimizers as opt
from mlx.utils import tree_flatten, tree_unflatten
from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
from .dora import DoRALinear
from .lora import LoRALinear, LoRASwitchLinear
from .dora import DoRAEmbedding, DoRALinear
from .lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear
def build_schedule(schedule_config: Dict):
@@ -71,12 +71,14 @@ def linear_to_lora_layers(
if use_dora:
raise ValueError(f"{type(layer).__name__} doesn't support DoRA yet.")
LoRALayer = LoRASwitchLinear
elif isinstance(layer, (nn.Embedding, nn.QuantizedEmbedding)):
LoRALayer = DoRAEmbedding if use_dora else LoRAEmbedding
else:
raise ValueError(
f"Can't convert layer of type {type(layer).__name__} to LoRA"
)
return LoRALayer.from_linear(
return LoRALayer.from_base(
layer,
r=config["rank"],
scale=config["scale"],
@@ -130,7 +132,12 @@ def linear_to_lora_layers(
for l in model.layers[num_layers - num_lora_layers :]:
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
l.update_modules(tree_unflatten(lora_layers))
if lora_layers:
l.update_modules(tree_unflatten(lora_layers))
lora_modules = [(k, to_lora(m)) for k, m in model.named_modules() if k in keys]
if lora_modules:
model.update_modules(tree_unflatten(lora_modules))
def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module: