From 2dd903b0bf18f2affd16e310d70bc3934b34c223 Mon Sep 17 00:00:00 2001 From: multimodalart Date: Mon, 14 Oct 2024 21:59:56 +0700 Subject: [PATCH] Add push to hub --- flux/dreambooth.py | 132 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 131 insertions(+), 1 deletion(-) diff --git a/flux/dreambooth.py b/flux/dreambooth.py index 4a4dbb08..2cf06582 100644 --- a/flux/dreambooth.py +++ b/flux/dreambooth.py @@ -5,6 +5,7 @@ import json import time from functools import partial from pathlib import Path +import os import mlx.core as mx import mlx.nn as nn @@ -15,6 +16,9 @@ from mlx.utils import tree_flatten, tree_map, tree_reduce from PIL import Image from tqdm import tqdm +from huggingface_hub import HfApi, interpreter_login +from huggingface_hub.utils import HfFolder + from flux import FluxPipeline @@ -156,6 +160,108 @@ def save_adapters(iteration, flux, args): }, ) +def push_to_hub(args): + if args.hf_token is None: + interpreter_login(new_session=False, write_permission=True) + else: + HfFolder.save_token(args.hf_token) + + repo_id = args.hf_repo_id or f"{HfFolder.get_token_username()}/{args.output_dir}" + + readme_content = generate_readme(args, repo_id) + readme_path = os.path.join(args.output_dir, "README.md") + with open(readme_path, "w", encoding="utf-8") as f: + f.write(readme_content) + + api = HfApi() + + api.create_repo( + repo_id, + private=args.hf_private, + exist_ok=True + ) + + api.upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + ignore_patterns=["*.yaml", "*.pt"], + repo_type="model", + ) + +def generate_readme(args, repo_id): + import yaml + import re + base_model = f"flux-{args.model}" + tags = [ + "text-to-image", + "flux", + "lora", + "diffusers", + "template:sd-lora", + "mlx", + "mlx-trainer" + ] + + widgets = [] + sample_image_paths = [] + # Look for progress images directly in the output directory + for filename in os.listdir(args.output_dir): + match = re.search(r"(\d+)_progress\.png$", filename) + if match: + iteration = int(match.group(1)) + sample_image_paths.append((iteration, filename)) + + sample_image_paths.sort(key=lambda x: x[0], reverse=True) + + if sample_image_paths: + widgets.append( + { + "text": args.progress_prompt, + "output": { + "url": sample_image_paths[0][1] + }, + } + ) + + readme_content = f"""--- +tags: +{yaml.dump(tags, indent=4).strip()} +{"widget:" if sample_image_paths else ""} +{yaml.dump(widgets, indent=4).strip() if widgets else ""} +base_model: {base_model} +license: other +--- + +# {os.path.basename(args.output_dir)} +Model trained with the MLX Flux Dreambooth script + + + +## Use it with [MLX](https://github.com/ml-explore/mlx-examples) +```py +from flux import FluxPipeline +import mlx.core as mx +flux = FluxPipeline("flux-{args.model}") +flux.linear_to_lora_layers({args.lora_rank}, {args.lora_blocks}) +flux.flow.load_weights("{repo_id}") +image = flux.generate_images("{args.progress_prompt}", n_images=1, num_steps={args.progress_steps}) +image.save("my_image.png") +``` + +## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) +```py +from diffusers import AutoPipelineForText2Image +import torch + +pipeline = AutoPipelineForText2Image.from_pretrained('black-forest-labs/{args.model}', torch_dtype=torch.bfloat16).to('cuda') +pipeline.load_lora_weights('{repo_id}') +image = pipeline({args.progress_prompt}').images[0] +image.save("my_image.png") +``` + +For more details on using Flux, check the [Flux documentation](https://github.com/black-forest-labs/flux). +""" + return readme_content if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -245,7 +351,28 @@ if __name__ == "__main__": parser.add_argument( "--output-dir", default="mlx_output", help="Folder to save the checkpoints in" ) - + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Push the model to Hugging Face Hub after training", + ) + parser.add_argument( + "--hf_token", + type=str, + default=None, + help="Hugging Face token for pushing to Hub", + ) + parser.add_argument( + "--hf_repo_id", + type=str, + default=None, + help="Hugging Face repository ID for pushing to Hub", + ) + parser.add_argument( + "--hf_private", + action="store_true", + help="Make the Hugging Face repository private", + ) parser.add_argument("dataset") args = parser.parse_args() @@ -376,3 +503,6 @@ if __name__ == "__main__": if (i + 1) % 10 == 0: losses = [] tic = time.time() + + if args.push_to_hub: + push_to_hub(args)