mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Add argument --save-every N
to lora.py for saving model regularly (#310)
This commit is contained in:
parent
b4c20cc7f7
commit
d8680a89f9
13
lora/lora.py
13
lora/lora.py
@ -96,6 +96,12 @@ def build_parser():
|
||||
default="adapters.npz",
|
||||
help="Save/load path for the trained adapter weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-every",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Save the model every N iterations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test",
|
||||
action="store_true",
|
||||
@ -262,6 +268,13 @@ def train(model, train_set, val_set, optimizer, loss, tokenizer, args):
|
||||
|
||||
start = time.perf_counter()
|
||||
|
||||
# Save adapter weights if needed
|
||||
if (it + 1) % args.save_every == 0:
|
||||
mx.savez(
|
||||
args.adapter_file, **dict(tree_flatten(model.trainable_parameters()))
|
||||
)
|
||||
print(f"Iter {it + 1}: Saved adapter weights to {args.adapter_file}.")
|
||||
|
||||
|
||||
def generate(model, prompt, tokenizer, args):
|
||||
print(prompt, end="", flush=True)
|
||||
|
Loading…
Reference in New Issue
Block a user