Start memory-efficient flux finetuning branch

This commit is contained in:
Angelos Katharopoulos 2024-10-25 09:46:47 -07:00
parent 4971462bf0
commit 67607a8e13
4 changed files with 101 additions and 6 deletions

View File

@ -16,6 +16,10 @@ from PIL import Image
from flux import FluxPipeline, Trainer, load_dataset
def quantization_predicate(name, m):
return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0
def generate_progress_images(iteration, flux, args):
"""Generate images to monitor the progress of the finetuning."""
out_dir = Path(args.output_dir)
@ -24,11 +28,10 @@ def generate_progress_images(iteration, flux, args):
print(f"Generating {str(out_file)}", flush=True)
# Generate some images and arrange them in a grid
n_rows = 2
n_images = 4
n_rows = 2 if args.progress_num_images % 2 == 0 else 1
x = flux.generate_images(
args.progress_prompt,
n_images,
args.progress_num_images,
args.progress_steps,
)
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
@ -42,6 +45,16 @@ def generate_progress_images(iteration, flux, args):
im = Image.fromarray(np.array(x))
im.save(out_file)
# generate_images reloads the text encoders in order to remove them from
# RAM. In memory pressured environments this will swap the flow transformer
# to disk and back to RAM during generation.
#
# However, we have to requantize the text encoders for the next time we
# want to use them.
if args.quantize:
nn.quantize(flux.t5, class_predicate=quantization_predicate)
nn.quantize(flux.clip, class_predicate=quantization_predicate)
def save_adapters(iteration, flux, args):
out_dir = Path(args.output_dir)
@ -74,6 +87,17 @@ def setup_arg_parser():
],
help="Which flux model to train",
)
parser.add_argument(
"--quantize",
"-q",
action="store_true",
help="Quantize the models to reduce the memory required for training",
)
parser.add_argument(
"--gradient-checkpointing",
action="store_true",
help="Enable gradient checkpointing to reduce the memory required for training",
)
parser.add_argument(
"--guidance", type=float, default=4.0, help="The guidance factor to use."
)
@ -118,6 +142,12 @@ def setup_arg_parser():
default=50,
help="Generate images every PROGRESS_EVERY steps",
)
parser.add_argument(
"--progress-num-images",
type=int,
default=4,
help="How many progress images to generate",
)
parser.add_argument(
"--checkpoint-every",
type=int,
@ -162,6 +192,14 @@ if __name__ == "__main__":
# initial weights.
mx.random.seed(0x0F0F0F0F)
flux = FluxPipeline("flux-" + args.model)
if args.quantize:
nn.quantize(flux.flow, class_predicate=quantization_predicate)
nn.quantize(flux.t5, class_predicate=quantization_predicate)
nn.quantize(flux.clip, class_predicate=quantization_predicate)
if args.gradient_checkpointing:
flux.gradient_checkpointing()
flux.flow.freeze()
flux.linear_to_lora_layers(args.lora_rank, args.lora_blocks)
@ -254,8 +292,12 @@ if __name__ == "__main__":
guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype)
# An initial generation to compare
generate_progress_images(0, flux, args)
# generate_progress_images(0, flux, args)
flux.reload_text_encoders()
del flux.t5
del flux.clip
mx.metal.reset_peak_memory()
grads = None
losses = []
tic = time.time()

View File

@ -7,6 +7,12 @@ import mlx.nn as nn
from mlx.utils import tree_unflatten
from tqdm import tqdm
from .layers import (
DoubleStreamBlock,
SingleStreamBlock,
disable_gradient_checkpointing,
enable_gradient_checkpointing,
)
from .lora import LoRALinear
from .sampler import FluxSampler
from .utils import (
@ -234,7 +240,7 @@ class FluxPipeline:
for i, block in zip(range(num_blocks), all_blocks):
loras = []
for name, module in block.named_modules():
if isinstance(module, nn.Linear):
if isinstance(module, (nn.Linear, nn.QuantizedLinear)):
loras.append((name, LoRALinear.from_base(module, r=rank)))
block.update_modules(tree_unflatten(loras))
@ -244,3 +250,13 @@ class FluxPipeline:
if isinstance(module, LoRALinear):
fused_layers.append((name, module.fuse()))
self.flow.update_modules(tree_unflatten(fused_layers))
def gradient_checkpointing(self, enable: bool = True):
"""Replace the call function of SingleStreamBlock and DoubleStreamBlock
to a checkpointing one."""
if enable:
enable_gradient_checkpointing(SingleStreamBlock)
enable_gradient_checkpointing(DoubleStreamBlock)
else:
disable_gradient_checkpointing(SingleStreamBlock)
disable_gradient_checkpointing(DoubleStreamBlock)

View File

@ -9,6 +9,37 @@ import mlx.core as mx
import mlx.nn as nn
def enable_gradient_checkpointing(module_class):
if hasattr(module_class, "_original_call"):
raise ValueError(
f"Gradient checkpointing is already enabled for {module_class.__name__}"
)
fn = module_class.__call__
module_class._original_call = fn
def checkpointed_fn(module_instance, *args, **kwargs):
def inner_fn(params, *args, **kwargs):
module_instance.update(params)
return fn(module_instance, *args, **kwargs)
return mx.checkpoint(inner_fn)(
module_instance.trainable_parameters(), *args, **kwargs
)
module_class.__call__ = checkpointed_fn
def disable_gradient_checkpointing(module_class):
if not hasattr(module_class, "_original_call"):
raise ValueError(
f"Gradient checkpointing is not enabled for {module_class.__name__}"
)
module_class.__call__ = module_class._original_call
delattr(module_class, "_original_call")
def _rope(pos: mx.array, dim: int, theta: float):
scale = mx.arange(0, dim, 2, dtype=mx.float32) / dim
omega = 1.0 / (theta**scale)

View File

@ -9,12 +9,15 @@ import mlx.nn as nn
class LoRALinear(nn.Module):
@staticmethod
def from_base(
linear: nn.Linear,
linear: nn.Module,
r: int = 8,
dropout: float = 0.0,
scale: float = 1.0,
):
output_dims, input_dims = linear.weight.shape
if isinstance(linear, nn.QuantizedLinear):
input_dims *= 32 // linear.bits
lora_lin = LoRALinear(
input_dims=input_dims,
output_dims=output_dims,
@ -26,6 +29,9 @@ class LoRALinear(nn.Module):
return lora_lin
def fuse(self):
if isinstance(self.linear, nn.QuantizedLinear):
raise NotImplementedError("Cannot fuse QLoRA layers yet.")
linear = self.linear
bias = "bias" in linear
weight = linear.weight