mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 11:54:37 +08:00
Further refactoring
This commit is contained in:
parent
f538394eec
commit
ecd8828e33
@ -87,63 +87,21 @@ class FinetuningDataset:
|
||||
yield xs[indices], t5[indices], clip[indices]
|
||||
|
||||
|
||||
def linear_to_lora_layers(flux, args):
|
||||
"""Swap the linear layers in the transformer blocks with LoRA layers."""
|
||||
rank = args.lora_rank
|
||||
all_blocks = flux.flow.double_blocks + flux.flow.single_blocks
|
||||
all_blocks.reverse()
|
||||
num_blocks = args.lora_blocks if args.lora_blocks > 0 else len(all_blocks)
|
||||
for i, block in zip(range(num_blocks), all_blocks):
|
||||
loras = []
|
||||
for name, module in block.named_modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
loras.append((name, LoRALinear.from_base(module, r=rank)))
|
||||
block.update_modules(tree_unflatten(loras))
|
||||
|
||||
|
||||
def generate_progress_images(iteration, flux, args):
|
||||
"""Generate images to monitor the progress of the finetuning."""
|
||||
|
||||
def generate_latents(flux, n_images, prompt, steps, seed=None, leave=True):
|
||||
with random_state(seed):
|
||||
latents = flux.generate_latents(
|
||||
prompt,
|
||||
n_images=n_images,
|
||||
num_steps=steps,
|
||||
)
|
||||
for x_t in tqdm(latents, total=args.progress_steps, leave=leave):
|
||||
mx.eval(x_t)
|
||||
|
||||
return x_t
|
||||
|
||||
def decode_latents(flux, x):
|
||||
decoded = []
|
||||
for i in tqdm(range(len(x))):
|
||||
decoded.append(flux.decode(x[i : i + 1]))
|
||||
mx.eval(decoded[-1])
|
||||
return mx.concatenate(decoded, axis=0)
|
||||
|
||||
out_dir = Path(args.output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_file = out_dir / f"{iteration:07d}_progress.png"
|
||||
print(f"Generating {str(out_file)}", flush=True)
|
||||
|
||||
# Generate the latent vectors using diffusion
|
||||
n_images = 4
|
||||
latents = generate_latents(
|
||||
flux,
|
||||
n_images,
|
||||
args.progress_prompt,
|
||||
args.progress_steps,
|
||||
seed=42 + mx.distributed.init().rank(),
|
||||
)
|
||||
|
||||
# Reload the text encoders to reduce the memory use during training
|
||||
flux.reload_text_encoders()
|
||||
|
||||
# Arrange them on a grid
|
||||
# Generate some images and arrange them in a grid
|
||||
n_rows = 2
|
||||
x = decode_latents(flux, latents)
|
||||
n_images = 4
|
||||
x = flux.generate_images(
|
||||
args.progress_prompt,
|
||||
n_images,
|
||||
args.progress_steps,
|
||||
)
|
||||
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
|
||||
B, H, W, C = x.shape
|
||||
x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
|
||||
@ -166,8 +124,8 @@ def save_adapters(iteration, flux, args):
|
||||
str(out_file),
|
||||
dict(tree_flatten(flux.flow.trainable_parameters())),
|
||||
metadata={
|
||||
"lora_rank": args.lora_rank,
|
||||
"lora_blocks": args.lora_blocks,
|
||||
"lora_rank": str(args.lora_rank),
|
||||
"lora_blocks": str(args.lora_blocks),
|
||||
},
|
||||
)
|
||||
|
||||
@ -269,7 +227,7 @@ if __name__ == "__main__":
|
||||
flux = FluxPipeline("flux-" + args.model)
|
||||
flux.flow.freeze()
|
||||
with random_state(0x0F0F0F0F):
|
||||
linear_to_lora_layers(flux, args)
|
||||
flux.linear_to_lora_layers(args.lora_rank, args.lora_blocks)
|
||||
|
||||
# Report how many parameters we are training
|
||||
trainable_params = tree_reduce(
|
||||
|
@ -3,8 +3,11 @@ import time
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_unflatten
|
||||
from tqdm import tqdm
|
||||
|
||||
from .lora import LoRALinear
|
||||
from .sampler import FluxSampler
|
||||
from .utils import (
|
||||
load_ae,
|
||||
@ -38,7 +41,7 @@ class FluxPipeline:
|
||||
|
||||
def reload_text_encoders(self):
|
||||
self.t5 = load_t5(self.name)
|
||||
self.clip = load_clip(name)
|
||||
self.clip = load_clip(self.name)
|
||||
|
||||
def tokenize(self, text):
|
||||
t5_tokens = self.t5_tokenizer.encode(text)
|
||||
@ -156,6 +159,37 @@ class FluxPipeline:
|
||||
x = self.ae.decode(x)
|
||||
return mx.clip(x + 1, 0, 2) * 0.5
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
text: str,
|
||||
n_images: int = 1,
|
||||
num_steps: int = 35,
|
||||
guidance: float = 4.0,
|
||||
latent_size: Tuple[int, int] = (64, 64),
|
||||
seed=None,
|
||||
reload_text_encoders: bool = True,
|
||||
progress: bool = True,
|
||||
):
|
||||
latents = self.generate_latents(
|
||||
text, n_images, num_steps, guidance, latent_size, seed
|
||||
)
|
||||
mx.eval(next(latents))
|
||||
|
||||
if reload_text_encoders:
|
||||
self.reload_text_encoders()
|
||||
|
||||
for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True):
|
||||
mx.eval(x_t)
|
||||
|
||||
images = []
|
||||
for i in tqdm(range(len(x_t)), disable=not progress):
|
||||
images.append(self.decode(x_t[i : i + 1]))
|
||||
mx.eval(images[-1])
|
||||
images = mx.concatenate(images, axis=0)
|
||||
mx.eval(images)
|
||||
|
||||
return images
|
||||
|
||||
def training_loss(
|
||||
self,
|
||||
x_0: mx.array,
|
||||
@ -171,7 +205,7 @@ class FluxPipeline:
|
||||
# Prepare the latent input
|
||||
x_0, x_ids = self._prepare_latent_images(x_0)
|
||||
|
||||
# Forward process (we use rf/lognorm(0, 1))
|
||||
# Forward process
|
||||
t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype)
|
||||
eps = mx.random.normal(x_0.shape, dtype=self.dtype)
|
||||
x_t = self.sampler.add_noise(x_0, t, noise=eps)
|
||||
@ -189,3 +223,15 @@ class FluxPipeline:
|
||||
)
|
||||
|
||||
return (pred + x_0 - eps).square().mean()
|
||||
|
||||
def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1):
|
||||
"""Swap the linear layers in the transformer blocks with LoRA layers."""
|
||||
all_blocks = self.flow.double_blocks + self.flow.single_blocks
|
||||
all_blocks.reverse()
|
||||
num_blocks = num_blocks if num_blocks > 0 else len(all_blocks)
|
||||
for i, block in zip(range(num_blocks), all_blocks):
|
||||
loras = []
|
||||
for name, module in block.named_modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
loras.append((name, LoRALinear.from_base(module, r=rank)))
|
||||
block.update_modules(tree_unflatten(loras))
|
||||
|
@ -23,7 +23,7 @@ class LoRALinear(nn.Module):
|
||||
lora_lin.linear = linear
|
||||
return lora_lin
|
||||
|
||||
def fuse(self, de_quantize: bool = False):
|
||||
def fuse(self):
|
||||
linear = self.linear
|
||||
bias = "bias" in linear
|
||||
weight = linear.weight
|
||||
|
Loading…
Reference in New Issue
Block a user