From ebf314bdcc1506e74381b5aa831c911b9d8c0af4 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 8 Nov 2024 13:01:51 -0800 Subject: [PATCH] Nits --- flux/dreambooth.py | 6 ++---- flux/flux/utils.py | 7 ++----- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/flux/dreambooth.py b/flux/dreambooth.py index 83a629f6..ffdb02d7 100644 --- a/flux/dreambooth.py +++ b/flux/dreambooth.py @@ -282,13 +282,11 @@ if __name__ == "__main__": generate_progress_images(i + 1, flux, args) if (i + 1) % args.checkpoint_every == 0: - file_name = f"{i + 1:07d}_adapters.safetensors" - save_adapters(file_name, flux, args) + save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args) if (i + 1) % 10 == 0: losses = [] tic = time.time() - final_adapter = "final_adapter.safetensors" - save_adapters(final_adapter, flux, args) + save_adapters("final_adapters.safetensors", flux, args) print(f"Training successful. Saved final weights to {args.adapter_file}.") diff --git a/flux/flux/utils.py b/flux/flux/utils.py index ffd99176..2437f21f 100644 --- a/flux/flux/utils.py +++ b/flux/flux/utils.py @@ -222,12 +222,9 @@ def save_config( config (dict): The model configuration. config_path (Union[str, Path]): Model configuration file path. """ - # Clean unused keys - config.pop("_name_or_path", None) - - # sort the config for better readability + # Sort the config for better readability config = dict(sorted(config.items())) - # write the updated config to the config_path (if provided) + # Write the config to the provided file with open(config_path, "w") as fid: json.dump(config, fid, indent=4)