diff --git a/stable_diffusion/dreambooth.py b/stable_diffusion/dreambooth.py deleted file mode 100644 index 7be12aa1..00000000 --- a/stable_diffusion/dreambooth.py +++ /dev/null @@ -1,317 +0,0 @@ -# Copyright © 2024 Apple Inc. - -import argparse -import time -from functools import partial -from pathlib import Path - -import mlx.core as mx -import mlx.optimizers as optim -import numpy as np -from PIL import Image -from tqdm import tqdm - -from stable_diffusion import StableDiffusion - - -def extract_latent_vectors(sd, image_folder): - latents = [] - for image in tqdm(Path(image_folder).iterdir()): - img = Image.open(image) - img = mx.array(np.array(img)) - img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1 - x_0, _ = sd.autoencoder.encode(img[None]) - mx.eval(x_0) - latents.append(x_0) - return mx.concatenate(latents) - - -def generate_latents(sd, n_images, prompt, steps, cfg_weight, seed=None, leave=True): - latents = sd.generate_latents( - prompt, - n_images=n_images, - cfg_weight=cfg_weight, - num_steps=steps, - seed=seed, - negative_text="", - ) - for x_t in tqdm(latents, total=args.progress_steps, leave=leave): - mx.eval(x_t) - - return x_t - - -def decode_latents(sd, x): - decoded = [] - for i in tqdm(range(len(x))): - decoded.append(sd.decode(x[i : i + 1])) - mx.eval(decoded[-1]) - return mx.concatenate(decoded, axis=0) - - -def generate_progress_images(iteration, sd, args): - out_dir = Path(args.output_dir) - out_dir.mkdir(parents=True, exist_ok=True) - out_file = out_dir / f"out_{iteration:03d}.png" - print(f"Generating {str(out_file)}") - # Generate the latent vectors using diffusion - n_images = 4 - latents = generate_latents( - sd, - n_images, - args.progress_prompt, - args.progress_steps, - args.progress_cfg, - seed=42, - ) - - # Arrange them on a grid - n_rows = 2 - x = decode_latents(sd, latents) - x = mx.pad(x, [(0, 0), (8, 8), (8, 8), (0, 0)]) - B, H, W, C = x.shape - x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4) - x = x.reshape(n_rows * H, B // n_rows * W, C) - x = (x * 255).astype(mx.uint8) - - # Save them to disc - im = Image.fromarray(np.array(x)) - im.save(out_file) - - -def save_checkpoints(iteration, sd, args): - out_dir = Path(args.output_dir) - out_dir.mkdir(parents=True, exist_ok=True) - unet_file = str(out_dir / f"unet_{iteration:03d}.safetensors") - print(f"Saving {unet_file}") - sd.unet.save_weights(unet_file) - if args.train_text_encoder: - te_file = str(out_dir / f"text_encoder_{iteration:03d}.safetensors") - print(f"Saving {te_file}") - sd.text_encoder.save_weights(te_file) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Finetune SD to generate images with a specific subject" - ) - - parser.add_argument( - "--model", - default="CompVis/stable-diffusion-v1-4", - choices=[ - "stabilityai/stable-diffusion-2-1-base", - "CompVis/stable-diffusion-v1-4", - ], - help="Which stable diffusion model to train", - ) - parser.add_argument( - "--train-text-encoder", - action="store_true", - help="Train the text encoder as well as the UNet", - ) - parser.add_argument( - "--iterations", - type=int, - default=400, - help="How many iterations to train for", - ) - parser.add_argument( - "--batch_size", - type=int, - default=4, - help="The batch size to use when training the stable diffusion model", - ) - parser.add_argument( - "--progress-prompt", - help="Use this prompt when generating images for evaluation", - ) - parser.add_argument( - "--progress-cfg", - type=float, - default=7.5, - help="Use this classifier free guidance weight when generating images for evaluation", - ) - parser.add_argument( - "--progress-steps", - type=int, - default=50, - help="Use this many steps when generating images for evaluation", - ) - parser.add_argument( - "--progress-every", - type=int, - default=50, - help="Generate images every PROGRESS_EVERY steps", - ) - parser.add_argument( - "--checkpoint-every", - type=int, - default=50, - help="Save the model every CHECKPOINT_EVERY steps", - ) - parser.add_argument( - "--learning-rate", type=float, default="1e-6", help="Learning rate for training" - ) - parser.add_argument( - "--predict-x0", - action="store_false", - dest="predict_noise", - help="Compute the loss on x0 instead of the noise", - ) - parser.add_argument( - "--prior-preservation-weight", - type=float, - default=0, - help="The loss weight for the prior preservation batches", - ) - parser.add_argument( - "--prior-preservation-images", - type=int, - default=100, - help="How many prior preservation images to use", - ) - parser.add_argument( - "--prior-preservation-prompt", help="The prompt to use for prior preservation" - ) - parser.add_argument( - "--prior-preservation-steps", - default=50, - type=int, - help="How many steps to use to generate prior images", - ) - parser.add_argument( - "--prior-preservation-cfg", - default=7.5, - type=float, - help="The CFG weight to use to generate prior images", - ) - parser.add_argument( - "--output-dir", default="mlx_output", help="Folder to save the checkpoints in" - ) - - parser.add_argument("prompt") - parser.add_argument("image_folder") - - args = parser.parse_args() - - args.progress_prompt = args.progress_prompt or args.prompt - args.prior_preservation_prompt = args.prior_preservation_prompt or args.prompt - - sd = StableDiffusion(args.model) - sd.ensure_models_are_loaded() - - optimizer = optim.Adam(learning_rate=args.learning_rate) - - def loss_fn(params, text, x, weights): - sd.unet.update(params["unet"]) - if "text_encoder" in params: - sd.text_encoder.update(params["text_encoder"]) - loss = sd.training_loss(text, x, pred_noise=args.predict_noise) - loss = loss * weights - return loss.mean() - - state = [sd.unet.state, optimizer.state, mx.random.state] - if args.train_text_encoder: - state.append(sd.text_encoder.state) - - @partial(mx.compile, inputs=state, outputs=state) - def step(text, x, prior_text=None, prior_x=None, prior_weight=None): - # Get the parameters we are calculating gradients for - params = {"unet": sd.unet.trainable_parameters()} - if args.train_text_encoder: - params["text_encoder"] = sd.text_encoder.trainable_parameters() - - # Combine the prior preservation if needed - if prior_weight is None: - weights = mx.ones(len(x)) - else: - weights = mx.array([1] * len(x) + [prior_weight] * len(prior_x)) - x = mx.concatenate([x, prior_x]) - text = mx.concatenate([text, prior_text]) - - # Calculate the loss and new parameters - loss, grads = mx.value_and_grad(loss_fn)(params, text, x, weights) - params = optimizer.apply_gradients(grads, params) - - # Update the models - sd.unet.update(params["unet"]) - if "text_encoder" in params: - sd.text_encoder.update(params["text_encoder"]) - - return loss - - print("Encoding training images to latent space") - x = extract_latent_vectors(sd, args.image_folder) - text = sd._tokenize(sd.tokenizer, args.prompt, None) - text = mx.repeat(text, len(x), axis=0) - prior_x = None - prior_text = None - - if args.prior_preservation_weight > 0: - print("Generating prior images") - batch_size = 4 - prior_x = mx.zeros( - ( - batch_size - * (args.prior_preservation_images + batch_size - 1) - // batch_size, - *x.shape[1:], - ), - dtype=x.dtype, - ) - prior_text = sd._tokenize(sd.tokenizer, args.prior_preservation_prompt, None) - prior_text = mx.repeat(prior_text, len(prior_x), axis=0) - for i in tqdm(range(0, args.prior_preservation_images, batch_size)): - prior_batch = generate_latents( - sd, - batch_size, - args.prior_preservation_prompt, - args.prior_preservation_steps, - args.prior_preservation_cfg, - leave=False, - ) - prior_x[i : i + batch_size] = prior_batch - mx.async_eval(prior_x) - - # An initial generation to compare - generate_progress_images(0, sd, args) - - losses = [] - tic = time.time() - for i in range(args.iterations): - indices = (mx.random.uniform(shape=(args.batch_size,)) * len(x)).astype( - mx.uint32 - ) - if args.prior_preservation_weight > 0.0: - prior_indices = ( - mx.random.uniform(shape=(args.batch_size,)) * len(prior_x) - ).astype(mx.uint32) - loss = step( - text[indices], - x[indices], - prior_text[prior_indices], - prior_x[prior_indices], - args.prior_preservation_weight, - ) - else: - loss = step(text[indices], x[indices]) - mx.eval(loss, state) - losses.append(loss.item()) - - if (i + 1) % 10 == 0: - toc = time.time() - print( - f"Iter: {i+1} Loss: {sum(losses) / 10:.3f} " - f"It/s: {10 / (toc - tic):.3f}" - ) - - if (i + 1) % args.progress_every == 0: - generate_progress_images(i + 1, sd, args) - - if (i + 1) % args.checkpoint_every == 0: - save_checkpoints(i + 1, sd, args) - - if (i + 1) % 10 == 0: - losses = [] - tic = time.time() diff --git a/stable_diffusion/stable_diffusion/__init__.py b/stable_diffusion/stable_diffusion/__init__.py index c28f275f..cc9dd9a8 100644 --- a/stable_diffusion/stable_diffusion/__init__.py +++ b/stable_diffusion/stable_diffusion/__init__.py @@ -168,25 +168,6 @@ class StableDiffusion: x = mx.clip(x / 2 + 0.5, 0, 1) return x - def training_loss(self, text: mx.array, x_0: mx.array, pred_noise: bool = True): - # Get the text conditioning - conditioning = self.text_encoder(text).last_hidden_state - - # Get the samples to be denoised - t = mx.random.uniform(shape=(len(x_0),)) * self.sampler.max_time - eps = mx.random.normal(x_0.shape) - x_t = self.sampler.add_noise(x_0, t, noise=eps) - x_t = mx.stop_gradient(x_t) - - # Do the denoising - eps_pred = self.unet(x_t, t, encoder_x=conditioning, text_time=None) - - if pred_noise: - return (eps_pred - eps).square().mean((1, 2, 3)) - else: - x_0_pred = self.sampler.step(eps_pred, x_t, t, mx.zeros_like(t)) - return (x_0_pred - x_0).square().mean((1, 2, 3)) - class StableDiffusionXL(StableDiffusion): def __init__(self, model: str = _DEFAULT_MODEL, float16: bool = False): diff --git a/stable_diffusion/stable_diffusion/model_io.py b/stable_diffusion/stable_diffusion/model_io.py index 2389cecf..2c2227db 100644 --- a/stable_diffusion/stable_diffusion/model_io.py +++ b/stable_diffusion/stable_diffusion/model_io.py @@ -43,17 +43,6 @@ _MODELS = { "tokenizer_vocab": "tokenizer/vocab.json", "tokenizer_merges": "tokenizer/merges.txt", }, - "CompVis/stable-diffusion-v1-4": { - "unet_config": "unet/config.json", - "unet": "unet/diffusion_pytorch_model.safetensors", - "text_encoder_config": "text_encoder/config.json", - "text_encoder": "text_encoder/model.safetensors", - "vae_config": "vae/config.json", - "vae": "vae/diffusion_pytorch_model.safetensors", - "diffusion_config": "scheduler/scheduler_config.json", - "tokenizer_vocab": "tokenizer/vocab.json", - "tokenizer_merges": "tokenizer/merges.txt", - }, } @@ -292,7 +281,7 @@ def load_autoencoder(key: str = _DEFAULT_MODEL, float16: bool = False): latent_channels_in=config["latent_channels"], block_out_channels=config["block_out_channels"], layers_per_block=config["layers_per_block"], - norm_num_groups=config.get("norm_num_groups", 32), + norm_num_groups=config["norm_num_groups"], scaling_factor=config.get("scaling_factor", 0.18215), ) ) diff --git a/stable_diffusion/stable_diffusion/sampler.py b/stable_diffusion/stable_diffusion/sampler.py index 265b8748..ff4433d0 100644 --- a/stable_diffusion/stable_diffusion/sampler.py +++ b/stable_diffusion/stable_diffusion/sampler.py @@ -59,17 +59,13 @@ class SimpleEulerSampler: noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt() ).astype(dtype) - def add_noise(self, x, t, noise=None, key=None): - noise = noise if noise is not None else mx.random.normal(x.shape, key=key) + def add_noise(self, x, t, key=None): + noise = mx.random.normal(x.shape, key=key) s = self.sigmas(t) return (x + noise * s) * (s.square() + 1).rsqrt() def sigmas(self, t): - s = _interp(self._sigmas, t) - if t.ndim == 0: - return s - else: - return s[:, None, None, None] + return _interp(self._sigmas, t) def timesteps(self, num_steps: int, start_time=None, dtype=mx.float32): start_time = start_time or (len(self._sigmas) - 1) diff --git a/stable_diffusion/txt2image.py b/stable_diffusion/txt2image.py index e2c88a3d..26c757f8 100644 --- a/stable_diffusion/txt2image.py +++ b/stable_diffusion/txt2image.py @@ -15,7 +15,7 @@ if __name__ == "__main__": description="Generate images from a textual prompt using stable diffusion" ) parser.add_argument("prompt") - parser.add_argument("--model", choices=["sd", "sd-compvis", "sdxl"], default="sdxl") + parser.add_argument("--model", choices=["sd", "sdxl"], default="sdxl") parser.add_argument("--n_images", type=int, default=4) parser.add_argument("--steps", type=int) parser.add_argument("--cfg", type=float) @@ -28,8 +28,6 @@ if __name__ == "__main__": parser.add_argument("--output", default="out.png") parser.add_argument("--seed", type=int) parser.add_argument("--verbose", "-v", action="store_true") - parser.add_argument("--unet-path") - parser.add_argument("--text-encoder-path") args = parser.parse_args() # Load the models @@ -46,23 +44,14 @@ if __name__ == "__main__": args.cfg = args.cfg or 0.0 args.steps = args.steps or 2 else: - model_path = "stabilityai/stable-diffusion-2-1-base" - if args.model == "sd-compvis": - model_path = "CompVis/stable-diffusion-v1-4" - sd = StableDiffusion(model_path, float16=args.float16) - - # Load the custom unet and text encoder if requested - if args.unet_path: - sd.unet.load_weights(args.unet_path) - if args.text_encoder_path: - sd.text_encoder.load_weights(args.text_encoder_path) - + sd = StableDiffusion( + "stabilityai/stable-diffusion-2-1-base", float16=args.float16 + ) if args.quantize: nn.quantize( sd.text_encoder, class_predicate=lambda _, m: isinstance(m, nn.Linear) ) nn.quantize(sd.unet, group_size=32, bits=8) - args.cfg = args.cfg or 7.5 args.steps = args.steps or 50