mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Merge dd9f26e604
into 977cd30242
This commit is contained in:
commit
33c6ff9d8f
@ -4,6 +4,7 @@ import argparse
|
|||||||
import time
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -13,8 +14,106 @@ from mlx.nn.utils import average_gradients
|
|||||||
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from huggingface_hub import HfApi, interpreter_login
|
||||||
|
from huggingface_hub.utils import HfFolder
|
||||||
|
|
||||||
from flux import FluxPipeline, Trainer, load_dataset, save_config
|
from flux import FluxPipeline, Trainer, load_dataset, save_config
|
||||||
|
|
||||||
|
class FinetuningDataset:
|
||||||
|
def __init__(self, flux, args):
|
||||||
|
self.args = args
|
||||||
|
self.flux = flux
|
||||||
|
self.dataset_base = Path(args.dataset)
|
||||||
|
dataset_index = self.dataset_base / "index.json"
|
||||||
|
if not dataset_index.exists():
|
||||||
|
raise ValueError(f"'{args.dataset}' is not a valid finetuning dataset")
|
||||||
|
with open(dataset_index, "r") as f:
|
||||||
|
self.index = json.load(f)
|
||||||
|
|
||||||
|
self.latents = []
|
||||||
|
self.t5_features = []
|
||||||
|
self.clip_features = []
|
||||||
|
|
||||||
|
def _random_crop_resize(self, img):
|
||||||
|
resolution = self.args.resolution
|
||||||
|
width, height = img.size
|
||||||
|
|
||||||
|
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
|
||||||
|
|
||||||
|
# Random crop the input image between 0.8 to 1.0 of its original dimensions
|
||||||
|
crop_size = (
|
||||||
|
max((0.8 + 0.2 * a) * width, resolution[0]),
|
||||||
|
max((0.8 + 0.2 * a) * height, resolution[1]),
|
||||||
|
)
|
||||||
|
pan = (width - crop_size[0], height - crop_size[1])
|
||||||
|
img = img.crop(
|
||||||
|
(
|
||||||
|
pan[0] * b,
|
||||||
|
pan[1] * c,
|
||||||
|
crop_size[0] + pan[0] * b,
|
||||||
|
crop_size[1] + pan[1] * c,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fit the largest rectangle with the ratio of resolution in the image
|
||||||
|
# rectangle.
|
||||||
|
width, height = crop_size
|
||||||
|
ratio = resolution[0] / resolution[1]
|
||||||
|
r1 = (height * ratio, height)
|
||||||
|
r2 = (width, width / ratio)
|
||||||
|
r = r1 if r1[0] <= width else r2
|
||||||
|
img = img.crop(
|
||||||
|
(
|
||||||
|
(width - r[0]) / 2,
|
||||||
|
(height - r[1]) / 2,
|
||||||
|
(width + r[0]) / 2,
|
||||||
|
(height + r[1]) / 2,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Finally resize the image to resolution
|
||||||
|
img = img.resize(resolution, Image.LANCZOS)
|
||||||
|
|
||||||
|
return mx.array(np.array(img))
|
||||||
|
|
||||||
|
def encode_images(self):
|
||||||
|
"""Encode the images in the latent space to prepare for training."""
|
||||||
|
self.flux.ae.eval()
|
||||||
|
for sample in tqdm(self.index["data"]):
|
||||||
|
input_img = Image.open(self.dataset_base / sample["image"])
|
||||||
|
for i in range(self.args.num_augmentations):
|
||||||
|
img = self._random_crop_resize(input_img)
|
||||||
|
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
|
||||||
|
x_0 = self.flux.ae.encode(img[None])
|
||||||
|
x_0 = x_0.astype(self.flux.dtype)
|
||||||
|
mx.eval(x_0)
|
||||||
|
self.latents.append(x_0)
|
||||||
|
|
||||||
|
def encode_prompts(self):
|
||||||
|
"""Pre-encode the prompts so that we don't recompute them during
|
||||||
|
training (doesn't allow finetuning the text encoders)."""
|
||||||
|
for sample in tqdm(self.index["data"]):
|
||||||
|
t5_tok, clip_tok = self.flux.tokenize([sample["text"]])
|
||||||
|
t5_feat = self.flux.t5(t5_tok)
|
||||||
|
clip_feat = self.flux.clip(clip_tok).pooled_output
|
||||||
|
mx.eval(t5_feat, clip_feat)
|
||||||
|
self.t5_features.append(t5_feat)
|
||||||
|
self.clip_features.append(clip_feat)
|
||||||
|
|
||||||
|
def iterate(self, batch_size):
|
||||||
|
xs = mx.concatenate(self.latents)
|
||||||
|
t5 = mx.concatenate(self.t5_features)
|
||||||
|
clip = mx.concatenate(self.clip_features)
|
||||||
|
mx.eval(xs, t5, clip)
|
||||||
|
n_aug = self.args.num_augmentations
|
||||||
|
while True:
|
||||||
|
x_indices = mx.random.permutation(len(self.latents))
|
||||||
|
c_indices = x_indices // n_aug
|
||||||
|
for i in range(0, len(self.latents), batch_size):
|
||||||
|
x_i = x_indices[i : i + batch_size]
|
||||||
|
c_i = c_indices[i : i + batch_size]
|
||||||
|
yield xs[x_i], t5[c_i], clip[c_i]
|
||||||
|
|
||||||
|
|
||||||
def generate_progress_images(iteration, flux, args):
|
def generate_progress_images(iteration, flux, args):
|
||||||
"""Generate images to monitor the progress of the finetuning."""
|
"""Generate images to monitor the progress of the finetuning."""
|
||||||
@ -58,6 +157,108 @@ def save_adapters(adapter_name, flux, args):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def push_to_hub(args):
|
||||||
|
if args.hf_token is None:
|
||||||
|
interpreter_login(new_session=False, write_permission=True)
|
||||||
|
else:
|
||||||
|
HfFolder.save_token(args.hf_token)
|
||||||
|
|
||||||
|
repo_id = args.hf_repo_id or f"{HfFolder.get_token_username()}/{args.output_dir}"
|
||||||
|
|
||||||
|
readme_content = generate_readme(args, repo_id)
|
||||||
|
readme_path = os.path.join(args.output_dir, "README.md")
|
||||||
|
with open(readme_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(readme_content)
|
||||||
|
|
||||||
|
api = HfApi()
|
||||||
|
|
||||||
|
api.create_repo(
|
||||||
|
repo_id,
|
||||||
|
private=args.hf_private,
|
||||||
|
exist_ok=True
|
||||||
|
)
|
||||||
|
|
||||||
|
api.upload_folder(
|
||||||
|
repo_id=repo_id,
|
||||||
|
folder_path=args.output_dir,
|
||||||
|
ignore_patterns=["*.yaml", "*.pt"],
|
||||||
|
repo_type="model",
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_readme(args, repo_id):
|
||||||
|
import yaml
|
||||||
|
import re
|
||||||
|
base_model = f"flux-{args.model}"
|
||||||
|
tags = [
|
||||||
|
"text-to-image",
|
||||||
|
"flux",
|
||||||
|
"lora",
|
||||||
|
"diffusers",
|
||||||
|
"template:sd-lora",
|
||||||
|
"mlx",
|
||||||
|
"mlx-trainer"
|
||||||
|
]
|
||||||
|
|
||||||
|
widgets = []
|
||||||
|
sample_image_paths = []
|
||||||
|
# Look for progress images directly in the output directory
|
||||||
|
for filename in os.listdir(args.output_dir):
|
||||||
|
match = re.search(r"(\d+)_progress\.png$", filename)
|
||||||
|
if match:
|
||||||
|
iteration = int(match.group(1))
|
||||||
|
sample_image_paths.append((iteration, filename))
|
||||||
|
|
||||||
|
sample_image_paths.sort(key=lambda x: x[0], reverse=True)
|
||||||
|
|
||||||
|
if sample_image_paths:
|
||||||
|
widgets.append(
|
||||||
|
{
|
||||||
|
"text": args.progress_prompt,
|
||||||
|
"output": {
|
||||||
|
"url": sample_image_paths[0][1]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
readme_content = f"""---
|
||||||
|
tags:
|
||||||
|
{yaml.dump(tags, indent=4).strip()}
|
||||||
|
{"widget:" if sample_image_paths else ""}
|
||||||
|
{yaml.dump(widgets, indent=4).strip() if widgets else ""}
|
||||||
|
base_model: {base_model}
|
||||||
|
license: other
|
||||||
|
---
|
||||||
|
|
||||||
|
# {os.path.basename(args.output_dir)}
|
||||||
|
Model trained with the MLX Flux Dreambooth script
|
||||||
|
|
||||||
|
<Gallery />
|
||||||
|
|
||||||
|
## Use it with [MLX](https://github.com/ml-explore/mlx-examples)
|
||||||
|
```py
|
||||||
|
from flux import FluxPipeline
|
||||||
|
import mlx.core as mx
|
||||||
|
flux = FluxPipeline("flux-{args.model}")
|
||||||
|
flux.linear_to_lora_layers({args.lora_rank}, {args.lora_blocks})
|
||||||
|
flux.flow.load_weights("{repo_id}")
|
||||||
|
image = flux.generate_images("{args.progress_prompt}", n_images=1, num_steps={args.progress_steps})
|
||||||
|
image.save("my_image.png")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
|
||||||
|
```py
|
||||||
|
from diffusers import AutoPipelineForText2Image
|
||||||
|
import torch
|
||||||
|
|
||||||
|
pipeline = AutoPipelineForText2Image.from_pretrained('black-forest-labs/{args.model}', torch_dtype=torch.bfloat16).to('cuda')
|
||||||
|
pipeline.load_lora_weights('{repo_id}')
|
||||||
|
image = pipeline({args.progress_prompt}').images[0]
|
||||||
|
image.save("my_image.png")
|
||||||
|
```
|
||||||
|
|
||||||
|
For more details on using Flux, check the [Flux documentation](https://github.com/black-forest-labs/flux).
|
||||||
|
"""
|
||||||
|
return readme_content
|
||||||
|
|
||||||
def setup_arg_parser():
|
def setup_arg_parser():
|
||||||
"""Set up and return the argument parser."""
|
"""Set up and return the argument parser."""
|
||||||
@ -148,7 +349,28 @@ def setup_arg_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output-dir", default="mlx_output", help="Folder to save the checkpoints in"
|
"--output-dir", default="mlx_output", help="Folder to save the checkpoints in"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--push_to_hub",
|
||||||
|
action="store_true",
|
||||||
|
help="Push the model to Hugging Face Hub after training",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hf_token",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Hugging Face token for pushing to Hub",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hf_repo_id",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Hugging Face repository ID for pushing to Hub",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hf_private",
|
||||||
|
action="store_true",
|
||||||
|
help="Make the Hugging Face repository private",
|
||||||
|
)
|
||||||
parser.add_argument("dataset")
|
parser.add_argument("dataset")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -287,6 +509,9 @@ if __name__ == "__main__":
|
|||||||
if (i + 1) % 10 == 0:
|
if (i + 1) % 10 == 0:
|
||||||
losses = []
|
losses = []
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
|
|
||||||
|
if args.push_to_hub:
|
||||||
|
push_to_hub(args)
|
||||||
|
|
||||||
save_adapters("final_adapters.safetensors", flux, args)
|
save_adapters("final_adapters.safetensors", flux, args)
|
||||||
print("Training successful.")
|
print("Training successful.")
|
||||||
|
Loading…
Reference in New Issue
Block a user