Save lora config (#636)

* lora config

* comments

* version bump
This commit is contained in:
Awni Hannun
2024-04-02 13:52:53 -07:00
committed by GitHub
parent d661440dbb
commit 2bd64b78cf
10 changed files with 73 additions and 90 deletions

View File

@@ -1,4 +1,7 @@
import os
# Copyright © 2024 Apple Inc.
import json
import types
from pathlib import Path
from typing import Dict
import mlx.core as mx
@@ -91,40 +94,28 @@ def linear_to_lora_layers(
raise ValueError(f"Lora does not support {model.model_type}")
for l in model.layers[num_layers - num_lora_layers :]:
modules = l.named_modules()
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
l.update_modules(tree_unflatten(lora_layers))
def apply_lora_layers(model: nn.Module, adapter_file: str) -> nn.Module:
def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module:
"""
Apply LoRA layers to the model.
Args:
model (nn.Module): The neural network model.
adapter_file (str): Path to the adapter configuration file.
adapter_path (str): Path to the adapter configuration file.
Returns:
nn.Module: The updated model with LoRA layers applied.
"""
if not os.path.exists(adapter_file):
raise FileNotFoundError(f"The adapter file does not exist: {adapter_file}")
adapters = list(mx.load(adapter_file).items())
linear_replacements = []
lora_layers = set(
[name.replace(".lora_a", "").replace(".lora_b", "") for name, _ in adapters]
)
for name, module in model.named_modules():
if name in lora_layers:
replacement_module = LoRALinear.from_linear(module)
linear_replacements.append((name, replacement_module))
model.update_modules(tree_unflatten(linear_replacements))
model.update(tree_unflatten(adapters))
adapter_path = Path(adapter_path)
if not adapter_path.exists():
raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}")
with open(adapter_path / "adapter_config.json", "r") as fid:
config = types.SimpleNamespace(**json.load(fid))
linear_to_lora_layers(model, config.lora_layers, config.lora_parameters)
model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False)
return model