diff --git a/lora/lora.py b/lora/lora.py index fba22dd8..b522dfdb 100644 --- a/lora/lora.py +++ b/lora/lora.py @@ -96,6 +96,12 @@ def build_parser(): default="adapters.npz", help="Save/load path for the trained adapter weights.", ) + parser.add_argument( + "--save-every", + type=int, + default=100, + help="Save the model every N iterations.", + ) parser.add_argument( "--test", action="store_true", @@ -262,6 +268,13 @@ def train(model, train_set, val_set, optimizer, loss, tokenizer, args): start = time.perf_counter() + # Save adapter weights if needed + if (it + 1) % args.save_every == 0: + mx.savez( + args.adapter_file, **dict(tree_flatten(model.trainable_parameters())) + ) + print(f"Iter {it + 1}: Saved adapter weights to {args.adapter_file}.") + def generate(model, prompt, tokenizer, args): print(prompt, end="", flush=True)