Fix Qwen2 and SD (#441)

* fix qwen2

* version bump

* fix list shape
This commit is contained in:
Awni Hannun
2024-02-14 13:43:12 -08:00
committed by GitHub
parent e446598f62
commit 06ddb8414d
3 changed files with 2 additions and 3 deletions

View File

@@ -130,7 +130,7 @@ class StableDiffusion:
# Get the latents from the input image and add noise according to the
# start time.
x_0, _ = self.autoencoder.encode(image[None])
x_0 = mx.broadcast_to(x_0, [n_images] + x_0.shape[1:])
x_0 = mx.broadcast_to(x_0, (n_images,) + x_0.shape[1:])
x_T = self.sampler.add_noise(x_0, mx.array(start_step))
# Perform the denoising loop