mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
FLUX: save train config to json
This commit is contained in:
parent
bbd2003047
commit
130cdae48e
@ -15,7 +15,7 @@ from mlx.utils import tree_flatten, tree_map, tree_reduce
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from flux import FluxPipeline
|
from flux import FluxPipeline, save_config
|
||||||
|
|
||||||
|
|
||||||
class FinetuningDataset:
|
class FinetuningDataset:
|
||||||
@ -250,6 +250,10 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
output_path = Path(args.output_dir)
|
||||||
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
save_config(vars(args), output_path / "adapter_config.json")
|
||||||
|
|
||||||
# Load the model and set it up for LoRA training. We use the same random
|
# Load the model and set it up for LoRA training. We use the same random
|
||||||
# state when creating the LoRA layers so all workers will have the same
|
# state when creating the LoRA layers so all workers will have the same
|
||||||
# initial weights.
|
# initial weights.
|
||||||
|
@ -3,7 +3,8 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
@ -207,3 +208,26 @@ def load_clip_tokenizer(name: str):
|
|||||||
def load_t5_tokenizer(name: str, pad: bool = True):
|
def load_t5_tokenizer(name: str, pad: bool = True):
|
||||||
model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model")
|
model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model")
|
||||||
return T5Tokenizer(model_file, 256 if "schnell" in name else 512)
|
return T5Tokenizer(model_file, 256 if "schnell" in name else 512)
|
||||||
|
|
||||||
|
|
||||||
|
def save_config(
|
||||||
|
config: dict,
|
||||||
|
config_path: Union[str, Path],
|
||||||
|
) -> None:
|
||||||
|
"""Save the model configuration to the ``config_path``.
|
||||||
|
|
||||||
|
The final configuration will be sorted before saving for better readability.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (dict): The model configuration.
|
||||||
|
config_path (Union[str, Path]): Model configuration file path.
|
||||||
|
"""
|
||||||
|
# Clean unused keys
|
||||||
|
config.pop("_name_or_path", None)
|
||||||
|
|
||||||
|
# sort the config for better readability
|
||||||
|
config = dict(sorted(config.items()))
|
||||||
|
|
||||||
|
# write the updated config to the config_path (if provided)
|
||||||
|
with open(config_path, "w") as fid:
|
||||||
|
json.dump(config, fid, indent=4)
|
Loading…
Reference in New Issue
Block a user