diff --git a/flux/dreambooth.py b/flux/dreambooth.py index 4a4dbb08..42250a3f 100644 --- a/flux/dreambooth.py +++ b/flux/dreambooth.py @@ -15,7 +15,7 @@ from mlx.utils import tree_flatten, tree_map, tree_reduce from PIL import Image from tqdm import tqdm -from flux import FluxPipeline +from flux import FluxPipeline, save_config class FinetuningDataset: @@ -250,6 +250,10 @@ if __name__ == "__main__": args = parser.parse_args() + output_path = Path(args.output_dir) + output_path.mkdir(parents=True, exist_ok=True) + save_config(vars(args), output_path / "adapter_config.json") + # Load the model and set it up for LoRA training. We use the same random # state when creating the LoRA layers so all workers will have the same # initial weights. diff --git a/flux/flux/utils.py b/flux/flux/utils.py index 21db17d3..43239f44 100644 --- a/flux/flux/utils.py +++ b/flux/flux/utils.py @@ -3,7 +3,8 @@ import json import os from dataclasses import dataclass -from typing import Optional +from pathlib import Path +from typing import Optional, Union import mlx.core as mx from huggingface_hub import hf_hub_download @@ -207,3 +208,26 @@ def load_clip_tokenizer(name: str): def load_t5_tokenizer(name: str, pad: bool = True): model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model") return T5Tokenizer(model_file, 256 if "schnell" in name else 512) + + +def save_config( + config: dict, + config_path: Union[str, Path], +) -> None: + """Save the model configuration to the ``config_path``. + + The final configuration will be sorted before saving for better readability. + + Args: + config (dict): The model configuration. + config_path (Union[str, Path]): Model configuration file path. + """ + # Clean unused keys + config.pop("_name_or_path", None) + + # sort the config for better readability + config = dict(sorted(config.items())) + + # write the updated config to the config_path (if provided) + with open(config_path, "w") as fid: + json.dump(config, fid, indent=4) \ No newline at end of file