FLUX: save train config (#1049)

This commit is contained in:
madroid
2024-11-09 09:15:19 +08:00
committed by GitHub
parent 657b4cc0aa
commit 1e07660184
4 changed files with 35 additions and 6 deletions

View File

@@ -12,4 +12,5 @@ from .utils import (
load_flow_model,
load_t5,
load_t5_tokenizer,
save_config,
)

View File

@@ -3,7 +3,8 @@
import json
import os
from dataclasses import dataclass
from typing import Optional
from pathlib import Path
from typing import Optional, Union
import mlx.core as mx
from huggingface_hub import hf_hub_download
@@ -207,3 +208,23 @@ def load_clip_tokenizer(name: str):
def load_t5_tokenizer(name: str, pad: bool = True):
model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model")
return T5Tokenizer(model_file, 256 if "schnell" in name else 512)
def save_config(
config: dict,
config_path: Union[str, Path],
) -> None:
"""Save the model configuration to the ``config_path``.
The final configuration will be sorted before saving for better readability.
Args:
config (dict): The model configuration.
config_path (Union[str, Path]): Model configuration file path.
"""
# Sort the config for better readability
config = dict(sorted(config.items()))
# Write the config to the provided file
with open(config_path, "w") as fid:
json.dump(config, fid, indent=4)