FLUX: save train config (#1049)

This commit is contained in:
madroid 2024-11-09 09:15:19 +08:00 committed by GitHub
parent 657b4cc0aa
commit 1e07660184
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 35 additions and 6 deletions

View File

@ -188,7 +188,7 @@ The adapters are saved in `mlx_output` and can be used directly by the
```shell
python txt2image.py --model dev --save-raw --image-size 512x512 --n-images 1 \
--adapter mlx_output/0001200_adapters.safetensors \
--adapter mlx_output/final_adapters.safetensors \
--fuse-adapter \
--no-t5-padding \
'A photo of an sks dog lying on the sand at a beach in Greece'

View File

@ -13,7 +13,7 @@ from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten, tree_map, tree_reduce
from PIL import Image
from flux import FluxPipeline, Trainer, load_dataset
from flux import FluxPipeline, Trainer, load_dataset, save_config
def generate_progress_images(iteration, flux, args):
@ -43,10 +43,10 @@ def generate_progress_images(iteration, flux, args):
im.save(out_file)
def save_adapters(iteration, flux, args):
def save_adapters(adapter_name, flux, args):
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
out_file = out_dir / f"{iteration:07d}_adapters.safetensors"
out_file = out_dir / adapter_name
print(f"Saving {str(out_file)}")
mx.save_safetensors(
@ -157,6 +157,10 @@ if __name__ == "__main__":
parser = setup_arg_parser()
args = parser.parse_args()
output_path = Path(args.output_dir)
output_path.mkdir(parents=True, exist_ok=True)
save_config(vars(args), output_path / "adapter_config.json")
# Load the model and set it up for LoRA training. We use the same random
# state when creating the LoRA layers so all workers will have the same
# initial weights.
@ -278,8 +282,11 @@ if __name__ == "__main__":
generate_progress_images(i + 1, flux, args)
if (i + 1) % args.checkpoint_every == 0:
save_adapters(i + 1, flux, args)
save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args)
if (i + 1) % 10 == 0:
losses = []
tic = time.time()
save_adapters("final_adapters.safetensors", flux, args)
print(f"Training successful. Saved final weights to {args.adapter_file}.")

View File

@ -12,4 +12,5 @@ from .utils import (
load_flow_model,
load_t5,
load_t5_tokenizer,
save_config,
)

View File

@ -3,7 +3,8 @@
import json
import os
from dataclasses import dataclass
from typing import Optional
from pathlib import Path
from typing import Optional, Union
import mlx.core as mx
from huggingface_hub import hf_hub_download
@ -207,3 +208,23 @@ def load_clip_tokenizer(name: str):
def load_t5_tokenizer(name: str, pad: bool = True):
model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model")
return T5Tokenizer(model_file, 256 if "schnell" in name else 512)
def save_config(
config: dict,
config_path: Union[str, Path],
) -> None:
"""Save the model configuration to the ``config_path``.
The final configuration will be sorted before saving for better readability.
Args:
config (dict): The model configuration.
config_path (Union[str, Path]): Model configuration file path.
"""
# Sort the config for better readability
config = dict(sorted(config.items()))
# Write the config to the provided file
with open(config_path, "w") as fid:
json.dump(config, fid, indent=4)