mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Add an image2image example in the stable diffusion (#198)
This commit is contained in:
committed by
GitHub
parent
09566c7257
commit
37fd2464dc
@@ -67,7 +67,7 @@ class EncoderDecoderBlock2D(nn.Module):
|
||||
# Add an optional downsampling layer
|
||||
if add_downsample:
|
||||
self.downsample = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=2, padding=1
|
||||
out_channels, out_channels, kernel_size=3, stride=2, padding=0
|
||||
)
|
||||
|
||||
# or upsampling layer
|
||||
@@ -81,6 +81,7 @@ class EncoderDecoderBlock2D(nn.Module):
|
||||
x = resnet(x)
|
||||
|
||||
if "downsample" in self:
|
||||
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
|
||||
x = self.downsample(x)
|
||||
|
||||
if "upsample" in self:
|
||||
@@ -253,16 +254,21 @@ class Autoencoder(nn.Module):
|
||||
)
|
||||
|
||||
def decode(self, z):
|
||||
z = z / self.scaling_factor
|
||||
return self.decoder(self.post_quant_proj(z))
|
||||
|
||||
def __call__(self, x, key=None):
|
||||
def encode(self, x):
|
||||
x = self.encoder(x)
|
||||
x = self.quant_proj(x)
|
||||
|
||||
mean, logvar = x.split(2, axis=-1)
|
||||
std = mx.exp(0.5 * logvar)
|
||||
z = mx.random.normal(mean.shape, key=key) * std + mean
|
||||
mean = mean * self.scaling_factor
|
||||
logvar = logvar + 2 * math.log(self.scaling_factor)
|
||||
|
||||
return mean, logvar
|
||||
|
||||
def __call__(self, x, key=None):
|
||||
mean, logvar = self.encode(x)
|
||||
z = mx.random.normal(mean.shape, key=key) * mx.exp(0.5 * logvar) + mean
|
||||
x_hat = self.decode(z)
|
||||
|
||||
return dict(x_hat=x_hat, z=z, mean=mean, logvar=logvar)
|
||||
|
||||
Reference in New Issue
Block a user