Stable Diffusion: Input image downsampling (#276)

This commit is contained in:
LeonEricsson
2024-01-16 22:45:00 +01:00
committed by GitHub
parent 2ba5d3db14
commit b4c20cc7f7
2 changed files with 11 additions and 2 deletions

View File

@@ -28,7 +28,15 @@ if __name__ == "__main__":
sd = StableDiffusion()
# Read the image
img = mx.array(np.array(Image.open(args.image)))
img = Image.open(args.image)
# Make sure image shape is divisible by 64
W, H = (dim - dim % 64 for dim in (img.width, img.height))
if W != img.width or H != img.height:
print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
img = img.resize((W, H), Image.NEAREST) # use desired downsampling filter
img = mx.array(np.array(img))
img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
# Noise and denoise the latents produced by encoding img.