mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-07 17:44:35 +08:00
Adding full finetuning (#903)
* Adding full model weights finetuning * Updating the LORA.md and ACKNOWLEDGMENTS.md files. * removing --use-dora and --fulll-training and adding --fine-tune-type * some clean up * reformating and fixing dora training * updated CONFIG_DEFAULTS * update config example * update in the config example fie * Update LORA.md * merge and commit * adding argument for dora linear layer * clean up * clean up in the example yaml file * fix * final fix before sending * small addition to re md file * fix for loading the fully trained model by saving all the files and configs correctly * clean up * removing the unnesesairy files * changing lora layers back to 16 * removed max file size * nits * resolve merge * some consistency changes --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -8,7 +8,7 @@ from mlx.utils import tree_flatten, tree_unflatten
|
||||
from .gguf import convert_to_gguf
|
||||
from .tuner.dora import DoRAEmbedding, DoRALinear
|
||||
from .tuner.lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear
|
||||
from .tuner.utils import apply_lora_layers, dequantize
|
||||
from .tuner.utils import dequantize, load_adapters
|
||||
from .utils import (
|
||||
fetch_from_hub,
|
||||
get_model_path,
|
||||
@@ -29,7 +29,7 @@ def parse_arguments() -> argparse.Namespace:
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
default="lora_fused_model",
|
||||
default="fused_model",
|
||||
help="The path to save the fused model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -77,17 +77,14 @@ def main() -> None:
|
||||
model, config, tokenizer = fetch_from_hub(model_path)
|
||||
|
||||
model.freeze()
|
||||
model = apply_lora_layers(model, args.adapter_path)
|
||||
model = load_adapters(model, args.adapter_path)
|
||||
|
||||
fused_linears = [
|
||||
(n, m.fuse())
|
||||
for n, m in model.named_modules()
|
||||
if isinstance(
|
||||
m, (LoRASwitchLinear, LoRALinear, LoRAEmbedding, DoRALinear, DoRAEmbedding)
|
||||
)
|
||||
(n, m.fuse()) for n, m in model.named_modules() if hasattr(m, "fuse")
|
||||
]
|
||||
|
||||
model.update_modules(tree_unflatten(fused_linears))
|
||||
if fused_linears:
|
||||
model.update_modules(tree_unflatten(fused_linears))
|
||||
|
||||
if args.de_quantize:
|
||||
print("De-quantizing model")
|
||||
|
Reference in New Issue
Block a user