diff --git a/flux/mlx_flux/dreambooth.py b/flux/mlx_flux/dreambooth.py index 91049cb1..2ea83eb6 100644 --- a/flux/mlx_flux/dreambooth.py +++ b/flux/mlx_flux/dreambooth.py @@ -155,7 +155,7 @@ def setup_arg_parser(): return parser -if __name__ == "__main__": +def main(): parser = setup_arg_parser() args = parser.parse_args() @@ -285,3 +285,7 @@ if __name__ == "__main__": if (i + 1) % 10 == 0: losses = [] tic = time.time() + + +if __name__ == "__main__": + main()