diff --git a/flux/README.md b/flux/README.md index 1a17e386..b00a9621 100644 --- a/flux/README.md +++ b/flux/README.md @@ -188,7 +188,7 @@ The adapters are saved in `mlx_output` and can be used directly by the ```shell python txt2image.py --model dev --save-raw --image-size 512x512 --n-images 1 \ - --adapter mlx_output/0001200_adapters.safetensors \ + --adapter mlx_output/final_adapters.safetensors \ --fuse-adapter \ --no-t5-padding \ 'A photo of an sks dog lying on the sand at a beach in Greece' diff --git a/flux/dreambooth.py b/flux/dreambooth.py index 48dcad47..ffdb02d7 100644 --- a/flux/dreambooth.py +++ b/flux/dreambooth.py @@ -13,7 +13,7 @@ from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten, tree_map, tree_reduce from PIL import Image -from flux import FluxPipeline, Trainer, load_dataset +from flux import FluxPipeline, Trainer, load_dataset, save_config def generate_progress_images(iteration, flux, args): @@ -43,10 +43,10 @@ def generate_progress_images(iteration, flux, args): im.save(out_file) -def save_adapters(iteration, flux, args): +def save_adapters(adapter_name, flux, args): out_dir = Path(args.output_dir) out_dir.mkdir(parents=True, exist_ok=True) - out_file = out_dir / f"{iteration:07d}_adapters.safetensors" + out_file = out_dir / adapter_name print(f"Saving {str(out_file)}") mx.save_safetensors( @@ -157,6 +157,10 @@ if __name__ == "__main__": parser = setup_arg_parser() args = parser.parse_args() + output_path = Path(args.output_dir) + output_path.mkdir(parents=True, exist_ok=True) + save_config(vars(args), output_path / "adapter_config.json") + # Load the model and set it up for LoRA training. We use the same random # state when creating the LoRA layers so all workers will have the same # initial weights. @@ -278,8 +282,11 @@ if __name__ == "__main__": generate_progress_images(i + 1, flux, args) if (i + 1) % args.checkpoint_every == 0: - save_adapters(i + 1, flux, args) + save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args) if (i + 1) % 10 == 0: losses = [] tic = time.time() + + save_adapters("final_adapters.safetensors", flux, args) + print(f"Training successful. Saved final weights to {args.adapter_file}.") diff --git a/flux/flux/__init__.py b/flux/flux/__init__.py index b1122d75..3dd423b7 100644 --- a/flux/flux/__init__.py +++ b/flux/flux/__init__.py @@ -12,4 +12,5 @@ from .utils import ( load_flow_model, load_t5, load_t5_tokenizer, + save_config, ) diff --git a/flux/flux/utils.py b/flux/flux/utils.py index 21db17d3..2437f21f 100644 --- a/flux/flux/utils.py +++ b/flux/flux/utils.py @@ -3,7 +3,8 @@ import json import os from dataclasses import dataclass -from typing import Optional +from pathlib import Path +from typing import Optional, Union import mlx.core as mx from huggingface_hub import hf_hub_download @@ -207,3 +208,23 @@ def load_clip_tokenizer(name: str): def load_t5_tokenizer(name: str, pad: bool = True): model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model") return T5Tokenizer(model_file, 256 if "schnell" in name else 512) + + +def save_config( + config: dict, + config_path: Union[str, Path], +) -> None: + """Save the model configuration to the ``config_path``. + + The final configuration will be sorted before saving for better readability. + + Args: + config (dict): The model configuration. + config_path (Union[str, Path]): Model configuration file path. + """ + # Sort the config for better readability + config = dict(sorted(config.items())) + + # Write the config to the provided file + with open(config_path, "w") as fid: + json.dump(config, fid, indent=4)