mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Add lr schedule
This commit is contained in:
@@ -153,7 +153,10 @@ if __name__ == "__main__":
|
|||||||
"--lora-rank", type=int, default=32, help="LoRA rank for finetuning"
|
"--lora-rank", type=int, default=32, help="LoRA rank for finetuning"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--learning-rate", type=float, default="1e-6", help="Learning rate for training"
|
"--warmup-steps", type=int, default=100, help="Learning rate warmup"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--learning-rate", type=float, default="1e-4", help="Learning rate for training"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--grad-accumulate",
|
"--grad-accumulate",
|
||||||
@@ -182,7 +185,12 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
print(f"Training {trainable_params / 1024**2:.3f}M parameters")
|
print(f"Training {trainable_params / 1024**2:.3f}M parameters")
|
||||||
|
|
||||||
optimizer = optim.Adam(learning_rate=args.learning_rate)
|
warmup = optim.linear_schedule(0, args.learning_rate, args.warmup_steps)
|
||||||
|
cosine = optim.cosine_decay(
|
||||||
|
args.learning_rate, args.iterations // args.grad_accumulate
|
||||||
|
)
|
||||||
|
lr_schedule = optim.join_schedules([warmup, cosine], [args.warmup_steps])
|
||||||
|
optimizer = optim.Adam(learning_rate=lr_schedule)
|
||||||
state = [flux.flow.state, optimizer.state, mx.random.state]
|
state = [flux.flow.state, optimizer.state, mx.random.state]
|
||||||
|
|
||||||
@partial(mx.compile, inputs=state, outputs=state)
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
|
Reference in New Issue
Block a user