mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00

* update a few examples to use compile * update mnist * add compile to vae and rename some stuff for simplicity * update reqs * use state in eval * GCN example with RNG + dropout * add a bit of prefetching
173 lines
5.8 KiB
Python
173 lines
5.8 KiB
Python
# Copyright © 2023-2024 Apple Inc.
|
|
|
|
import math
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
|
|
|
|
# from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/unet.py
|
|
def upsample_nearest(x, scale: int = 2):
|
|
B, H, W, C = x.shape
|
|
x = mx.broadcast_to(x[:, :, None, :, None, :], (B, H, scale, W, scale, C))
|
|
x = x.reshape(B, H * scale, W * scale, C)
|
|
return x
|
|
|
|
|
|
class UpsamplingConv2d(nn.Module):
|
|
"""
|
|
A convolutional layer that upsamples the input by a factor of 2. MLX does
|
|
not yet support transposed convolutions, so we approximate them with
|
|
nearest neighbor upsampling followed by a convolution. This is similar to
|
|
the approach used in the original U-Net.
|
|
"""
|
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(
|
|
in_channels, out_channels, kernel_size, stride=stride, padding=padding
|
|
)
|
|
|
|
def __call__(self, x):
|
|
x = self.conv(upsample_nearest(x))
|
|
return x
|
|
|
|
|
|
class Encoder(nn.Module):
|
|
"""
|
|
A convolutional variational encoder.
|
|
Maps the input to a normal distribution in latent space and sample a latent
|
|
vector from that distribution.
|
|
"""
|
|
|
|
def __init__(self, num_latent_dims, image_shape, max_num_filters):
|
|
super().__init__()
|
|
|
|
# number of filters in the convolutional layers
|
|
num_filters_1 = max_num_filters // 4
|
|
num_filters_2 = max_num_filters // 2
|
|
num_filters_3 = max_num_filters
|
|
|
|
# Output (BHWC): B x 32 x 32 x num_filters_1
|
|
self.conv1 = nn.Conv2d(image_shape[-1], num_filters_1, 3, stride=2, padding=1)
|
|
# Output (BHWC): B x 16 x 16 x num_filters_2
|
|
self.conv2 = nn.Conv2d(num_filters_1, num_filters_2, 3, stride=2, padding=1)
|
|
# Output (BHWC): B x 8 x 8 x num_filters_3
|
|
self.conv3 = nn.Conv2d(num_filters_2, num_filters_3, 3, stride=2, padding=1)
|
|
|
|
# Batch Normalization
|
|
self.bn1 = nn.BatchNorm(num_filters_1)
|
|
self.bn2 = nn.BatchNorm(num_filters_2)
|
|
self.bn3 = nn.BatchNorm(num_filters_3)
|
|
|
|
# Divide the spatial dimensions by 8 because of the 3 strided convolutions
|
|
output_shape = [num_filters_3] + [
|
|
dimension // 8 for dimension in image_shape[:-1]
|
|
]
|
|
|
|
flattened_dim = math.prod(output_shape)
|
|
|
|
# Linear mappings to mean and standard deviation
|
|
self.proj_mu = nn.Linear(flattened_dim, num_latent_dims)
|
|
self.proj_log_var = nn.Linear(flattened_dim, num_latent_dims)
|
|
|
|
def __call__(self, x):
|
|
x = nn.leaky_relu(self.bn1(self.conv1(x)))
|
|
x = nn.leaky_relu(self.bn2(self.conv2(x)))
|
|
x = nn.leaky_relu(self.bn3(self.conv3(x)))
|
|
x = mx.flatten(x, 1) # flatten all dimensions except batch
|
|
|
|
mu = self.proj_mu(x)
|
|
logvar = self.proj_log_var(x)
|
|
# Ensure this is the std deviation, not variance
|
|
sigma = mx.exp(logvar * 0.5)
|
|
|
|
# Generate a tensor of random values from a normal distribution
|
|
eps = mx.random.normal(sigma.shape)
|
|
|
|
# Reparametrization trick to brackpropagate through sampling.
|
|
z = eps * sigma + mu
|
|
|
|
return z, mu, logvar
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
"""A convolutional decoder"""
|
|
|
|
def __init__(self, num_latent_dims, image_shape, max_num_filters):
|
|
super().__init__()
|
|
self.num_latent_dims = num_latent_dims
|
|
num_img_channels = image_shape[-1]
|
|
self.max_num_filters = max_num_filters
|
|
|
|
# decoder layers
|
|
num_filters_1 = max_num_filters
|
|
num_filters_2 = max_num_filters // 2
|
|
num_filters_3 = max_num_filters // 4
|
|
|
|
# divide the last two dimensions by 8 because of the 3 upsampling convolutions
|
|
self.input_shape = [dimension // 8 for dimension in image_shape[:-1]] + [
|
|
num_filters_1
|
|
]
|
|
flattened_dim = math.prod(self.input_shape)
|
|
|
|
# Output: flattened_dim
|
|
self.lin1 = nn.Linear(num_latent_dims, flattened_dim)
|
|
# Output (BHWC): B x 16 x 16 x num_filters_2
|
|
self.upconv1 = UpsamplingConv2d(
|
|
num_filters_1, num_filters_2, 3, stride=1, padding=1
|
|
)
|
|
# Output (BHWC): B x 32 x 32 x num_filters_1
|
|
self.upconv2 = UpsamplingConv2d(
|
|
num_filters_2, num_filters_3, 3, stride=1, padding=1
|
|
)
|
|
# Output (BHWC): B x 64 x 64 x #img_channels
|
|
self.upconv3 = UpsamplingConv2d(
|
|
num_filters_3, num_img_channels, 3, stride=1, padding=1
|
|
)
|
|
|
|
# Batch Normalizations
|
|
self.bn1 = nn.BatchNorm(num_filters_2)
|
|
self.bn2 = nn.BatchNorm(num_filters_3)
|
|
|
|
def __call__(self, z):
|
|
x = self.lin1(z)
|
|
|
|
# reshape to BHWC
|
|
x = x.reshape(
|
|
-1, self.input_shape[0], self.input_shape[1], self.max_num_filters
|
|
)
|
|
|
|
# approximate transposed convolutions with nearest neighbor upsampling
|
|
x = nn.leaky_relu(self.bn1(self.upconv1(x)))
|
|
x = nn.leaky_relu(self.bn2(self.upconv2(x)))
|
|
# sigmoid to ensure pixel values are in [0,1]
|
|
x = mx.sigmoid(self.upconv3(x))
|
|
return x
|
|
|
|
|
|
class CVAE(nn.Module):
|
|
"""
|
|
A convolutional variational autoencoder consisting of an encoder and a
|
|
decoder.
|
|
"""
|
|
|
|
def __init__(self, num_latent_dims, input_shape, max_num_filters):
|
|
super().__init__()
|
|
self.num_latent_dims = num_latent_dims
|
|
self.encoder = Encoder(num_latent_dims, input_shape, max_num_filters)
|
|
self.decoder = Decoder(num_latent_dims, input_shape, max_num_filters)
|
|
|
|
def __call__(self, x):
|
|
# image to latent vector
|
|
z, mu, logvar = self.encoder(x)
|
|
# latent vector to image
|
|
x = self.decode(z)
|
|
return x, mu, logvar
|
|
|
|
def encode(self, x):
|
|
return self.encoder(x)[0]
|
|
|
|
def decode(self, z):
|
|
return self.decoder(z)
|