Add argument --save-every N to lora.py for saving model regularly (#310)

This commit is contained in:
Zheng Qu 2024-01-17 05:03:33 +01:00 committed by GitHub
parent b4c20cc7f7
commit d8680a89f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -96,6 +96,12 @@ def build_parser():
default="adapters.npz", default="adapters.npz",
help="Save/load path for the trained adapter weights.", help="Save/load path for the trained adapter weights.",
) )
parser.add_argument(
"--save-every",
type=int,
default=100,
help="Save the model every N iterations.",
)
parser.add_argument( parser.add_argument(
"--test", "--test",
action="store_true", action="store_true",
@ -262,6 +268,13 @@ def train(model, train_set, val_set, optimizer, loss, tokenizer, args):
start = time.perf_counter() start = time.perf_counter()
# Save adapter weights if needed
if (it + 1) % args.save_every == 0:
mx.savez(
args.adapter_file, **dict(tree_flatten(model.trainable_parameters()))
)
print(f"Iter {it + 1}: Saved adapter weights to {args.adapter_file}.")
def generate(model, prompt, tokenizer, args): def generate(model, prompt, tokenizer, args):
print(prompt, end="", flush=True) print(prompt, end="", flush=True)