Add gradient accumulation and data parallelism

This commit is contained in:
Angelos Katharopoulos
2024-10-03 18:03:45 -07:00
parent 7cffcdcaff
commit e7751e4c29

View File

@@ -7,7 +7,8 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import mlx.optimizers as optim import mlx.optimizers as optim
import numpy as np import numpy as np
from mlx.utils import tree_reduce, tree_unflatten from mlx.nn.utils import average_gradients
from mlx.utils import tree_map, tree_reduce, tree_unflatten
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
@@ -15,13 +16,14 @@ from flux import FluxPipeline
from flux.lora import LoRALinear from flux.lora import LoRALinear
def linear_to_lora_layers(flux): def linear_to_lora_layers(flux, args):
lora_layers = [] lora_layers = []
rank = args.lora_rank
for name, mod in flux.flow.named_modules(): for name, mod in flux.flow.named_modules():
if ".img_attn" not in name: if ".img_attn" not in name and ".txt_attn" not in name:
continue continue
if ".qkv" in name or ".proj" in name: if ".qkv" in name or ".proj" in name:
lora_layers.append((name, LoRALinear.from_base(mod, r=32))) lora_layers.append((name, LoRALinear.from_base(mod, r=rank)))
flux.flow.update_modules(tree_unflatten(lora_layers)) flux.flow.update_modules(tree_unflatten(lora_layers))
@@ -60,6 +62,15 @@ def generate_latents(flux, n_images, prompt, steps, seed=None, leave=True):
return x_t return x_t
def iterate_batches(t5_tokens, clip_tokens, x, batch_size):
while True:
indices = mx.random.randint(0, len(x), (batch_size,))
t5_i = t5_tokens[indices]
clip_i = clip_tokens[indices]
x_i = x[indices]
yield t5_i, clip_i, x_i
def generate_progress_images(iteration, flux, args): def generate_progress_images(iteration, flux, args):
out_dir = Path(args.output_dir) out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True) out_dir.mkdir(parents=True, exist_ok=True)
@@ -72,7 +83,7 @@ def generate_progress_images(iteration, flux, args):
n_images, n_images,
args.progress_prompt, args.progress_prompt,
args.progress_steps, args.progress_steps,
seed=42, seed=42 + mx.distributed.init().rank(),
) )
# Arrange them on a grid # Arrange them on a grid
@@ -82,6 +93,7 @@ def generate_progress_images(iteration, flux, args):
B, H, W, C = x.shape B, H, W, C = x.shape
x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4) x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
x = x.reshape(n_rows * H, B // n_rows * W, C) x = x.reshape(n_rows * H, B // n_rows * W, C)
x = mx.pad(x, [(4, 4), (4, 4), (0, 0)])
x = (x * 255).astype(mx.uint8) x = (x * 255).astype(mx.uint8)
# Save them to disc # Save them to disc
@@ -137,9 +149,18 @@ if __name__ == "__main__":
default=50, default=50,
help="Save the model every CHECKPOINT_EVERY steps", help="Save the model every CHECKPOINT_EVERY steps",
) )
parser.add_argument(
"--lora-rank", type=int, default=32, help="LoRA rank for finetuning"
)
parser.add_argument( parser.add_argument(
"--learning-rate", type=float, default="1e-6", help="Learning rate for training" "--learning-rate", type=float, default="1e-6", help="Learning rate for training"
) )
parser.add_argument(
"--grad-accumulate",
type=int,
default=1,
help="Accumulate gradients for that many iterations before applying them",
)
parser.add_argument( parser.add_argument(
"--output-dir", default="mlx_output", help="Folder to save the checkpoints in" "--output-dir", default="mlx_output", help="Folder to save the checkpoints in"
) )
@@ -154,7 +175,7 @@ if __name__ == "__main__":
flux = FluxPipeline("flux-" + args.model) flux = FluxPipeline("flux-" + args.model)
flux.ensure_models_are_loaded() flux.ensure_models_are_loaded()
flux.flow.freeze() flux.flow.freeze()
linear_to_lora_layers(flux) linear_to_lora_layers(flux, args)
trainable_params = tree_reduce( trainable_params = tree_reduce(
lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0 lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0
@@ -165,14 +186,61 @@ if __name__ == "__main__":
state = [flux.flow.state, optimizer.state, mx.random.state] state = [flux.flow.state, optimizer.state, mx.random.state]
@partial(mx.compile, inputs=state, outputs=state) @partial(mx.compile, inputs=state, outputs=state)
def step(t5_tokens, clip_tokens, x, guidance): def single_step(t5_tokens, clip_tokens, x, guidance):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
t5_tokens, clip_tokens, x, guidance t5_tokens, clip_tokens, x, guidance
) )
grads = average_gradients(grads)
optimizer.update(flux.flow, grads) optimizer.update(flux.flow, grads)
return loss return loss
@partial(mx.compile, inputs=state, outputs=state)
def compute_loss_and_grads(t5_tokens, clip_tokens, x, guidance):
return nn.value_and_grad(flux.flow, flux.training_loss)(
t5_tokens, clip_tokens, x, guidance
)
@partial(mx.compile, inputs=state, outputs=state)
def compute_loss_and_accumulate_grads(
t5_tokens, clip_tokens, x, guidance, prev_grads
):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
t5_tokens, clip_tokens, x, guidance
)
grads = tree_map(lambda a, b: a + b, prev_grads, grads)
return loss, grads
@partial(mx.compile, inputs=state, outputs=state)
def grad_accumulate_and_step(t5_tokens, clip_tokens, x, guidance, prev_grads):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
t5_tokens, clip_tokens, x, guidance
)
grads = tree_map(lambda a, b: a + b, prev_grads, grads)
grads = average_gradients(grads)
optimizer.update(flux.flow, grads)
return loss
def step(t5_tokens, clip_tokens, x, guidance, prev_grads, perform_step):
if prev_grads is None:
if perform_step:
return single_step(t5_tokens, clip_tokens, x, guidance), None
else:
return compute_loss_and_grads(t5_tokens, clip_tokens, x, guidance)
else:
if perform_step:
return (
grad_accumulate_and_step(
t5_tokens, clip_tokens, x, guidance, prev_grads
),
None,
)
else:
return compute_loss_and_accumulate_grads(
t5_tokens, clip_tokens, x, guidance, prev_grads
)
print("Encoding training images to latent space") print("Encoding training images to latent space")
x = extract_latent_vectors(flux, args.image_folder) x = extract_latent_vectors(flux, args.image_folder)
t5_tokens, clip_tokens = flux.tokenize([args.prompt] * len(x)) t5_tokens, clip_tokens = flux.tokenize([args.prompt] * len(x))
@@ -181,14 +249,13 @@ if __name__ == "__main__":
# An initial generation to compare # An initial generation to compare
generate_progress_images(0, flux, args) generate_progress_images(0, flux, args)
grads = None
losses = [] losses = []
tic = time.time() tic = time.time()
for i in range(args.iterations): batches = iterate_batches(t5_tokens, clip_tokens, x, args.batch_size)
indices = (mx.random.uniform(shape=(args.batch_size,)) * len(x)).astype( for i, batch in zip(range(args.iterations), batches):
mx.uint32 loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
) mx.eval(loss, grads, state)
loss = step(t5_tokens[indices], clip_tokens[indices], x[indices], guidance)
mx.eval(loss, state)
losses.append(loss.item()) losses.append(loss.item())
if (i + 1) % 10 == 0: if (i + 1) % 10 == 0: