mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Add gradient accumulation and data parallelism
This commit is contained in:
@@ -7,7 +7,8 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import numpy as np
|
||||
from mlx.utils import tree_reduce, tree_unflatten
|
||||
from mlx.nn.utils import average_gradients
|
||||
from mlx.utils import tree_map, tree_reduce, tree_unflatten
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -15,13 +16,14 @@ from flux import FluxPipeline
|
||||
from flux.lora import LoRALinear
|
||||
|
||||
|
||||
def linear_to_lora_layers(flux):
|
||||
def linear_to_lora_layers(flux, args):
|
||||
lora_layers = []
|
||||
rank = args.lora_rank
|
||||
for name, mod in flux.flow.named_modules():
|
||||
if ".img_attn" not in name:
|
||||
if ".img_attn" not in name and ".txt_attn" not in name:
|
||||
continue
|
||||
if ".qkv" in name or ".proj" in name:
|
||||
lora_layers.append((name, LoRALinear.from_base(mod, r=32)))
|
||||
lora_layers.append((name, LoRALinear.from_base(mod, r=rank)))
|
||||
flux.flow.update_modules(tree_unflatten(lora_layers))
|
||||
|
||||
|
||||
@@ -60,6 +62,15 @@ def generate_latents(flux, n_images, prompt, steps, seed=None, leave=True):
|
||||
return x_t
|
||||
|
||||
|
||||
def iterate_batches(t5_tokens, clip_tokens, x, batch_size):
|
||||
while True:
|
||||
indices = mx.random.randint(0, len(x), (batch_size,))
|
||||
t5_i = t5_tokens[indices]
|
||||
clip_i = clip_tokens[indices]
|
||||
x_i = x[indices]
|
||||
yield t5_i, clip_i, x_i
|
||||
|
||||
|
||||
def generate_progress_images(iteration, flux, args):
|
||||
out_dir = Path(args.output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -72,7 +83,7 @@ def generate_progress_images(iteration, flux, args):
|
||||
n_images,
|
||||
args.progress_prompt,
|
||||
args.progress_steps,
|
||||
seed=42,
|
||||
seed=42 + mx.distributed.init().rank(),
|
||||
)
|
||||
|
||||
# Arrange them on a grid
|
||||
@@ -82,6 +93,7 @@ def generate_progress_images(iteration, flux, args):
|
||||
B, H, W, C = x.shape
|
||||
x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
|
||||
x = x.reshape(n_rows * H, B // n_rows * W, C)
|
||||
x = mx.pad(x, [(4, 4), (4, 4), (0, 0)])
|
||||
x = (x * 255).astype(mx.uint8)
|
||||
|
||||
# Save them to disc
|
||||
@@ -137,9 +149,18 @@ if __name__ == "__main__":
|
||||
default=50,
|
||||
help="Save the model every CHECKPOINT_EVERY steps",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-rank", type=int, default=32, help="LoRA rank for finetuning"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning-rate", type=float, default="1e-6", help="Learning rate for training"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grad-accumulate",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Accumulate gradients for that many iterations before applying them",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir", default="mlx_output", help="Folder to save the checkpoints in"
|
||||
)
|
||||
@@ -154,7 +175,7 @@ if __name__ == "__main__":
|
||||
flux = FluxPipeline("flux-" + args.model)
|
||||
flux.ensure_models_are_loaded()
|
||||
flux.flow.freeze()
|
||||
linear_to_lora_layers(flux)
|
||||
linear_to_lora_layers(flux, args)
|
||||
|
||||
trainable_params = tree_reduce(
|
||||
lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0
|
||||
@@ -165,14 +186,61 @@ if __name__ == "__main__":
|
||||
state = [flux.flow.state, optimizer.state, mx.random.state]
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def step(t5_tokens, clip_tokens, x, guidance):
|
||||
def single_step(t5_tokens, clip_tokens, x, guidance):
|
||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||
t5_tokens, clip_tokens, x, guidance
|
||||
)
|
||||
grads = average_gradients(grads)
|
||||
optimizer.update(flux.flow, grads)
|
||||
|
||||
return loss
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def compute_loss_and_grads(t5_tokens, clip_tokens, x, guidance):
|
||||
return nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||
t5_tokens, clip_tokens, x, guidance
|
||||
)
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def compute_loss_and_accumulate_grads(
|
||||
t5_tokens, clip_tokens, x, guidance, prev_grads
|
||||
):
|
||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||
t5_tokens, clip_tokens, x, guidance
|
||||
)
|
||||
grads = tree_map(lambda a, b: a + b, prev_grads, grads)
|
||||
return loss, grads
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def grad_accumulate_and_step(t5_tokens, clip_tokens, x, guidance, prev_grads):
|
||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||
t5_tokens, clip_tokens, x, guidance
|
||||
)
|
||||
grads = tree_map(lambda a, b: a + b, prev_grads, grads)
|
||||
grads = average_gradients(grads)
|
||||
optimizer.update(flux.flow, grads)
|
||||
|
||||
return loss
|
||||
|
||||
def step(t5_tokens, clip_tokens, x, guidance, prev_grads, perform_step):
|
||||
if prev_grads is None:
|
||||
if perform_step:
|
||||
return single_step(t5_tokens, clip_tokens, x, guidance), None
|
||||
else:
|
||||
return compute_loss_and_grads(t5_tokens, clip_tokens, x, guidance)
|
||||
else:
|
||||
if perform_step:
|
||||
return (
|
||||
grad_accumulate_and_step(
|
||||
t5_tokens, clip_tokens, x, guidance, prev_grads
|
||||
),
|
||||
None,
|
||||
)
|
||||
else:
|
||||
return compute_loss_and_accumulate_grads(
|
||||
t5_tokens, clip_tokens, x, guidance, prev_grads
|
||||
)
|
||||
|
||||
print("Encoding training images to latent space")
|
||||
x = extract_latent_vectors(flux, args.image_folder)
|
||||
t5_tokens, clip_tokens = flux.tokenize([args.prompt] * len(x))
|
||||
@@ -181,14 +249,13 @@ if __name__ == "__main__":
|
||||
# An initial generation to compare
|
||||
generate_progress_images(0, flux, args)
|
||||
|
||||
grads = None
|
||||
losses = []
|
||||
tic = time.time()
|
||||
for i in range(args.iterations):
|
||||
indices = (mx.random.uniform(shape=(args.batch_size,)) * len(x)).astype(
|
||||
mx.uint32
|
||||
)
|
||||
loss = step(t5_tokens[indices], clip_tokens[indices], x[indices], guidance)
|
||||
mx.eval(loss, state)
|
||||
batches = iterate_batches(t5_tokens, clip_tokens, x, args.batch_size)
|
||||
for i, batch in zip(range(args.iterations), batches):
|
||||
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
|
||||
mx.eval(loss, grads, state)
|
||||
losses.append(loss.item())
|
||||
|
||||
if (i + 1) % 10 == 0:
|
||||
|
Reference in New Issue
Block a user