Adding full finetuning (#903)

* Adding full model weights finetuning

* Updating the LORA.md and ACKNOWLEDGMENTS.md files.

* removing --use-dora and --fulll-training and adding --fine-tune-type

* some clean up

* reformating and fixing dora training

* updated CONFIG_DEFAULTS

* update config example

* update in the config example fie

* Update LORA.md

* merge and commit

* adding argument for dora linear layer

* clean up

* clean up in the example yaml file

* fix

* final fix before sending

* small addition to re md file

* fix for loading the fully trained model by saving all the files and configs correctly

* clean up

* removing the unnesesairy files

* changing lora layers back to 16

* removed max file size

* nits

* resolve merge

* some consistency changes

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Gökdeniz Gülmez
2024-09-30 02:12:47 +02:00
committed by GitHub
parent 7ec2021bb9
commit 50e5ca81a8
9 changed files with 79 additions and 70 deletions

View File

@@ -8,7 +8,7 @@ from mlx.utils import tree_flatten, tree_unflatten
from .gguf import convert_to_gguf
from .tuner.dora import DoRAEmbedding, DoRALinear
from .tuner.lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear
from .tuner.utils import apply_lora_layers, dequantize
from .tuner.utils import dequantize, load_adapters
from .utils import (
fetch_from_hub,
get_model_path,
@@ -29,7 +29,7 @@ def parse_arguments() -> argparse.Namespace:
)
parser.add_argument(
"--save-path",
default="lora_fused_model",
default="fused_model",
help="The path to save the fused model.",
)
parser.add_argument(
@@ -77,17 +77,14 @@ def main() -> None:
model, config, tokenizer = fetch_from_hub(model_path)
model.freeze()
model = apply_lora_layers(model, args.adapter_path)
model = load_adapters(model, args.adapter_path)
fused_linears = [
(n, m.fuse())
for n, m in model.named_modules()
if isinstance(
m, (LoRASwitchLinear, LoRALinear, LoRAEmbedding, DoRALinear, DoRAEmbedding)
)
(n, m.fuse()) for n, m in model.named_modules() if hasattr(m, "fuse")
]
model.update_modules(tree_unflatten(fused_linears))
if fused_linears:
model.update_modules(tree_unflatten(fused_linears))
if args.de_quantize:
print("De-quantizing model")