mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
Stable Diffusion: Input image downsampling (#276)
This commit is contained in:
@@ -28,7 +28,15 @@ if __name__ == "__main__":
|
||||
sd = StableDiffusion()
|
||||
|
||||
# Read the image
|
||||
img = mx.array(np.array(Image.open(args.image)))
|
||||
img = Image.open(args.image)
|
||||
|
||||
# Make sure image shape is divisible by 64
|
||||
W, H = (dim - dim % 64 for dim in (img.width, img.height))
|
||||
if W != img.width or H != img.height:
|
||||
print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
|
||||
img = img.resize((W, H), Image.NEAREST) # use desired downsampling filter
|
||||
|
||||
img = mx.array(np.array(img))
|
||||
img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
|
||||
|
||||
# Noise and denoise the latents produced by encoding img.
|
||||
|
||||
Reference in New Issue
Block a user