Add lr schedule

This commit is contained in:
Angelos Katharopoulos 2024-10-04 18:09:01 -07:00
parent e7751e4c29
commit f2ccad52f4

View File

@ -153,7 +153,10 @@ if __name__ == "__main__":
"--lora-rank", type=int, default=32, help="LoRA rank for finetuning"
)
parser.add_argument(
"--learning-rate", type=float, default="1e-6", help="Learning rate for training"
"--warmup-steps", type=int, default=100, help="Learning rate warmup"
)
parser.add_argument(
"--learning-rate", type=float, default="1e-4", help="Learning rate for training"
)
parser.add_argument(
"--grad-accumulate",
@ -182,7 +185,12 @@ if __name__ == "__main__":
)
print(f"Training {trainable_params / 1024**2:.3f}M parameters")
optimizer = optim.Adam(learning_rate=args.learning_rate)
warmup = optim.linear_schedule(0, args.learning_rate, args.warmup_steps)
cosine = optim.cosine_decay(
args.learning_rate, args.iterations // args.grad_accumulate
)
lr_schedule = optim.join_schedules([warmup, cosine], [args.warmup_steps])
optimizer = optim.Adam(learning_rate=lr_schedule)
state = [flux.flow.state, optimizer.state, mx.random.state]
@partial(mx.compile, inputs=state, outputs=state)