mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Add argument --save-every N
to lora.py for saving model regularly (#310)
This commit is contained in:
parent
b4c20cc7f7
commit
d8680a89f9
13
lora/lora.py
13
lora/lora.py
@ -96,6 +96,12 @@ def build_parser():
|
|||||||
default="adapters.npz",
|
default="adapters.npz",
|
||||||
help="Save/load path for the trained adapter weights.",
|
help="Save/load path for the trained adapter weights.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-every",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="Save the model every N iterations.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--test",
|
"--test",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@ -262,6 +268,13 @@ def train(model, train_set, val_set, optimizer, loss, tokenizer, args):
|
|||||||
|
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
|
||||||
|
# Save adapter weights if needed
|
||||||
|
if (it + 1) % args.save_every == 0:
|
||||||
|
mx.savez(
|
||||||
|
args.adapter_file, **dict(tree_flatten(model.trainable_parameters()))
|
||||||
|
)
|
||||||
|
print(f"Iter {it + 1}: Saved adapter weights to {args.adapter_file}.")
|
||||||
|
|
||||||
|
|
||||||
def generate(model, prompt, tokenizer, args):
|
def generate(model, prompt, tokenizer, args):
|
||||||
print(prompt, end="", flush=True)
|
print(prompt, end="", flush=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user