diff --git a/flux/dreambooth.py b/flux/dreambooth.py index 48dcad47..4ae7249c 100644 --- a/flux/dreambooth.py +++ b/flux/dreambooth.py @@ -16,6 +16,10 @@ from PIL import Image from flux import FluxPipeline, Trainer, load_dataset +def quantization_predicate(name, m): + return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0 + + def generate_progress_images(iteration, flux, args): """Generate images to monitor the progress of the finetuning.""" out_dir = Path(args.output_dir) @@ -24,11 +28,10 @@ def generate_progress_images(iteration, flux, args): print(f"Generating {str(out_file)}", flush=True) # Generate some images and arrange them in a grid - n_rows = 2 - n_images = 4 + n_rows = 2 if args.progress_num_images % 2 == 0 else 1 x = flux.generate_images( args.progress_prompt, - n_images, + args.progress_num_images, args.progress_steps, ) x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)]) @@ -42,6 +45,16 @@ def generate_progress_images(iteration, flux, args): im = Image.fromarray(np.array(x)) im.save(out_file) + # generate_images reloads the text encoders in order to remove them from + # RAM. In memory pressured environments this will swap the flow transformer + # to disk and back to RAM during generation. + # + # However, we have to requantize the text encoders for the next time we + # want to use them. + if args.quantize: + nn.quantize(flux.t5, class_predicate=quantization_predicate) + nn.quantize(flux.clip, class_predicate=quantization_predicate) + def save_adapters(iteration, flux, args): out_dir = Path(args.output_dir) @@ -74,6 +87,17 @@ def setup_arg_parser(): ], help="Which flux model to train", ) + parser.add_argument( + "--quantize", + "-q", + action="store_true", + help="Quantize the models to reduce the memory required for training", + ) + parser.add_argument( + "--gradient-checkpointing", + action="store_true", + help="Enable gradient checkpointing to reduce the memory required for training", + ) parser.add_argument( "--guidance", type=float, default=4.0, help="The guidance factor to use." ) @@ -118,6 +142,12 @@ def setup_arg_parser(): default=50, help="Generate images every PROGRESS_EVERY steps", ) + parser.add_argument( + "--progress-num-images", + type=int, + default=4, + help="How many progress images to generate", + ) parser.add_argument( "--checkpoint-every", type=int, @@ -162,6 +192,14 @@ if __name__ == "__main__": # initial weights. mx.random.seed(0x0F0F0F0F) flux = FluxPipeline("flux-" + args.model) + if args.quantize: + nn.quantize(flux.flow, class_predicate=quantization_predicate) + nn.quantize(flux.t5, class_predicate=quantization_predicate) + nn.quantize(flux.clip, class_predicate=quantization_predicate) + + if args.gradient_checkpointing: + flux.gradient_checkpointing() + flux.flow.freeze() flux.linear_to_lora_layers(args.lora_rank, args.lora_blocks) @@ -254,8 +292,12 @@ if __name__ == "__main__": guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype) # An initial generation to compare - generate_progress_images(0, flux, args) + # generate_progress_images(0, flux, args) + flux.reload_text_encoders() + del flux.t5 + del flux.clip + mx.metal.reset_peak_memory() grads = None losses = [] tic = time.time() diff --git a/flux/flux/flux.py b/flux/flux/flux.py index 3fd044ac..b3c18230 100644 --- a/flux/flux/flux.py +++ b/flux/flux/flux.py @@ -7,6 +7,12 @@ import mlx.nn as nn from mlx.utils import tree_unflatten from tqdm import tqdm +from .layers import ( + DoubleStreamBlock, + SingleStreamBlock, + disable_gradient_checkpointing, + enable_gradient_checkpointing, +) from .lora import LoRALinear from .sampler import FluxSampler from .utils import ( @@ -234,7 +240,7 @@ class FluxPipeline: for i, block in zip(range(num_blocks), all_blocks): loras = [] for name, module in block.named_modules(): - if isinstance(module, nn.Linear): + if isinstance(module, (nn.Linear, nn.QuantizedLinear)): loras.append((name, LoRALinear.from_base(module, r=rank))) block.update_modules(tree_unflatten(loras)) @@ -244,3 +250,13 @@ class FluxPipeline: if isinstance(module, LoRALinear): fused_layers.append((name, module.fuse())) self.flow.update_modules(tree_unflatten(fused_layers)) + + def gradient_checkpointing(self, enable: bool = True): + """Replace the call function of SingleStreamBlock and DoubleStreamBlock + to a checkpointing one.""" + if enable: + enable_gradient_checkpointing(SingleStreamBlock) + enable_gradient_checkpointing(DoubleStreamBlock) + else: + disable_gradient_checkpointing(SingleStreamBlock) + disable_gradient_checkpointing(DoubleStreamBlock) diff --git a/flux/flux/layers.py b/flux/flux/layers.py index 12397904..eabd54ab 100644 --- a/flux/flux/layers.py +++ b/flux/flux/layers.py @@ -9,6 +9,37 @@ import mlx.core as mx import mlx.nn as nn +def enable_gradient_checkpointing(module_class): + if hasattr(module_class, "_original_call"): + raise ValueError( + f"Gradient checkpointing is already enabled for {module_class.__name__}" + ) + + fn = module_class.__call__ + module_class._original_call = fn + + def checkpointed_fn(module_instance, *args, **kwargs): + def inner_fn(params, *args, **kwargs): + module_instance.update(params) + return fn(module_instance, *args, **kwargs) + + return mx.checkpoint(inner_fn)( + module_instance.trainable_parameters(), *args, **kwargs + ) + + module_class.__call__ = checkpointed_fn + + +def disable_gradient_checkpointing(module_class): + if not hasattr(module_class, "_original_call"): + raise ValueError( + f"Gradient checkpointing is not enabled for {module_class.__name__}" + ) + + module_class.__call__ = module_class._original_call + delattr(module_class, "_original_call") + + def _rope(pos: mx.array, dim: int, theta: float): scale = mx.arange(0, dim, 2, dtype=mx.float32) / dim omega = 1.0 / (theta**scale) diff --git a/flux/flux/lora.py b/flux/flux/lora.py index b0c8ae56..4c44b9c3 100644 --- a/flux/flux/lora.py +++ b/flux/flux/lora.py @@ -9,12 +9,15 @@ import mlx.nn as nn class LoRALinear(nn.Module): @staticmethod def from_base( - linear: nn.Linear, + linear: nn.Module, r: int = 8, dropout: float = 0.0, scale: float = 1.0, ): output_dims, input_dims = linear.weight.shape + if isinstance(linear, nn.QuantizedLinear): + input_dims *= 32 // linear.bits + lora_lin = LoRALinear( input_dims=input_dims, output_dims=output_dims, @@ -26,6 +29,9 @@ class LoRALinear(nn.Module): return lora_lin def fuse(self): + if isinstance(self.linear, nn.QuantizedLinear): + raise NotImplementedError("Cannot fuse QLoRA layers yet.") + linear = self.linear bias = "bias" in linear weight = linear.weight