mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 20:04:38 +08:00
partial implementation of SD XL, incl. CLIP with projection, but doesn't produce good output
This commit is contained in:
@@ -95,3 +95,80 @@ class StableDiffusion:
|
||||
x = self.autoencoder.decode(x_t / self.autoencoder.scaling_factor)
|
||||
x = mx.minimum(1, mx.maximum(0, x / 2 + 0.5))
|
||||
return x
|
||||
|
||||
|
||||
class StableDiffusionXL(StableDiffusion):
|
||||
def __init__(self, model: str = _DEFAULT_MODEL, float16: bool = False):
|
||||
super().__init__(model, float16)
|
||||
self.text_encoder_2 = load_text_encoder(model, float16, name="text_encoder_2")
|
||||
self.tokenizer_2 = load_tokenizer(model, name="tokenizer_2")
|
||||
|
||||
def generate_latents(
|
||||
self,
|
||||
text: str,
|
||||
text_2: str = None,
|
||||
n_images: int = 1,
|
||||
num_steps: int = 50,
|
||||
cfg_weight: float = 7.5,
|
||||
negative_text: str = "",
|
||||
negative_text_2: str = "",
|
||||
latent_size: Tuple[int] = (64, 64),
|
||||
seed=None,
|
||||
):
|
||||
# Set the PRNG state
|
||||
seed = seed or int(time.time())
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Tokenize the text
|
||||
tokens = [self.tokenizer.tokenize(text)]
|
||||
if cfg_weight > 1:
|
||||
tokens += [self.tokenizer.tokenize(negative_text)]
|
||||
lengths = [len(t) for t in tokens]
|
||||
N = max(lengths)
|
||||
tokens = [t + [0] * (N - len(t)) for t in tokens]
|
||||
tokens = mx.array(tokens)
|
||||
|
||||
# Tokenize the text_2 if exists, otherwise tokenize original text with tokenizer_2
|
||||
tokens2 = (
|
||||
[self.tokenizer_2.tokenize(text_2)]
|
||||
if text_2
|
||||
else [self.tokenizer_2.tokenize(text)]
|
||||
)
|
||||
if cfg_weight > 1:
|
||||
tokens2 += (
|
||||
[self.tokenizer.tokenize(negative_text_2)]
|
||||
if text_2
|
||||
else [self.tokenizer.tokenize(negative_text)]
|
||||
)
|
||||
lengths2 = [len(t) for t in tokens2]
|
||||
N = max(lengths2)
|
||||
tokens2 = [t + [0] * (N - len(t)) for t in tokens2]
|
||||
tokens2 = mx.array(tokens2)
|
||||
|
||||
# Compute the features
|
||||
conditioning1 = self.text_encoder(tokens)
|
||||
conditioning2 = self.text_encoder_2(tokens2)
|
||||
conditioning = mx.concatenate([conditioning1, conditioning2], axis=2)
|
||||
# Repeat the conditioning for each of the generated images
|
||||
if n_images > 1:
|
||||
conditioning = _repeat(conditioning, n_images, axis=0)
|
||||
|
||||
# Create the latent variables
|
||||
x_T = self.sampler.sample_prior(
|
||||
(n_images, *latent_size, self.autoencoder.latent_channels), dtype=self.dtype
|
||||
)
|
||||
|
||||
# Perform the denoising loop
|
||||
x_t = x_T
|
||||
for t, t_prev in self.sampler.timesteps(num_steps, dtype=self.dtype):
|
||||
x_t_unet = mx.concatenate([x_t] * 2, axis=0) if cfg_weight > 1 else x_t
|
||||
t_unet = mx.broadcast_to(t, [len(x_t_unet)])
|
||||
eps_pred = self.unet(x_t_unet, t_unet, encoder_x=conditioning)
|
||||
|
||||
if cfg_weight > 1:
|
||||
eps_text, eps_neg = eps_pred.split(2)
|
||||
eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg)
|
||||
|
||||
x_t_prev = self.sampler.step(eps_pred, x_t, t, t_prev)
|
||||
x_t = x_t_prev
|
||||
yield x_t
|
||||
|
Reference in New Issue
Block a user