mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-08 10:14:36 +08:00
Add gradient accumulation and data parallelism
This commit is contained in:
@@ -7,7 +7,8 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import mlx.optimizers as optim
|
import mlx.optimizers as optim
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mlx.utils import tree_reduce, tree_unflatten
|
from mlx.nn.utils import average_gradients
|
||||||
|
from mlx.utils import tree_map, tree_reduce, tree_unflatten
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
@@ -15,13 +16,14 @@ from flux import FluxPipeline
|
|||||||
from flux.lora import LoRALinear
|
from flux.lora import LoRALinear
|
||||||
|
|
||||||
|
|
||||||
def linear_to_lora_layers(flux):
|
def linear_to_lora_layers(flux, args):
|
||||||
lora_layers = []
|
lora_layers = []
|
||||||
|
rank = args.lora_rank
|
||||||
for name, mod in flux.flow.named_modules():
|
for name, mod in flux.flow.named_modules():
|
||||||
if ".img_attn" not in name:
|
if ".img_attn" not in name and ".txt_attn" not in name:
|
||||||
continue
|
continue
|
||||||
if ".qkv" in name or ".proj" in name:
|
if ".qkv" in name or ".proj" in name:
|
||||||
lora_layers.append((name, LoRALinear.from_base(mod, r=32)))
|
lora_layers.append((name, LoRALinear.from_base(mod, r=rank)))
|
||||||
flux.flow.update_modules(tree_unflatten(lora_layers))
|
flux.flow.update_modules(tree_unflatten(lora_layers))
|
||||||
|
|
||||||
|
|
||||||
@@ -60,6 +62,15 @@ def generate_latents(flux, n_images, prompt, steps, seed=None, leave=True):
|
|||||||
return x_t
|
return x_t
|
||||||
|
|
||||||
|
|
||||||
|
def iterate_batches(t5_tokens, clip_tokens, x, batch_size):
|
||||||
|
while True:
|
||||||
|
indices = mx.random.randint(0, len(x), (batch_size,))
|
||||||
|
t5_i = t5_tokens[indices]
|
||||||
|
clip_i = clip_tokens[indices]
|
||||||
|
x_i = x[indices]
|
||||||
|
yield t5_i, clip_i, x_i
|
||||||
|
|
||||||
|
|
||||||
def generate_progress_images(iteration, flux, args):
|
def generate_progress_images(iteration, flux, args):
|
||||||
out_dir = Path(args.output_dir)
|
out_dir = Path(args.output_dir)
|
||||||
out_dir.mkdir(parents=True, exist_ok=True)
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -72,7 +83,7 @@ def generate_progress_images(iteration, flux, args):
|
|||||||
n_images,
|
n_images,
|
||||||
args.progress_prompt,
|
args.progress_prompt,
|
||||||
args.progress_steps,
|
args.progress_steps,
|
||||||
seed=42,
|
seed=42 + mx.distributed.init().rank(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Arrange them on a grid
|
# Arrange them on a grid
|
||||||
@@ -82,6 +93,7 @@ def generate_progress_images(iteration, flux, args):
|
|||||||
B, H, W, C = x.shape
|
B, H, W, C = x.shape
|
||||||
x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
|
x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
|
||||||
x = x.reshape(n_rows * H, B // n_rows * W, C)
|
x = x.reshape(n_rows * H, B // n_rows * W, C)
|
||||||
|
x = mx.pad(x, [(4, 4), (4, 4), (0, 0)])
|
||||||
x = (x * 255).astype(mx.uint8)
|
x = (x * 255).astype(mx.uint8)
|
||||||
|
|
||||||
# Save them to disc
|
# Save them to disc
|
||||||
@@ -137,9 +149,18 @@ if __name__ == "__main__":
|
|||||||
default=50,
|
default=50,
|
||||||
help="Save the model every CHECKPOINT_EVERY steps",
|
help="Save the model every CHECKPOINT_EVERY steps",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora-rank", type=int, default=32, help="LoRA rank for finetuning"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--learning-rate", type=float, default="1e-6", help="Learning rate for training"
|
"--learning-rate", type=float, default="1e-6", help="Learning rate for training"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--grad-accumulate",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Accumulate gradients for that many iterations before applying them",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output-dir", default="mlx_output", help="Folder to save the checkpoints in"
|
"--output-dir", default="mlx_output", help="Folder to save the checkpoints in"
|
||||||
)
|
)
|
||||||
@@ -154,7 +175,7 @@ if __name__ == "__main__":
|
|||||||
flux = FluxPipeline("flux-" + args.model)
|
flux = FluxPipeline("flux-" + args.model)
|
||||||
flux.ensure_models_are_loaded()
|
flux.ensure_models_are_loaded()
|
||||||
flux.flow.freeze()
|
flux.flow.freeze()
|
||||||
linear_to_lora_layers(flux)
|
linear_to_lora_layers(flux, args)
|
||||||
|
|
||||||
trainable_params = tree_reduce(
|
trainable_params = tree_reduce(
|
||||||
lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0
|
lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0
|
||||||
@@ -165,14 +186,61 @@ if __name__ == "__main__":
|
|||||||
state = [flux.flow.state, optimizer.state, mx.random.state]
|
state = [flux.flow.state, optimizer.state, mx.random.state]
|
||||||
|
|
||||||
@partial(mx.compile, inputs=state, outputs=state)
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
def step(t5_tokens, clip_tokens, x, guidance):
|
def single_step(t5_tokens, clip_tokens, x, guidance):
|
||||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||||
t5_tokens, clip_tokens, x, guidance
|
t5_tokens, clip_tokens, x, guidance
|
||||||
)
|
)
|
||||||
|
grads = average_gradients(grads)
|
||||||
optimizer.update(flux.flow, grads)
|
optimizer.update(flux.flow, grads)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
|
def compute_loss_and_grads(t5_tokens, clip_tokens, x, guidance):
|
||||||
|
return nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||||
|
t5_tokens, clip_tokens, x, guidance
|
||||||
|
)
|
||||||
|
|
||||||
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
|
def compute_loss_and_accumulate_grads(
|
||||||
|
t5_tokens, clip_tokens, x, guidance, prev_grads
|
||||||
|
):
|
||||||
|
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||||
|
t5_tokens, clip_tokens, x, guidance
|
||||||
|
)
|
||||||
|
grads = tree_map(lambda a, b: a + b, prev_grads, grads)
|
||||||
|
return loss, grads
|
||||||
|
|
||||||
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
|
def grad_accumulate_and_step(t5_tokens, clip_tokens, x, guidance, prev_grads):
|
||||||
|
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||||
|
t5_tokens, clip_tokens, x, guidance
|
||||||
|
)
|
||||||
|
grads = tree_map(lambda a, b: a + b, prev_grads, grads)
|
||||||
|
grads = average_gradients(grads)
|
||||||
|
optimizer.update(flux.flow, grads)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def step(t5_tokens, clip_tokens, x, guidance, prev_grads, perform_step):
|
||||||
|
if prev_grads is None:
|
||||||
|
if perform_step:
|
||||||
|
return single_step(t5_tokens, clip_tokens, x, guidance), None
|
||||||
|
else:
|
||||||
|
return compute_loss_and_grads(t5_tokens, clip_tokens, x, guidance)
|
||||||
|
else:
|
||||||
|
if perform_step:
|
||||||
|
return (
|
||||||
|
grad_accumulate_and_step(
|
||||||
|
t5_tokens, clip_tokens, x, guidance, prev_grads
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return compute_loss_and_accumulate_grads(
|
||||||
|
t5_tokens, clip_tokens, x, guidance, prev_grads
|
||||||
|
)
|
||||||
|
|
||||||
print("Encoding training images to latent space")
|
print("Encoding training images to latent space")
|
||||||
x = extract_latent_vectors(flux, args.image_folder)
|
x = extract_latent_vectors(flux, args.image_folder)
|
||||||
t5_tokens, clip_tokens = flux.tokenize([args.prompt] * len(x))
|
t5_tokens, clip_tokens = flux.tokenize([args.prompt] * len(x))
|
||||||
@@ -181,14 +249,13 @@ if __name__ == "__main__":
|
|||||||
# An initial generation to compare
|
# An initial generation to compare
|
||||||
generate_progress_images(0, flux, args)
|
generate_progress_images(0, flux, args)
|
||||||
|
|
||||||
|
grads = None
|
||||||
losses = []
|
losses = []
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
for i in range(args.iterations):
|
batches = iterate_batches(t5_tokens, clip_tokens, x, args.batch_size)
|
||||||
indices = (mx.random.uniform(shape=(args.batch_size,)) * len(x)).astype(
|
for i, batch in zip(range(args.iterations), batches):
|
||||||
mx.uint32
|
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
|
||||||
)
|
mx.eval(loss, grads, state)
|
||||||
loss = step(t5_tokens[indices], clip_tokens[indices], x[indices], guidance)
|
|
||||||
mx.eval(loss, state)
|
|
||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
|
|
||||||
if (i + 1) % 10 == 0:
|
if (i + 1) % 10 == 0:
|
||||||
|
Reference in New Issue
Block a user