mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Start memory-efficient flux finetuning branch
This commit is contained in:
parent
4971462bf0
commit
67607a8e13
@ -16,6 +16,10 @@ from PIL import Image
|
|||||||
from flux import FluxPipeline, Trainer, load_dataset
|
from flux import FluxPipeline, Trainer, load_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def quantization_predicate(name, m):
|
||||||
|
return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0
|
||||||
|
|
||||||
|
|
||||||
def generate_progress_images(iteration, flux, args):
|
def generate_progress_images(iteration, flux, args):
|
||||||
"""Generate images to monitor the progress of the finetuning."""
|
"""Generate images to monitor the progress of the finetuning."""
|
||||||
out_dir = Path(args.output_dir)
|
out_dir = Path(args.output_dir)
|
||||||
@ -24,11 +28,10 @@ def generate_progress_images(iteration, flux, args):
|
|||||||
print(f"Generating {str(out_file)}", flush=True)
|
print(f"Generating {str(out_file)}", flush=True)
|
||||||
|
|
||||||
# Generate some images and arrange them in a grid
|
# Generate some images and arrange them in a grid
|
||||||
n_rows = 2
|
n_rows = 2 if args.progress_num_images % 2 == 0 else 1
|
||||||
n_images = 4
|
|
||||||
x = flux.generate_images(
|
x = flux.generate_images(
|
||||||
args.progress_prompt,
|
args.progress_prompt,
|
||||||
n_images,
|
args.progress_num_images,
|
||||||
args.progress_steps,
|
args.progress_steps,
|
||||||
)
|
)
|
||||||
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
|
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
|
||||||
@ -42,6 +45,16 @@ def generate_progress_images(iteration, flux, args):
|
|||||||
im = Image.fromarray(np.array(x))
|
im = Image.fromarray(np.array(x))
|
||||||
im.save(out_file)
|
im.save(out_file)
|
||||||
|
|
||||||
|
# generate_images reloads the text encoders in order to remove them from
|
||||||
|
# RAM. In memory pressured environments this will swap the flow transformer
|
||||||
|
# to disk and back to RAM during generation.
|
||||||
|
#
|
||||||
|
# However, we have to requantize the text encoders for the next time we
|
||||||
|
# want to use them.
|
||||||
|
if args.quantize:
|
||||||
|
nn.quantize(flux.t5, class_predicate=quantization_predicate)
|
||||||
|
nn.quantize(flux.clip, class_predicate=quantization_predicate)
|
||||||
|
|
||||||
|
|
||||||
def save_adapters(iteration, flux, args):
|
def save_adapters(iteration, flux, args):
|
||||||
out_dir = Path(args.output_dir)
|
out_dir = Path(args.output_dir)
|
||||||
@ -74,6 +87,17 @@ def setup_arg_parser():
|
|||||||
],
|
],
|
||||||
help="Which flux model to train",
|
help="Which flux model to train",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--quantize",
|
||||||
|
"-q",
|
||||||
|
action="store_true",
|
||||||
|
help="Quantize the models to reduce the memory required for training",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gradient-checkpointing",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable gradient checkpointing to reduce the memory required for training",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--guidance", type=float, default=4.0, help="The guidance factor to use."
|
"--guidance", type=float, default=4.0, help="The guidance factor to use."
|
||||||
)
|
)
|
||||||
@ -118,6 +142,12 @@ def setup_arg_parser():
|
|||||||
default=50,
|
default=50,
|
||||||
help="Generate images every PROGRESS_EVERY steps",
|
help="Generate images every PROGRESS_EVERY steps",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--progress-num-images",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="How many progress images to generate",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--checkpoint-every",
|
"--checkpoint-every",
|
||||||
type=int,
|
type=int,
|
||||||
@ -162,6 +192,14 @@ if __name__ == "__main__":
|
|||||||
# initial weights.
|
# initial weights.
|
||||||
mx.random.seed(0x0F0F0F0F)
|
mx.random.seed(0x0F0F0F0F)
|
||||||
flux = FluxPipeline("flux-" + args.model)
|
flux = FluxPipeline("flux-" + args.model)
|
||||||
|
if args.quantize:
|
||||||
|
nn.quantize(flux.flow, class_predicate=quantization_predicate)
|
||||||
|
nn.quantize(flux.t5, class_predicate=quantization_predicate)
|
||||||
|
nn.quantize(flux.clip, class_predicate=quantization_predicate)
|
||||||
|
|
||||||
|
if args.gradient_checkpointing:
|
||||||
|
flux.gradient_checkpointing()
|
||||||
|
|
||||||
flux.flow.freeze()
|
flux.flow.freeze()
|
||||||
flux.linear_to_lora_layers(args.lora_rank, args.lora_blocks)
|
flux.linear_to_lora_layers(args.lora_rank, args.lora_blocks)
|
||||||
|
|
||||||
@ -254,8 +292,12 @@ if __name__ == "__main__":
|
|||||||
guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype)
|
guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype)
|
||||||
|
|
||||||
# An initial generation to compare
|
# An initial generation to compare
|
||||||
generate_progress_images(0, flux, args)
|
# generate_progress_images(0, flux, args)
|
||||||
|
flux.reload_text_encoders()
|
||||||
|
del flux.t5
|
||||||
|
del flux.clip
|
||||||
|
|
||||||
|
mx.metal.reset_peak_memory()
|
||||||
grads = None
|
grads = None
|
||||||
losses = []
|
losses = []
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
|
@ -7,6 +7,12 @@ import mlx.nn as nn
|
|||||||
from mlx.utils import tree_unflatten
|
from mlx.utils import tree_unflatten
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from .layers import (
|
||||||
|
DoubleStreamBlock,
|
||||||
|
SingleStreamBlock,
|
||||||
|
disable_gradient_checkpointing,
|
||||||
|
enable_gradient_checkpointing,
|
||||||
|
)
|
||||||
from .lora import LoRALinear
|
from .lora import LoRALinear
|
||||||
from .sampler import FluxSampler
|
from .sampler import FluxSampler
|
||||||
from .utils import (
|
from .utils import (
|
||||||
@ -234,7 +240,7 @@ class FluxPipeline:
|
|||||||
for i, block in zip(range(num_blocks), all_blocks):
|
for i, block in zip(range(num_blocks), all_blocks):
|
||||||
loras = []
|
loras = []
|
||||||
for name, module in block.named_modules():
|
for name, module in block.named_modules():
|
||||||
if isinstance(module, nn.Linear):
|
if isinstance(module, (nn.Linear, nn.QuantizedLinear)):
|
||||||
loras.append((name, LoRALinear.from_base(module, r=rank)))
|
loras.append((name, LoRALinear.from_base(module, r=rank)))
|
||||||
block.update_modules(tree_unflatten(loras))
|
block.update_modules(tree_unflatten(loras))
|
||||||
|
|
||||||
@ -244,3 +250,13 @@ class FluxPipeline:
|
|||||||
if isinstance(module, LoRALinear):
|
if isinstance(module, LoRALinear):
|
||||||
fused_layers.append((name, module.fuse()))
|
fused_layers.append((name, module.fuse()))
|
||||||
self.flow.update_modules(tree_unflatten(fused_layers))
|
self.flow.update_modules(tree_unflatten(fused_layers))
|
||||||
|
|
||||||
|
def gradient_checkpointing(self, enable: bool = True):
|
||||||
|
"""Replace the call function of SingleStreamBlock and DoubleStreamBlock
|
||||||
|
to a checkpointing one."""
|
||||||
|
if enable:
|
||||||
|
enable_gradient_checkpointing(SingleStreamBlock)
|
||||||
|
enable_gradient_checkpointing(DoubleStreamBlock)
|
||||||
|
else:
|
||||||
|
disable_gradient_checkpointing(SingleStreamBlock)
|
||||||
|
disable_gradient_checkpointing(DoubleStreamBlock)
|
||||||
|
@ -9,6 +9,37 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def enable_gradient_checkpointing(module_class):
|
||||||
|
if hasattr(module_class, "_original_call"):
|
||||||
|
raise ValueError(
|
||||||
|
f"Gradient checkpointing is already enabled for {module_class.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
fn = module_class.__call__
|
||||||
|
module_class._original_call = fn
|
||||||
|
|
||||||
|
def checkpointed_fn(module_instance, *args, **kwargs):
|
||||||
|
def inner_fn(params, *args, **kwargs):
|
||||||
|
module_instance.update(params)
|
||||||
|
return fn(module_instance, *args, **kwargs)
|
||||||
|
|
||||||
|
return mx.checkpoint(inner_fn)(
|
||||||
|
module_instance.trainable_parameters(), *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
module_class.__call__ = checkpointed_fn
|
||||||
|
|
||||||
|
|
||||||
|
def disable_gradient_checkpointing(module_class):
|
||||||
|
if not hasattr(module_class, "_original_call"):
|
||||||
|
raise ValueError(
|
||||||
|
f"Gradient checkpointing is not enabled for {module_class.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
module_class.__call__ = module_class._original_call
|
||||||
|
delattr(module_class, "_original_call")
|
||||||
|
|
||||||
|
|
||||||
def _rope(pos: mx.array, dim: int, theta: float):
|
def _rope(pos: mx.array, dim: int, theta: float):
|
||||||
scale = mx.arange(0, dim, 2, dtype=mx.float32) / dim
|
scale = mx.arange(0, dim, 2, dtype=mx.float32) / dim
|
||||||
omega = 1.0 / (theta**scale)
|
omega = 1.0 / (theta**scale)
|
||||||
|
@ -9,12 +9,15 @@ import mlx.nn as nn
|
|||||||
class LoRALinear(nn.Module):
|
class LoRALinear(nn.Module):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_base(
|
def from_base(
|
||||||
linear: nn.Linear,
|
linear: nn.Module,
|
||||||
r: int = 8,
|
r: int = 8,
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
scale: float = 1.0,
|
scale: float = 1.0,
|
||||||
):
|
):
|
||||||
output_dims, input_dims = linear.weight.shape
|
output_dims, input_dims = linear.weight.shape
|
||||||
|
if isinstance(linear, nn.QuantizedLinear):
|
||||||
|
input_dims *= 32 // linear.bits
|
||||||
|
|
||||||
lora_lin = LoRALinear(
|
lora_lin = LoRALinear(
|
||||||
input_dims=input_dims,
|
input_dims=input_dims,
|
||||||
output_dims=output_dims,
|
output_dims=output_dims,
|
||||||
@ -26,6 +29,9 @@ class LoRALinear(nn.Module):
|
|||||||
return lora_lin
|
return lora_lin
|
||||||
|
|
||||||
def fuse(self):
|
def fuse(self):
|
||||||
|
if isinstance(self.linear, nn.QuantizedLinear):
|
||||||
|
raise NotImplementedError("Cannot fuse QLoRA layers yet.")
|
||||||
|
|
||||||
linear = self.linear
|
linear = self.linear
|
||||||
bias = "bias" in linear
|
bias = "bias" in linear
|
||||||
weight = linear.weight
|
weight = linear.weight
|
||||||
|
Loading…
Reference in New Issue
Block a user