mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Start memory-efficient flux finetuning branch
This commit is contained in:
parent
4971462bf0
commit
67607a8e13
@ -16,6 +16,10 @@ from PIL import Image
|
||||
from flux import FluxPipeline, Trainer, load_dataset
|
||||
|
||||
|
||||
def quantization_predicate(name, m):
|
||||
return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0
|
||||
|
||||
|
||||
def generate_progress_images(iteration, flux, args):
|
||||
"""Generate images to monitor the progress of the finetuning."""
|
||||
out_dir = Path(args.output_dir)
|
||||
@ -24,11 +28,10 @@ def generate_progress_images(iteration, flux, args):
|
||||
print(f"Generating {str(out_file)}", flush=True)
|
||||
|
||||
# Generate some images and arrange them in a grid
|
||||
n_rows = 2
|
||||
n_images = 4
|
||||
n_rows = 2 if args.progress_num_images % 2 == 0 else 1
|
||||
x = flux.generate_images(
|
||||
args.progress_prompt,
|
||||
n_images,
|
||||
args.progress_num_images,
|
||||
args.progress_steps,
|
||||
)
|
||||
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
|
||||
@ -42,6 +45,16 @@ def generate_progress_images(iteration, flux, args):
|
||||
im = Image.fromarray(np.array(x))
|
||||
im.save(out_file)
|
||||
|
||||
# generate_images reloads the text encoders in order to remove them from
|
||||
# RAM. In memory pressured environments this will swap the flow transformer
|
||||
# to disk and back to RAM during generation.
|
||||
#
|
||||
# However, we have to requantize the text encoders for the next time we
|
||||
# want to use them.
|
||||
if args.quantize:
|
||||
nn.quantize(flux.t5, class_predicate=quantization_predicate)
|
||||
nn.quantize(flux.clip, class_predicate=quantization_predicate)
|
||||
|
||||
|
||||
def save_adapters(iteration, flux, args):
|
||||
out_dir = Path(args.output_dir)
|
||||
@ -74,6 +87,17 @@ def setup_arg_parser():
|
||||
],
|
||||
help="Which flux model to train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantize",
|
||||
"-q",
|
||||
action="store_true",
|
||||
help="Quantize the models to reduce the memory required for training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient-checkpointing",
|
||||
action="store_true",
|
||||
help="Enable gradient checkpointing to reduce the memory required for training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guidance", type=float, default=4.0, help="The guidance factor to use."
|
||||
)
|
||||
@ -118,6 +142,12 @@ def setup_arg_parser():
|
||||
default=50,
|
||||
help="Generate images every PROGRESS_EVERY steps",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--progress-num-images",
|
||||
type=int,
|
||||
default=4,
|
||||
help="How many progress images to generate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint-every",
|
||||
type=int,
|
||||
@ -162,6 +192,14 @@ if __name__ == "__main__":
|
||||
# initial weights.
|
||||
mx.random.seed(0x0F0F0F0F)
|
||||
flux = FluxPipeline("flux-" + args.model)
|
||||
if args.quantize:
|
||||
nn.quantize(flux.flow, class_predicate=quantization_predicate)
|
||||
nn.quantize(flux.t5, class_predicate=quantization_predicate)
|
||||
nn.quantize(flux.clip, class_predicate=quantization_predicate)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
flux.gradient_checkpointing()
|
||||
|
||||
flux.flow.freeze()
|
||||
flux.linear_to_lora_layers(args.lora_rank, args.lora_blocks)
|
||||
|
||||
@ -254,8 +292,12 @@ if __name__ == "__main__":
|
||||
guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype)
|
||||
|
||||
# An initial generation to compare
|
||||
generate_progress_images(0, flux, args)
|
||||
# generate_progress_images(0, flux, args)
|
||||
flux.reload_text_encoders()
|
||||
del flux.t5
|
||||
del flux.clip
|
||||
|
||||
mx.metal.reset_peak_memory()
|
||||
grads = None
|
||||
losses = []
|
||||
tic = time.time()
|
||||
|
@ -7,6 +7,12 @@ import mlx.nn as nn
|
||||
from mlx.utils import tree_unflatten
|
||||
from tqdm import tqdm
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
SingleStreamBlock,
|
||||
disable_gradient_checkpointing,
|
||||
enable_gradient_checkpointing,
|
||||
)
|
||||
from .lora import LoRALinear
|
||||
from .sampler import FluxSampler
|
||||
from .utils import (
|
||||
@ -234,7 +240,7 @@ class FluxPipeline:
|
||||
for i, block in zip(range(num_blocks), all_blocks):
|
||||
loras = []
|
||||
for name, module in block.named_modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
if isinstance(module, (nn.Linear, nn.QuantizedLinear)):
|
||||
loras.append((name, LoRALinear.from_base(module, r=rank)))
|
||||
block.update_modules(tree_unflatten(loras))
|
||||
|
||||
@ -244,3 +250,13 @@ class FluxPipeline:
|
||||
if isinstance(module, LoRALinear):
|
||||
fused_layers.append((name, module.fuse()))
|
||||
self.flow.update_modules(tree_unflatten(fused_layers))
|
||||
|
||||
def gradient_checkpointing(self, enable: bool = True):
|
||||
"""Replace the call function of SingleStreamBlock and DoubleStreamBlock
|
||||
to a checkpointing one."""
|
||||
if enable:
|
||||
enable_gradient_checkpointing(SingleStreamBlock)
|
||||
enable_gradient_checkpointing(DoubleStreamBlock)
|
||||
else:
|
||||
disable_gradient_checkpointing(SingleStreamBlock)
|
||||
disable_gradient_checkpointing(DoubleStreamBlock)
|
||||
|
@ -9,6 +9,37 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
def enable_gradient_checkpointing(module_class):
|
||||
if hasattr(module_class, "_original_call"):
|
||||
raise ValueError(
|
||||
f"Gradient checkpointing is already enabled for {module_class.__name__}"
|
||||
)
|
||||
|
||||
fn = module_class.__call__
|
||||
module_class._original_call = fn
|
||||
|
||||
def checkpointed_fn(module_instance, *args, **kwargs):
|
||||
def inner_fn(params, *args, **kwargs):
|
||||
module_instance.update(params)
|
||||
return fn(module_instance, *args, **kwargs)
|
||||
|
||||
return mx.checkpoint(inner_fn)(
|
||||
module_instance.trainable_parameters(), *args, **kwargs
|
||||
)
|
||||
|
||||
module_class.__call__ = checkpointed_fn
|
||||
|
||||
|
||||
def disable_gradient_checkpointing(module_class):
|
||||
if not hasattr(module_class, "_original_call"):
|
||||
raise ValueError(
|
||||
f"Gradient checkpointing is not enabled for {module_class.__name__}"
|
||||
)
|
||||
|
||||
module_class.__call__ = module_class._original_call
|
||||
delattr(module_class, "_original_call")
|
||||
|
||||
|
||||
def _rope(pos: mx.array, dim: int, theta: float):
|
||||
scale = mx.arange(0, dim, 2, dtype=mx.float32) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
|
@ -9,12 +9,15 @@ import mlx.nn as nn
|
||||
class LoRALinear(nn.Module):
|
||||
@staticmethod
|
||||
def from_base(
|
||||
linear: nn.Linear,
|
||||
linear: nn.Module,
|
||||
r: int = 8,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 1.0,
|
||||
):
|
||||
output_dims, input_dims = linear.weight.shape
|
||||
if isinstance(linear, nn.QuantizedLinear):
|
||||
input_dims *= 32 // linear.bits
|
||||
|
||||
lora_lin = LoRALinear(
|
||||
input_dims=input_dims,
|
||||
output_dims=output_dims,
|
||||
@ -26,6 +29,9 @@ class LoRALinear(nn.Module):
|
||||
return lora_lin
|
||||
|
||||
def fuse(self):
|
||||
if isinstance(self.linear, nn.QuantizedLinear):
|
||||
raise NotImplementedError("Cannot fuse QLoRA layers yet.")
|
||||
|
||||
linear = self.linear
|
||||
bias = "bias" in linear
|
||||
weight = linear.weight
|
||||
|
Loading…
Reference in New Issue
Block a user