diff --git a/flux/dreambooth.py b/flux/dreambooth.py index 4a4dbb08..59fc38a5 100644 --- a/flux/dreambooth.py +++ b/flux/dreambooth.py @@ -157,7 +157,8 @@ def save_adapters(iteration, flux, args): ) -if __name__ == "__main__": +def setup_arg_parser(): + """Set up and return the argument parser.""" parser = argparse.ArgumentParser( description="Finetune Flux to generate images with a specific subject" ) @@ -247,7 +248,11 @@ if __name__ == "__main__": ) parser.add_argument("dataset") + return parser + +if __name__ == "__main__": + parser = setup_arg_parser() args = parser.parse_args() # Load the model and set it up for LoRA training. We use the same random