From e7751e4c29287529d618a6274776482abe265c52 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 3 Oct 2024 18:03:45 -0700 Subject: [PATCH] Add gradient accumulation and data parallelism --- flux/dreambooth.py | 93 +++++++++++++++++++++++++++++++++++++++------- 1 file changed, 80 insertions(+), 13 deletions(-) diff --git a/flux/dreambooth.py b/flux/dreambooth.py index fe36a82b..2c83458d 100644 --- a/flux/dreambooth.py +++ b/flux/dreambooth.py @@ -7,7 +7,8 @@ import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim import numpy as np -from mlx.utils import tree_reduce, tree_unflatten +from mlx.nn.utils import average_gradients +from mlx.utils import tree_map, tree_reduce, tree_unflatten from PIL import Image from tqdm import tqdm @@ -15,13 +16,14 @@ from flux import FluxPipeline from flux.lora import LoRALinear -def linear_to_lora_layers(flux): +def linear_to_lora_layers(flux, args): lora_layers = [] + rank = args.lora_rank for name, mod in flux.flow.named_modules(): - if ".img_attn" not in name: + if ".img_attn" not in name and ".txt_attn" not in name: continue if ".qkv" in name or ".proj" in name: - lora_layers.append((name, LoRALinear.from_base(mod, r=32))) + lora_layers.append((name, LoRALinear.from_base(mod, r=rank))) flux.flow.update_modules(tree_unflatten(lora_layers)) @@ -60,6 +62,15 @@ def generate_latents(flux, n_images, prompt, steps, seed=None, leave=True): return x_t +def iterate_batches(t5_tokens, clip_tokens, x, batch_size): + while True: + indices = mx.random.randint(0, len(x), (batch_size,)) + t5_i = t5_tokens[indices] + clip_i = clip_tokens[indices] + x_i = x[indices] + yield t5_i, clip_i, x_i + + def generate_progress_images(iteration, flux, args): out_dir = Path(args.output_dir) out_dir.mkdir(parents=True, exist_ok=True) @@ -72,7 +83,7 @@ def generate_progress_images(iteration, flux, args): n_images, args.progress_prompt, args.progress_steps, - seed=42, + seed=42 + mx.distributed.init().rank(), ) # Arrange them on a grid @@ -82,6 +93,7 @@ def generate_progress_images(iteration, flux, args): B, H, W, C = x.shape x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4) x = x.reshape(n_rows * H, B // n_rows * W, C) + x = mx.pad(x, [(4, 4), (4, 4), (0, 0)]) x = (x * 255).astype(mx.uint8) # Save them to disc @@ -137,9 +149,18 @@ if __name__ == "__main__": default=50, help="Save the model every CHECKPOINT_EVERY steps", ) + parser.add_argument( + "--lora-rank", type=int, default=32, help="LoRA rank for finetuning" + ) parser.add_argument( "--learning-rate", type=float, default="1e-6", help="Learning rate for training" ) + parser.add_argument( + "--grad-accumulate", + type=int, + default=1, + help="Accumulate gradients for that many iterations before applying them", + ) parser.add_argument( "--output-dir", default="mlx_output", help="Folder to save the checkpoints in" ) @@ -154,7 +175,7 @@ if __name__ == "__main__": flux = FluxPipeline("flux-" + args.model) flux.ensure_models_are_loaded() flux.flow.freeze() - linear_to_lora_layers(flux) + linear_to_lora_layers(flux, args) trainable_params = tree_reduce( lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0 @@ -165,14 +186,61 @@ if __name__ == "__main__": state = [flux.flow.state, optimizer.state, mx.random.state] @partial(mx.compile, inputs=state, outputs=state) - def step(t5_tokens, clip_tokens, x, guidance): + def single_step(t5_tokens, clip_tokens, x, guidance): loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( t5_tokens, clip_tokens, x, guidance ) + grads = average_gradients(grads) optimizer.update(flux.flow, grads) return loss + @partial(mx.compile, inputs=state, outputs=state) + def compute_loss_and_grads(t5_tokens, clip_tokens, x, guidance): + return nn.value_and_grad(flux.flow, flux.training_loss)( + t5_tokens, clip_tokens, x, guidance + ) + + @partial(mx.compile, inputs=state, outputs=state) + def compute_loss_and_accumulate_grads( + t5_tokens, clip_tokens, x, guidance, prev_grads + ): + loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( + t5_tokens, clip_tokens, x, guidance + ) + grads = tree_map(lambda a, b: a + b, prev_grads, grads) + return loss, grads + + @partial(mx.compile, inputs=state, outputs=state) + def grad_accumulate_and_step(t5_tokens, clip_tokens, x, guidance, prev_grads): + loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( + t5_tokens, clip_tokens, x, guidance + ) + grads = tree_map(lambda a, b: a + b, prev_grads, grads) + grads = average_gradients(grads) + optimizer.update(flux.flow, grads) + + return loss + + def step(t5_tokens, clip_tokens, x, guidance, prev_grads, perform_step): + if prev_grads is None: + if perform_step: + return single_step(t5_tokens, clip_tokens, x, guidance), None + else: + return compute_loss_and_grads(t5_tokens, clip_tokens, x, guidance) + else: + if perform_step: + return ( + grad_accumulate_and_step( + t5_tokens, clip_tokens, x, guidance, prev_grads + ), + None, + ) + else: + return compute_loss_and_accumulate_grads( + t5_tokens, clip_tokens, x, guidance, prev_grads + ) + print("Encoding training images to latent space") x = extract_latent_vectors(flux, args.image_folder) t5_tokens, clip_tokens = flux.tokenize([args.prompt] * len(x)) @@ -181,14 +249,13 @@ if __name__ == "__main__": # An initial generation to compare generate_progress_images(0, flux, args) + grads = None losses = [] tic = time.time() - for i in range(args.iterations): - indices = (mx.random.uniform(shape=(args.batch_size,)) * len(x)).astype( - mx.uint32 - ) - loss = step(t5_tokens[indices], clip_tokens[indices], x[indices], guidance) - mx.eval(loss, state) + batches = iterate_batches(t5_tokens, clip_tokens, x, args.batch_size) + for i, batch in zip(range(args.iterations), batches): + loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0) + mx.eval(loss, grads, state) losses.append(loss.item()) if (i + 1) % 10 == 0: