Add push to hub

This commit is contained in:
multimodalart 2024-10-14 21:59:56 +07:00
parent 1e0cda68c6
commit 2dd903b0bf

View File

@ -5,6 +5,7 @@ import json
import time import time
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
import os
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -15,6 +16,9 @@ from mlx.utils import tree_flatten, tree_map, tree_reduce
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from huggingface_hub import HfApi, interpreter_login
from huggingface_hub.utils import HfFolder
from flux import FluxPipeline from flux import FluxPipeline
@ -156,6 +160,108 @@ def save_adapters(iteration, flux, args):
}, },
) )
def push_to_hub(args):
if args.hf_token is None:
interpreter_login(new_session=False, write_permission=True)
else:
HfFolder.save_token(args.hf_token)
repo_id = args.hf_repo_id or f"{HfFolder.get_token_username()}/{args.output_dir}"
readme_content = generate_readme(args, repo_id)
readme_path = os.path.join(args.output_dir, "README.md")
with open(readme_path, "w", encoding="utf-8") as f:
f.write(readme_content)
api = HfApi()
api.create_repo(
repo_id,
private=args.hf_private,
exist_ok=True
)
api.upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
ignore_patterns=["*.yaml", "*.pt"],
repo_type="model",
)
def generate_readme(args, repo_id):
import yaml
import re
base_model = f"flux-{args.model}"
tags = [
"text-to-image",
"flux",
"lora",
"diffusers",
"template:sd-lora",
"mlx",
"mlx-trainer"
]
widgets = []
sample_image_paths = []
# Look for progress images directly in the output directory
for filename in os.listdir(args.output_dir):
match = re.search(r"(\d+)_progress\.png$", filename)
if match:
iteration = int(match.group(1))
sample_image_paths.append((iteration, filename))
sample_image_paths.sort(key=lambda x: x[0], reverse=True)
if sample_image_paths:
widgets.append(
{
"text": args.progress_prompt,
"output": {
"url": sample_image_paths[0][1]
},
}
)
readme_content = f"""---
tags:
{yaml.dump(tags, indent=4).strip()}
{"widget:" if sample_image_paths else ""}
{yaml.dump(widgets, indent=4).strip() if widgets else ""}
base_model: {base_model}
license: other
---
# {os.path.basename(args.output_dir)}
Model trained with the MLX Flux Dreambooth script
<Gallery />
## Use it with [MLX](https://github.com/ml-explore/mlx-examples)
```py
from flux import FluxPipeline
import mlx.core as mx
flux = FluxPipeline("flux-{args.model}")
flux.linear_to_lora_layers({args.lora_rank}, {args.lora_blocks})
flux.flow.load_weights("{repo_id}")
image = flux.generate_images("{args.progress_prompt}", n_images=1, num_steps={args.progress_steps})
image.save("my_image.png")
```
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
```py
from diffusers import AutoPipelineForText2Image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained('black-forest-labs/{args.model}', torch_dtype=torch.bfloat16).to('cuda')
pipeline.load_lora_weights('{repo_id}')
image = pipeline({args.progress_prompt}').images[0]
image.save("my_image.png")
```
For more details on using Flux, check the [Flux documentation](https://github.com/black-forest-labs/flux).
"""
return readme_content
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -245,7 +351,28 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--output-dir", default="mlx_output", help="Folder to save the checkpoints in" "--output-dir", default="mlx_output", help="Folder to save the checkpoints in"
) )
parser.add_argument(
"--push_to_hub",
action="store_true",
help="Push the model to Hugging Face Hub after training",
)
parser.add_argument(
"--hf_token",
type=str,
default=None,
help="Hugging Face token for pushing to Hub",
)
parser.add_argument(
"--hf_repo_id",
type=str,
default=None,
help="Hugging Face repository ID for pushing to Hub",
)
parser.add_argument(
"--hf_private",
action="store_true",
help="Make the Hugging Face repository private",
)
parser.add_argument("dataset") parser.add_argument("dataset")
args = parser.parse_args() args = parser.parse_args()
@ -376,3 +503,6 @@ if __name__ == "__main__":
if (i + 1) % 10 == 0: if (i + 1) % 10 == 0:
losses = [] losses = []
tic = time.time() tic = time.time()
if args.push_to_hub:
push_to_hub(args)