diff --git a/stable_diffusion/stable_diffusion/__init__.py b/stable_diffusion/stable_diffusion/__init__.py index a4beffb4..cc9dd9a8 100644 --- a/stable_diffusion/stable_diffusion/__init__.py +++ b/stable_diffusion/stable_diffusion/__init__.py @@ -224,6 +224,7 @@ class StableDiffusionXL(StableDiffusion): if n_images > 1: conditioning = mx.repeat(conditioning, n_images, axis=0) + pooled_conditioning = mx.repeat(pooled_conditioning, n_images, axis=0) return conditioning, pooled_conditioning