This commit is contained in:
apolinário 2025-05-05 09:43:58 -04:00 committed by GitHub
commit 33c6ff9d8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,6 +4,7 @@ import argparse
import time import time
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
import os
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -13,8 +14,106 @@ from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten, tree_map, tree_reduce from mlx.utils import tree_flatten, tree_map, tree_reduce
from PIL import Image from PIL import Image
from huggingface_hub import HfApi, interpreter_login
from huggingface_hub.utils import HfFolder
from flux import FluxPipeline, Trainer, load_dataset, save_config from flux import FluxPipeline, Trainer, load_dataset, save_config
class FinetuningDataset:
def __init__(self, flux, args):
self.args = args
self.flux = flux
self.dataset_base = Path(args.dataset)
dataset_index = self.dataset_base / "index.json"
if not dataset_index.exists():
raise ValueError(f"'{args.dataset}' is not a valid finetuning dataset")
with open(dataset_index, "r") as f:
self.index = json.load(f)
self.latents = []
self.t5_features = []
self.clip_features = []
def _random_crop_resize(self, img):
resolution = self.args.resolution
width, height = img.size
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
# Random crop the input image between 0.8 to 1.0 of its original dimensions
crop_size = (
max((0.8 + 0.2 * a) * width, resolution[0]),
max((0.8 + 0.2 * a) * height, resolution[1]),
)
pan = (width - crop_size[0], height - crop_size[1])
img = img.crop(
(
pan[0] * b,
pan[1] * c,
crop_size[0] + pan[0] * b,
crop_size[1] + pan[1] * c,
)
)
# Fit the largest rectangle with the ratio of resolution in the image
# rectangle.
width, height = crop_size
ratio = resolution[0] / resolution[1]
r1 = (height * ratio, height)
r2 = (width, width / ratio)
r = r1 if r1[0] <= width else r2
img = img.crop(
(
(width - r[0]) / 2,
(height - r[1]) / 2,
(width + r[0]) / 2,
(height + r[1]) / 2,
)
)
# Finally resize the image to resolution
img = img.resize(resolution, Image.LANCZOS)
return mx.array(np.array(img))
def encode_images(self):
"""Encode the images in the latent space to prepare for training."""
self.flux.ae.eval()
for sample in tqdm(self.index["data"]):
input_img = Image.open(self.dataset_base / sample["image"])
for i in range(self.args.num_augmentations):
img = self._random_crop_resize(input_img)
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
x_0 = self.flux.ae.encode(img[None])
x_0 = x_0.astype(self.flux.dtype)
mx.eval(x_0)
self.latents.append(x_0)
def encode_prompts(self):
"""Pre-encode the prompts so that we don't recompute them during
training (doesn't allow finetuning the text encoders)."""
for sample in tqdm(self.index["data"]):
t5_tok, clip_tok = self.flux.tokenize([sample["text"]])
t5_feat = self.flux.t5(t5_tok)
clip_feat = self.flux.clip(clip_tok).pooled_output
mx.eval(t5_feat, clip_feat)
self.t5_features.append(t5_feat)
self.clip_features.append(clip_feat)
def iterate(self, batch_size):
xs = mx.concatenate(self.latents)
t5 = mx.concatenate(self.t5_features)
clip = mx.concatenate(self.clip_features)
mx.eval(xs, t5, clip)
n_aug = self.args.num_augmentations
while True:
x_indices = mx.random.permutation(len(self.latents))
c_indices = x_indices // n_aug
for i in range(0, len(self.latents), batch_size):
x_i = x_indices[i : i + batch_size]
c_i = c_indices[i : i + batch_size]
yield xs[x_i], t5[c_i], clip[c_i]
def generate_progress_images(iteration, flux, args): def generate_progress_images(iteration, flux, args):
"""Generate images to monitor the progress of the finetuning.""" """Generate images to monitor the progress of the finetuning."""
@ -58,6 +157,108 @@ def save_adapters(adapter_name, flux, args):
}, },
) )
def push_to_hub(args):
if args.hf_token is None:
interpreter_login(new_session=False, write_permission=True)
else:
HfFolder.save_token(args.hf_token)
repo_id = args.hf_repo_id or f"{HfFolder.get_token_username()}/{args.output_dir}"
readme_content = generate_readme(args, repo_id)
readme_path = os.path.join(args.output_dir, "README.md")
with open(readme_path, "w", encoding="utf-8") as f:
f.write(readme_content)
api = HfApi()
api.create_repo(
repo_id,
private=args.hf_private,
exist_ok=True
)
api.upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
ignore_patterns=["*.yaml", "*.pt"],
repo_type="model",
)
def generate_readme(args, repo_id):
import yaml
import re
base_model = f"flux-{args.model}"
tags = [
"text-to-image",
"flux",
"lora",
"diffusers",
"template:sd-lora",
"mlx",
"mlx-trainer"
]
widgets = []
sample_image_paths = []
# Look for progress images directly in the output directory
for filename in os.listdir(args.output_dir):
match = re.search(r"(\d+)_progress\.png$", filename)
if match:
iteration = int(match.group(1))
sample_image_paths.append((iteration, filename))
sample_image_paths.sort(key=lambda x: x[0], reverse=True)
if sample_image_paths:
widgets.append(
{
"text": args.progress_prompt,
"output": {
"url": sample_image_paths[0][1]
},
}
)
readme_content = f"""---
tags:
{yaml.dump(tags, indent=4).strip()}
{"widget:" if sample_image_paths else ""}
{yaml.dump(widgets, indent=4).strip() if widgets else ""}
base_model: {base_model}
license: other
---
# {os.path.basename(args.output_dir)}
Model trained with the MLX Flux Dreambooth script
<Gallery />
## Use it with [MLX](https://github.com/ml-explore/mlx-examples)
```py
from flux import FluxPipeline
import mlx.core as mx
flux = FluxPipeline("flux-{args.model}")
flux.linear_to_lora_layers({args.lora_rank}, {args.lora_blocks})
flux.flow.load_weights("{repo_id}")
image = flux.generate_images("{args.progress_prompt}", n_images=1, num_steps={args.progress_steps})
image.save("my_image.png")
```
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
```py
from diffusers import AutoPipelineForText2Image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained('black-forest-labs/{args.model}', torch_dtype=torch.bfloat16).to('cuda')
pipeline.load_lora_weights('{repo_id}')
image = pipeline({args.progress_prompt}').images[0]
image.save("my_image.png")
```
For more details on using Flux, check the [Flux documentation](https://github.com/black-forest-labs/flux).
"""
return readme_content
def setup_arg_parser(): def setup_arg_parser():
"""Set up and return the argument parser.""" """Set up and return the argument parser."""
@ -148,7 +349,28 @@ def setup_arg_parser():
parser.add_argument( parser.add_argument(
"--output-dir", default="mlx_output", help="Folder to save the checkpoints in" "--output-dir", default="mlx_output", help="Folder to save the checkpoints in"
) )
parser.add_argument(
"--push_to_hub",
action="store_true",
help="Push the model to Hugging Face Hub after training",
)
parser.add_argument(
"--hf_token",
type=str,
default=None,
help="Hugging Face token for pushing to Hub",
)
parser.add_argument(
"--hf_repo_id",
type=str,
default=None,
help="Hugging Face repository ID for pushing to Hub",
)
parser.add_argument(
"--hf_private",
action="store_true",
help="Make the Hugging Face repository private",
)
parser.add_argument("dataset") parser.add_argument("dataset")
return parser return parser
@ -287,6 +509,9 @@ if __name__ == "__main__":
if (i + 1) % 10 == 0: if (i + 1) % 10 == 0:
losses = [] losses = []
tic = time.time() tic = time.time()
if args.push_to_hub:
push_to_hub(args)
save_adapters("final_adapters.safetensors", flux, args) save_adapters("final_adapters.safetensors", flux, args)
print("Training successful.") print("Training successful.")