mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Allow the entire model to be targed for LoRA and DoRA fine tuning: LoRA and DoRA embeddings with small DoRALinear bug fix (#914)
* feature: LoRA adapter for Embeddings * feature: wire in LoRAEmbedding into the tuner. Allow the embedding and non model.layers Linear layers to be targeted for fine tuning * feature: DoRA adapter for Embeddings * feature: wire in DoRAEmbedding * bugfix: ensure self.m is recalculated when the linear layer is changed in DoRALinear.from_linear * refactor: prefer from_base over from_linear or from_embedding. prefer fuse over to_linear or to_embedding * cleanup: remove unused imports in test_dora.py * refactor: remove unnecessary non_layer_modules * cleanup: remove wrong comments for lora embedding dropout. remove uncessary parens in dora embedding dropout * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -10,8 +10,8 @@ import mlx.optimizers as opt
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
|
||||
from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
|
||||
from .dora import DoRALinear
|
||||
from .lora import LoRALinear, LoRASwitchLinear
|
||||
from .dora import DoRAEmbedding, DoRALinear
|
||||
from .lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear
|
||||
|
||||
|
||||
def build_schedule(schedule_config: Dict):
|
||||
@@ -71,12 +71,14 @@ def linear_to_lora_layers(
|
||||
if use_dora:
|
||||
raise ValueError(f"{type(layer).__name__} doesn't support DoRA yet.")
|
||||
LoRALayer = LoRASwitchLinear
|
||||
elif isinstance(layer, (nn.Embedding, nn.QuantizedEmbedding)):
|
||||
LoRALayer = DoRAEmbedding if use_dora else LoRAEmbedding
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Can't convert layer of type {type(layer).__name__} to LoRA"
|
||||
)
|
||||
|
||||
return LoRALayer.from_linear(
|
||||
return LoRALayer.from_base(
|
||||
layer,
|
||||
r=config["rank"],
|
||||
scale=config["scale"],
|
||||
@@ -130,7 +132,12 @@ def linear_to_lora_layers(
|
||||
|
||||
for l in model.layers[num_layers - num_lora_layers :]:
|
||||
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
|
||||
l.update_modules(tree_unflatten(lora_layers))
|
||||
if lora_layers:
|
||||
l.update_modules(tree_unflatten(lora_layers))
|
||||
|
||||
lora_modules = [(k, to_lora(m)) for k, m in model.named_modules() if k in keys]
|
||||
if lora_modules:
|
||||
model.update_modules(tree_unflatten(lora_modules))
|
||||
|
||||
|
||||
def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module:
|
||||
|
Reference in New Issue
Block a user