Fix CFG for SDXL (#667)

This commit is contained in:
Angelos Katharopoulos 2024-04-09 06:06:41 -07:00 committed by GitHub
parent 1278994b56
commit eff6690952
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -224,6 +224,7 @@ class StableDiffusionXL(StableDiffusion):
if n_images > 1: if n_images > 1:
conditioning = mx.repeat(conditioning, n_images, axis=0) conditioning = mx.repeat(conditioning, n_images, axis=0)
pooled_conditioning = mx.repeat(pooled_conditioning, n_images, axis=0)
return conditioning, pooled_conditioning return conditioning, pooled_conditioning