FLUX: save train config to json

This commit is contained in:
madroid 2024-10-15 13:35:12 +08:00
parent bbd2003047
commit 130cdae48e
2 changed files with 30 additions and 2 deletions

View File

@ -15,7 +15,7 @@ from mlx.utils import tree_flatten, tree_map, tree_reduce
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from flux import FluxPipeline from flux import FluxPipeline, save_config
class FinetuningDataset: class FinetuningDataset:
@ -250,6 +250,10 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
output_path = Path(args.output_dir)
output_path.mkdir(parents=True, exist_ok=True)
save_config(vars(args), output_path / "adapter_config.json")
# Load the model and set it up for LoRA training. We use the same random # Load the model and set it up for LoRA training. We use the same random
# state when creating the LoRA layers so all workers will have the same # state when creating the LoRA layers so all workers will have the same
# initial weights. # initial weights.

View File

@ -3,7 +3,8 @@
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from pathlib import Path
from typing import Optional, Union
import mlx.core as mx import mlx.core as mx
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
@ -207,3 +208,26 @@ def load_clip_tokenizer(name: str):
def load_t5_tokenizer(name: str, pad: bool = True): def load_t5_tokenizer(name: str, pad: bool = True):
model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model") model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model")
return T5Tokenizer(model_file, 256 if "schnell" in name else 512) return T5Tokenizer(model_file, 256 if "schnell" in name else 512)
def save_config(
config: dict,
config_path: Union[str, Path],
) -> None:
"""Save the model configuration to the ``config_path``.
The final configuration will be sorted before saving for better readability.
Args:
config (dict): The model configuration.
config_path (Union[str, Path]): Model configuration file path.
"""
# Clean unused keys
config.pop("_name_or_path", None)
# sort the config for better readability
config = dict(sorted(config.items()))
# write the updated config to the config_path (if provided)
with open(config_path, "w") as fid:
json.dump(config, fid, indent=4)