diff --git a/flux/dreambooth.py b/flux/dreambooth.py index 2c83458d..e9995aea 100644 --- a/flux/dreambooth.py +++ b/flux/dreambooth.py @@ -153,7 +153,10 @@ if __name__ == "__main__": "--lora-rank", type=int, default=32, help="LoRA rank for finetuning" ) parser.add_argument( - "--learning-rate", type=float, default="1e-6", help="Learning rate for training" + "--warmup-steps", type=int, default=100, help="Learning rate warmup" + ) + parser.add_argument( + "--learning-rate", type=float, default="1e-4", help="Learning rate for training" ) parser.add_argument( "--grad-accumulate", @@ -182,7 +185,12 @@ if __name__ == "__main__": ) print(f"Training {trainable_params / 1024**2:.3f}M parameters") - optimizer = optim.Adam(learning_rate=args.learning_rate) + warmup = optim.linear_schedule(0, args.learning_rate, args.warmup_steps) + cosine = optim.cosine_decay( + args.learning_rate, args.iterations // args.grad_accumulate + ) + lr_schedule = optim.join_schedules([warmup, cosine], [args.warmup_steps]) + optimizer = optim.Adam(learning_rate=lr_schedule) state = [flux.flow.state, optimizer.state, mx.random.state] @partial(mx.compile, inputs=state, outputs=state)