From fe2291710fbc0dd688dcbd25c0b64f3f61e18cfc Mon Sep 17 00:00:00 2001 From: Pawel Kowalski Date: Tue, 19 Dec 2023 13:21:34 +0100 Subject: [PATCH] partial implementation of SD XL, incl. CLIP with projection, but doesn't produce good output --- stable_diffusion/stable_diffusion/__init__.py | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/stable_diffusion/stable_diffusion/__init__.py b/stable_diffusion/stable_diffusion/__init__.py index 899336e8..aad4636e 100644 --- a/stable_diffusion/stable_diffusion/__init__.py +++ b/stable_diffusion/stable_diffusion/__init__.py @@ -95,3 +95,80 @@ class StableDiffusion: x = self.autoencoder.decode(x_t / self.autoencoder.scaling_factor) x = mx.minimum(1, mx.maximum(0, x / 2 + 0.5)) return x + + +class StableDiffusionXL(StableDiffusion): + def __init__(self, model: str = _DEFAULT_MODEL, float16: bool = False): + super().__init__(model, float16) + self.text_encoder_2 = load_text_encoder(model, float16, name="text_encoder_2") + self.tokenizer_2 = load_tokenizer(model, name="tokenizer_2") + + def generate_latents( + self, + text: str, + text_2: str = None, + n_images: int = 1, + num_steps: int = 50, + cfg_weight: float = 7.5, + negative_text: str = "", + negative_text_2: str = "", + latent_size: Tuple[int] = (64, 64), + seed=None, + ): + # Set the PRNG state + seed = seed or int(time.time()) + mx.random.seed(seed) + + # Tokenize the text + tokens = [self.tokenizer.tokenize(text)] + if cfg_weight > 1: + tokens += [self.tokenizer.tokenize(negative_text)] + lengths = [len(t) for t in tokens] + N = max(lengths) + tokens = [t + [0] * (N - len(t)) for t in tokens] + tokens = mx.array(tokens) + + # Tokenize the text_2 if exists, otherwise tokenize original text with tokenizer_2 + tokens2 = ( + [self.tokenizer_2.tokenize(text_2)] + if text_2 + else [self.tokenizer_2.tokenize(text)] + ) + if cfg_weight > 1: + tokens2 += ( + [self.tokenizer.tokenize(negative_text_2)] + if text_2 + else [self.tokenizer.tokenize(negative_text)] + ) + lengths2 = [len(t) for t in tokens2] + N = max(lengths2) + tokens2 = [t + [0] * (N - len(t)) for t in tokens2] + tokens2 = mx.array(tokens2) + + # Compute the features + conditioning1 = self.text_encoder(tokens) + conditioning2 = self.text_encoder_2(tokens2) + conditioning = mx.concatenate([conditioning1, conditioning2], axis=2) + # Repeat the conditioning for each of the generated images + if n_images > 1: + conditioning = _repeat(conditioning, n_images, axis=0) + + # Create the latent variables + x_T = self.sampler.sample_prior( + (n_images, *latent_size, self.autoencoder.latent_channels), dtype=self.dtype + ) + + # Perform the denoising loop + x_t = x_T + for t, t_prev in self.sampler.timesteps(num_steps, dtype=self.dtype): + x_t_unet = mx.concatenate([x_t] * 2, axis=0) if cfg_weight > 1 else x_t + t_unet = mx.broadcast_to(t, [len(x_t_unet)]) + eps_pred = self.unet(x_t_unet, t_unet, encoder_x=conditioning) + + if cfg_weight > 1: + eps_text, eps_neg = eps_pred.split(2) + eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg) + + x_t_prev = self.sampler.step(eps_pred, x_t, t, t_prev) + x_t = x_t_prev + yield x_t