Add an image2image example in the stable diffusion (#198)

This commit is contained in:
Angelos Katharopoulos
2023-12-28 18:31:45 -08:00
committed by GitHub
parent 09566c7257
commit 37fd2464dc
6 changed files with 177 additions and 27 deletions

View File

@@ -67,7 +67,7 @@ class EncoderDecoderBlock2D(nn.Module):
# Add an optional downsampling layer
if add_downsample:
self.downsample = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=2, padding=1
out_channels, out_channels, kernel_size=3, stride=2, padding=0
)
# or upsampling layer
@@ -81,6 +81,7 @@ class EncoderDecoderBlock2D(nn.Module):
x = resnet(x)
if "downsample" in self:
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
x = self.downsample(x)
if "upsample" in self:
@@ -253,16 +254,21 @@ class Autoencoder(nn.Module):
)
def decode(self, z):
z = z / self.scaling_factor
return self.decoder(self.post_quant_proj(z))
def __call__(self, x, key=None):
def encode(self, x):
x = self.encoder(x)
x = self.quant_proj(x)
mean, logvar = x.split(2, axis=-1)
std = mx.exp(0.5 * logvar)
z = mx.random.normal(mean.shape, key=key) * std + mean
mean = mean * self.scaling_factor
logvar = logvar + 2 * math.log(self.scaling_factor)
return mean, logvar
def __call__(self, x, key=None):
mean, logvar = self.encode(x)
z = mx.random.normal(mean.shape, key=key) * mx.exp(0.5 * logvar) + mean
x_hat = self.decode(z)
return dict(x_hat=x_hat, z=z, mean=mean, logvar=logvar)