From fe2291710fbc0dd688dcbd25c0b64f3f61e18cfc Mon Sep 17 00:00:00 2001
From: Pawel Kowalski
Date: Tue, 19 Dec 2023 13:21:34 +0100
Subject: [PATCH] partial implementation of SD XL, incl. CLIP with projection,
but doesn't produce good output
---
stable_diffusion/stable_diffusion/__init__.py | 77 +++++++++++++++++++
1 file changed, 77 insertions(+)
diff --git a/stable_diffusion/stable_diffusion/__init__.py b/stable_diffusion/stable_diffusion/__init__.py
index 899336e8..aad4636e 100644
--- a/stable_diffusion/stable_diffusion/__init__.py
+++ b/stable_diffusion/stable_diffusion/__init__.py
@@ -95,3 +95,80 @@ class StableDiffusion:
x = self.autoencoder.decode(x_t / self.autoencoder.scaling_factor)
x = mx.minimum(1, mx.maximum(0, x / 2 + 0.5))
return x
+
+
+class StableDiffusionXL(StableDiffusion):
+ def __init__(self, model: str = _DEFAULT_MODEL, float16: bool = False):
+ super().__init__(model, float16)
+ self.text_encoder_2 = load_text_encoder(model, float16, name="text_encoder_2")
+ self.tokenizer_2 = load_tokenizer(model, name="tokenizer_2")
+
+ def generate_latents(
+ self,
+ text: str,
+ text_2: str = None,
+ n_images: int = 1,
+ num_steps: int = 50,
+ cfg_weight: float = 7.5,
+ negative_text: str = "",
+ negative_text_2: str = "",
+ latent_size: Tuple[int] = (64, 64),
+ seed=None,
+ ):
+ # Set the PRNG state
+ seed = seed or int(time.time())
+ mx.random.seed(seed)
+
+ # Tokenize the text
+ tokens = [self.tokenizer.tokenize(text)]
+ if cfg_weight > 1:
+ tokens += [self.tokenizer.tokenize(negative_text)]
+ lengths = [len(t) for t in tokens]
+ N = max(lengths)
+ tokens = [t + [0] * (N - len(t)) for t in tokens]
+ tokens = mx.array(tokens)
+
+ # Tokenize the text_2 if exists, otherwise tokenize original text with tokenizer_2
+ tokens2 = (
+ [self.tokenizer_2.tokenize(text_2)]
+ if text_2
+ else [self.tokenizer_2.tokenize(text)]
+ )
+ if cfg_weight > 1:
+ tokens2 += (
+ [self.tokenizer.tokenize(negative_text_2)]
+ if text_2
+ else [self.tokenizer.tokenize(negative_text)]
+ )
+ lengths2 = [len(t) for t in tokens2]
+ N = max(lengths2)
+ tokens2 = [t + [0] * (N - len(t)) for t in tokens2]
+ tokens2 = mx.array(tokens2)
+
+ # Compute the features
+ conditioning1 = self.text_encoder(tokens)
+ conditioning2 = self.text_encoder_2(tokens2)
+ conditioning = mx.concatenate([conditioning1, conditioning2], axis=2)
+ # Repeat the conditioning for each of the generated images
+ if n_images > 1:
+ conditioning = _repeat(conditioning, n_images, axis=0)
+
+ # Create the latent variables
+ x_T = self.sampler.sample_prior(
+ (n_images, *latent_size, self.autoencoder.latent_channels), dtype=self.dtype
+ )
+
+ # Perform the denoising loop
+ x_t = x_T
+ for t, t_prev in self.sampler.timesteps(num_steps, dtype=self.dtype):
+ x_t_unet = mx.concatenate([x_t] * 2, axis=0) if cfg_weight > 1 else x_t
+ t_unet = mx.broadcast_to(t, [len(x_t_unet)])
+ eps_pred = self.unet(x_t_unet, t_unet, encoder_x=conditioning)
+
+ if cfg_weight > 1:
+ eps_text, eps_neg = eps_pred.split(2)
+ eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg)
+
+ x_t_prev = self.sampler.step(eps_pred, x_t, t, t_prev)
+ x_t = x_t_prev
+ yield x_t