diff --git a/flux/dreambooth.py b/flux/dreambooth.py index 83a629f6..ffdb02d7 100644 --- a/flux/dreambooth.py +++ b/flux/dreambooth.py @@ -282,13 +282,11 @@ if __name__ == "__main__": generate_progress_images(i + 1, flux, args) if (i + 1) % args.checkpoint_every == 0: - file_name = f"{i + 1:07d}_adapters.safetensors" - save_adapters(file_name, flux, args) + save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args) if (i + 1) % 10 == 0: losses = [] tic = time.time() - final_adapter = "final_adapter.safetensors" - save_adapters(final_adapter, flux, args) + save_adapters("final_adapters.safetensors", flux, args) print(f"Training successful. Saved final weights to {args.adapter_file}.") diff --git a/flux/flux/utils.py b/flux/flux/utils.py index ffd99176..2437f21f 100644 --- a/flux/flux/utils.py +++ b/flux/flux/utils.py @@ -222,12 +222,9 @@ def save_config( config (dict): The model configuration. config_path (Union[str, Path]): Model configuration file path. """ - # Clean unused keys - config.pop("_name_or_path", None) - - # sort the config for better readability + # Sort the config for better readability config = dict(sorted(config.items())) - # write the updated config to the config_path (if provided) + # Write the config to the provided file with open(config_path, "w") as fid: json.dump(config, fid, indent=4)