From d8680a89f986492dbc27c36af3294034db26458f Mon Sep 17 00:00:00 2001 From: Zheng Qu Date: Wed, 17 Jan 2024 05:03:33 +0100 Subject: [PATCH] Add argument `--save-every N` to lora.py for saving model regularly (#310) --- lora/lora.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/lora/lora.py b/lora/lora.py index fba22dd8..b522dfdb 100644 --- a/lora/lora.py +++ b/lora/lora.py @@ -96,6 +96,12 @@ def build_parser(): default="adapters.npz", help="Save/load path for the trained adapter weights.", ) + parser.add_argument( + "--save-every", + type=int, + default=100, + help="Save the model every N iterations.", + ) parser.add_argument( "--test", action="store_true", @@ -262,6 +268,13 @@ def train(model, train_set, val_set, optimizer, loss, tokenizer, args): start = time.perf_counter() + # Save adapter weights if needed + if (it + 1) % args.save_every == 0: + mx.savez( + args.adapter_file, **dict(tree_flatten(model.trainable_parameters())) + ) + print(f"Iter {it + 1}: Saved adapter weights to {args.adapter_file}.") + def generate(model, prompt, tokenizer, args): print(prompt, end="", flush=True)