diff --git a/flux/dreambooth.py b/flux/dreambooth.py index f82178b9..8739fda1 100644 --- a/flux/dreambooth.py +++ b/flux/dreambooth.py @@ -4,6 +4,7 @@ import argparse import time from functools import partial from pathlib import Path +import os import mlx.core as mx import mlx.nn as nn @@ -13,8 +14,106 @@ from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten, tree_map, tree_reduce from PIL import Image +from huggingface_hub import HfApi, interpreter_login +from huggingface_hub.utils import HfFolder + from flux import FluxPipeline, Trainer, load_dataset, save_config +class FinetuningDataset: + def __init__(self, flux, args): + self.args = args + self.flux = flux + self.dataset_base = Path(args.dataset) + dataset_index = self.dataset_base / "index.json" + if not dataset_index.exists(): + raise ValueError(f"'{args.dataset}' is not a valid finetuning dataset") + with open(dataset_index, "r") as f: + self.index = json.load(f) + + self.latents = [] + self.t5_features = [] + self.clip_features = [] + + def _random_crop_resize(self, img): + resolution = self.args.resolution + width, height = img.size + + a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist() + + # Random crop the input image between 0.8 to 1.0 of its original dimensions + crop_size = ( + max((0.8 + 0.2 * a) * width, resolution[0]), + max((0.8 + 0.2 * a) * height, resolution[1]), + ) + pan = (width - crop_size[0], height - crop_size[1]) + img = img.crop( + ( + pan[0] * b, + pan[1] * c, + crop_size[0] + pan[0] * b, + crop_size[1] + pan[1] * c, + ) + ) + + # Fit the largest rectangle with the ratio of resolution in the image + # rectangle. + width, height = crop_size + ratio = resolution[0] / resolution[1] + r1 = (height * ratio, height) + r2 = (width, width / ratio) + r = r1 if r1[0] <= width else r2 + img = img.crop( + ( + (width - r[0]) / 2, + (height - r[1]) / 2, + (width + r[0]) / 2, + (height + r[1]) / 2, + ) + ) + + # Finally resize the image to resolution + img = img.resize(resolution, Image.LANCZOS) + + return mx.array(np.array(img)) + + def encode_images(self): + """Encode the images in the latent space to prepare for training.""" + self.flux.ae.eval() + for sample in tqdm(self.index["data"]): + input_img = Image.open(self.dataset_base / sample["image"]) + for i in range(self.args.num_augmentations): + img = self._random_crop_resize(input_img) + img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1 + x_0 = self.flux.ae.encode(img[None]) + x_0 = x_0.astype(self.flux.dtype) + mx.eval(x_0) + self.latents.append(x_0) + + def encode_prompts(self): + """Pre-encode the prompts so that we don't recompute them during + training (doesn't allow finetuning the text encoders).""" + for sample in tqdm(self.index["data"]): + t5_tok, clip_tok = self.flux.tokenize([sample["text"]]) + t5_feat = self.flux.t5(t5_tok) + clip_feat = self.flux.clip(clip_tok).pooled_output + mx.eval(t5_feat, clip_feat) + self.t5_features.append(t5_feat) + self.clip_features.append(clip_feat) + + def iterate(self, batch_size): + xs = mx.concatenate(self.latents) + t5 = mx.concatenate(self.t5_features) + clip = mx.concatenate(self.clip_features) + mx.eval(xs, t5, clip) + n_aug = self.args.num_augmentations + while True: + x_indices = mx.random.permutation(len(self.latents)) + c_indices = x_indices // n_aug + for i in range(0, len(self.latents), batch_size): + x_i = x_indices[i : i + batch_size] + c_i = c_indices[i : i + batch_size] + yield xs[x_i], t5[c_i], clip[c_i] + def generate_progress_images(iteration, flux, args): """Generate images to monitor the progress of the finetuning.""" @@ -58,6 +157,108 @@ def save_adapters(adapter_name, flux, args): }, ) +def push_to_hub(args): + if args.hf_token is None: + interpreter_login(new_session=False, write_permission=True) + else: + HfFolder.save_token(args.hf_token) + + repo_id = args.hf_repo_id or f"{HfFolder.get_token_username()}/{args.output_dir}" + + readme_content = generate_readme(args, repo_id) + readme_path = os.path.join(args.output_dir, "README.md") + with open(readme_path, "w", encoding="utf-8") as f: + f.write(readme_content) + + api = HfApi() + + api.create_repo( + repo_id, + private=args.hf_private, + exist_ok=True + ) + + api.upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + ignore_patterns=["*.yaml", "*.pt"], + repo_type="model", + ) + +def generate_readme(args, repo_id): + import yaml + import re + base_model = f"flux-{args.model}" + tags = [ + "text-to-image", + "flux", + "lora", + "diffusers", + "template:sd-lora", + "mlx", + "mlx-trainer" + ] + + widgets = [] + sample_image_paths = [] + # Look for progress images directly in the output directory + for filename in os.listdir(args.output_dir): + match = re.search(r"(\d+)_progress\.png$", filename) + if match: + iteration = int(match.group(1)) + sample_image_paths.append((iteration, filename)) + + sample_image_paths.sort(key=lambda x: x[0], reverse=True) + + if sample_image_paths: + widgets.append( + { + "text": args.progress_prompt, + "output": { + "url": sample_image_paths[0][1] + }, + } + ) + + readme_content = f"""--- +tags: +{yaml.dump(tags, indent=4).strip()} +{"widget:" if sample_image_paths else ""} +{yaml.dump(widgets, indent=4).strip() if widgets else ""} +base_model: {base_model} +license: other +--- + +# {os.path.basename(args.output_dir)} +Model trained with the MLX Flux Dreambooth script + + + +## Use it with [MLX](https://github.com/ml-explore/mlx-examples) +```py +from flux import FluxPipeline +import mlx.core as mx +flux = FluxPipeline("flux-{args.model}") +flux.linear_to_lora_layers({args.lora_rank}, {args.lora_blocks}) +flux.flow.load_weights("{repo_id}") +image = flux.generate_images("{args.progress_prompt}", n_images=1, num_steps={args.progress_steps}) +image.save("my_image.png") +``` + +## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) +```py +from diffusers import AutoPipelineForText2Image +import torch + +pipeline = AutoPipelineForText2Image.from_pretrained('black-forest-labs/{args.model}', torch_dtype=torch.bfloat16).to('cuda') +pipeline.load_lora_weights('{repo_id}') +image = pipeline({args.progress_prompt}').images[0] +image.save("my_image.png") +``` + +For more details on using Flux, check the [Flux documentation](https://github.com/black-forest-labs/flux). +""" + return readme_content def setup_arg_parser(): """Set up and return the argument parser.""" @@ -148,7 +349,28 @@ def setup_arg_parser(): parser.add_argument( "--output-dir", default="mlx_output", help="Folder to save the checkpoints in" ) - + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Push the model to Hugging Face Hub after training", + ) + parser.add_argument( + "--hf_token", + type=str, + default=None, + help="Hugging Face token for pushing to Hub", + ) + parser.add_argument( + "--hf_repo_id", + type=str, + default=None, + help="Hugging Face repository ID for pushing to Hub", + ) + parser.add_argument( + "--hf_private", + action="store_true", + help="Make the Hugging Face repository private", + ) parser.add_argument("dataset") return parser @@ -287,6 +509,9 @@ if __name__ == "__main__": if (i + 1) % 10 == 0: losses = [] tic = time.time() + + if args.push_to_hub: + push_to_hub(args) save_adapters("final_adapters.safetensors", flux, args) print("Training successful.")