From eff6690952847386aa3cc375b4ac83decc886868 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 9 Apr 2024 06:06:41 -0700 Subject: [PATCH] Fix CFG for SDXL (#667) --- stable_diffusion/stable_diffusion/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/stable_diffusion/stable_diffusion/__init__.py b/stable_diffusion/stable_diffusion/__init__.py index a4beffb4..cc9dd9a8 100644 --- a/stable_diffusion/stable_diffusion/__init__.py +++ b/stable_diffusion/stable_diffusion/__init__.py @@ -224,6 +224,7 @@ class StableDiffusionXL(StableDiffusion): if n_images > 1: conditioning = mx.repeat(conditioning, n_images, axis=0) + pooled_conditioning = mx.repeat(pooled_conditioning, n_images, axis=0) return conditioning, pooled_conditioning