Further refactoring

This commit is contained in:
Angelos Katharopoulos 2024-10-10 00:35:44 -07:00
parent f538394eec
commit ecd8828e33
3 changed files with 59 additions and 55 deletions

View File

@ -87,63 +87,21 @@ class FinetuningDataset:
yield xs[indices], t5[indices], clip[indices]
def linear_to_lora_layers(flux, args):
"""Swap the linear layers in the transformer blocks with LoRA layers."""
rank = args.lora_rank
all_blocks = flux.flow.double_blocks + flux.flow.single_blocks
all_blocks.reverse()
num_blocks = args.lora_blocks if args.lora_blocks > 0 else len(all_blocks)
for i, block in zip(range(num_blocks), all_blocks):
loras = []
for name, module in block.named_modules():
if isinstance(module, nn.Linear):
loras.append((name, LoRALinear.from_base(module, r=rank)))
block.update_modules(tree_unflatten(loras))
def generate_progress_images(iteration, flux, args):
"""Generate images to monitor the progress of the finetuning."""
def generate_latents(flux, n_images, prompt, steps, seed=None, leave=True):
with random_state(seed):
latents = flux.generate_latents(
prompt,
n_images=n_images,
num_steps=steps,
)
for x_t in tqdm(latents, total=args.progress_steps, leave=leave):
mx.eval(x_t)
return x_t
def decode_latents(flux, x):
decoded = []
for i in tqdm(range(len(x))):
decoded.append(flux.decode(x[i : i + 1]))
mx.eval(decoded[-1])
return mx.concatenate(decoded, axis=0)
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
out_file = out_dir / f"{iteration:07d}_progress.png"
print(f"Generating {str(out_file)}", flush=True)
# Generate the latent vectors using diffusion
n_images = 4
latents = generate_latents(
flux,
n_images,
args.progress_prompt,
args.progress_steps,
seed=42 + mx.distributed.init().rank(),
)
# Reload the text encoders to reduce the memory use during training
flux.reload_text_encoders()
# Arrange them on a grid
# Generate some images and arrange them in a grid
n_rows = 2
x = decode_latents(flux, latents)
n_images = 4
x = flux.generate_images(
args.progress_prompt,
n_images,
args.progress_steps,
)
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
B, H, W, C = x.shape
x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
@ -166,8 +124,8 @@ def save_adapters(iteration, flux, args):
str(out_file),
dict(tree_flatten(flux.flow.trainable_parameters())),
metadata={
"lora_rank": args.lora_rank,
"lora_blocks": args.lora_blocks,
"lora_rank": str(args.lora_rank),
"lora_blocks": str(args.lora_blocks),
},
)
@ -269,7 +227,7 @@ if __name__ == "__main__":
flux = FluxPipeline("flux-" + args.model)
flux.flow.freeze()
with random_state(0x0F0F0F0F):
linear_to_lora_layers(flux, args)
flux.linear_to_lora_layers(args.lora_rank, args.lora_blocks)
# Report how many parameters we are training
trainable_params = tree_reduce(

View File

@ -3,8 +3,11 @@ import time
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten
from tqdm import tqdm
from .lora import LoRALinear
from .sampler import FluxSampler
from .utils import (
load_ae,
@ -38,7 +41,7 @@ class FluxPipeline:
def reload_text_encoders(self):
self.t5 = load_t5(self.name)
self.clip = load_clip(name)
self.clip = load_clip(self.name)
def tokenize(self, text):
t5_tokens = self.t5_tokenizer.encode(text)
@ -156,6 +159,37 @@ class FluxPipeline:
x = self.ae.decode(x)
return mx.clip(x + 1, 0, 2) * 0.5
def generate_images(
self,
text: str,
n_images: int = 1,
num_steps: int = 35,
guidance: float = 4.0,
latent_size: Tuple[int, int] = (64, 64),
seed=None,
reload_text_encoders: bool = True,
progress: bool = True,
):
latents = self.generate_latents(
text, n_images, num_steps, guidance, latent_size, seed
)
mx.eval(next(latents))
if reload_text_encoders:
self.reload_text_encoders()
for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True):
mx.eval(x_t)
images = []
for i in tqdm(range(len(x_t)), disable=not progress):
images.append(self.decode(x_t[i : i + 1]))
mx.eval(images[-1])
images = mx.concatenate(images, axis=0)
mx.eval(images)
return images
def training_loss(
self,
x_0: mx.array,
@ -171,7 +205,7 @@ class FluxPipeline:
# Prepare the latent input
x_0, x_ids = self._prepare_latent_images(x_0)
# Forward process (we use rf/lognorm(0, 1))
# Forward process
t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype)
eps = mx.random.normal(x_0.shape, dtype=self.dtype)
x_t = self.sampler.add_noise(x_0, t, noise=eps)
@ -189,3 +223,15 @@ class FluxPipeline:
)
return (pred + x_0 - eps).square().mean()
def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1):
"""Swap the linear layers in the transformer blocks with LoRA layers."""
all_blocks = self.flow.double_blocks + self.flow.single_blocks
all_blocks.reverse()
num_blocks = num_blocks if num_blocks > 0 else len(all_blocks)
for i, block in zip(range(num_blocks), all_blocks):
loras = []
for name, module in block.named_modules():
if isinstance(module, nn.Linear):
loras.append((name, LoRALinear.from_base(module, r=rank)))
block.update_modules(tree_unflatten(loras))

View File

@ -23,7 +23,7 @@ class LoRALinear(nn.Module):
lora_lin.linear = linear
return lora_lin
def fuse(self, de_quantize: bool = False):
def fuse(self):
linear = self.linear
bias = "bias" in linear
weight = linear.weight