mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 11:54:37 +08:00
Add lr schedule
This commit is contained in:
parent
e7751e4c29
commit
f2ccad52f4
@ -153,7 +153,10 @@ if __name__ == "__main__":
|
||||
"--lora-rank", type=int, default=32, help="LoRA rank for finetuning"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning-rate", type=float, default="1e-6", help="Learning rate for training"
|
||||
"--warmup-steps", type=int, default=100, help="Learning rate warmup"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning-rate", type=float, default="1e-4", help="Learning rate for training"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grad-accumulate",
|
||||
@ -182,7 +185,12 @@ if __name__ == "__main__":
|
||||
)
|
||||
print(f"Training {trainable_params / 1024**2:.3f}M parameters")
|
||||
|
||||
optimizer = optim.Adam(learning_rate=args.learning_rate)
|
||||
warmup = optim.linear_schedule(0, args.learning_rate, args.warmup_steps)
|
||||
cosine = optim.cosine_decay(
|
||||
args.learning_rate, args.iterations // args.grad_accumulate
|
||||
)
|
||||
lr_schedule = optim.join_schedules([warmup, cosine], [args.warmup_steps])
|
||||
optimizer = optim.Adam(learning_rate=lr_schedule)
|
||||
state = [flux.flow.state, optimizer.state, mx.random.state]
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
|
Loading…
Reference in New Issue
Block a user