FLUX: save train config (#1049)

This commit is contained in:
madroid 2024-11-09 09:15:19 +08:00 committed by GitHub
parent 657b4cc0aa
commit 1e07660184
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 35 additions and 6 deletions

View File

@ -188,7 +188,7 @@ The adapters are saved in `mlx_output` and can be used directly by the
```shell ```shell
python txt2image.py --model dev --save-raw --image-size 512x512 --n-images 1 \ python txt2image.py --model dev --save-raw --image-size 512x512 --n-images 1 \
--adapter mlx_output/0001200_adapters.safetensors \ --adapter mlx_output/final_adapters.safetensors \
--fuse-adapter \ --fuse-adapter \
--no-t5-padding \ --no-t5-padding \
'A photo of an sks dog lying on the sand at a beach in Greece' 'A photo of an sks dog lying on the sand at a beach in Greece'

View File

@ -13,7 +13,7 @@ from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten, tree_map, tree_reduce from mlx.utils import tree_flatten, tree_map, tree_reduce
from PIL import Image from PIL import Image
from flux import FluxPipeline, Trainer, load_dataset from flux import FluxPipeline, Trainer, load_dataset, save_config
def generate_progress_images(iteration, flux, args): def generate_progress_images(iteration, flux, args):
@ -43,10 +43,10 @@ def generate_progress_images(iteration, flux, args):
im.save(out_file) im.save(out_file)
def save_adapters(iteration, flux, args): def save_adapters(adapter_name, flux, args):
out_dir = Path(args.output_dir) out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True) out_dir.mkdir(parents=True, exist_ok=True)
out_file = out_dir / f"{iteration:07d}_adapters.safetensors" out_file = out_dir / adapter_name
print(f"Saving {str(out_file)}") print(f"Saving {str(out_file)}")
mx.save_safetensors( mx.save_safetensors(
@ -157,6 +157,10 @@ if __name__ == "__main__":
parser = setup_arg_parser() parser = setup_arg_parser()
args = parser.parse_args() args = parser.parse_args()
output_path = Path(args.output_dir)
output_path.mkdir(parents=True, exist_ok=True)
save_config(vars(args), output_path / "adapter_config.json")
# Load the model and set it up for LoRA training. We use the same random # Load the model and set it up for LoRA training. We use the same random
# state when creating the LoRA layers so all workers will have the same # state when creating the LoRA layers so all workers will have the same
# initial weights. # initial weights.
@ -278,8 +282,11 @@ if __name__ == "__main__":
generate_progress_images(i + 1, flux, args) generate_progress_images(i + 1, flux, args)
if (i + 1) % args.checkpoint_every == 0: if (i + 1) % args.checkpoint_every == 0:
save_adapters(i + 1, flux, args) save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args)
if (i + 1) % 10 == 0: if (i + 1) % 10 == 0:
losses = [] losses = []
tic = time.time() tic = time.time()
save_adapters("final_adapters.safetensors", flux, args)
print(f"Training successful. Saved final weights to {args.adapter_file}.")

View File

@ -12,4 +12,5 @@ from .utils import (
load_flow_model, load_flow_model,
load_t5, load_t5,
load_t5_tokenizer, load_t5_tokenizer,
save_config,
) )

View File

@ -3,7 +3,8 @@
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from pathlib import Path
from typing import Optional, Union
import mlx.core as mx import mlx.core as mx
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
@ -207,3 +208,23 @@ def load_clip_tokenizer(name: str):
def load_t5_tokenizer(name: str, pad: bool = True): def load_t5_tokenizer(name: str, pad: bool = True):
model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model") model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model")
return T5Tokenizer(model_file, 256 if "schnell" in name else 512) return T5Tokenizer(model_file, 256 if "schnell" in name else 512)
def save_config(
config: dict,
config_path: Union[str, Path],
) -> None:
"""Save the model configuration to the ``config_path``.
The final configuration will be sorted before saving for better readability.
Args:
config (dict): The model configuration.
config_path (Union[str, Path]): Model configuration file path.
"""
# Sort the config for better readability
config = dict(sorted(config.items()))
# Write the config to the provided file
with open(config_path, "w") as fid:
json.dump(config, fid, indent=4)