mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Example of a Convolutional Variational Autoencoder (CVAE) on MNIST (#264)
* initial commit * style fixes * update of ACKNOWLEDGMENTS * fixed comment * minor refactoring; removed unused imports * added cifar and cvae to top-level README.md * removed mention of cuda/mps in argparse * fixed training status output * load_weights() with strict=True * pretrained model update * fixed imports and style * requires mlx>=0.0.9 * updated with results using mlx 0.0.9 * removed mention of private repo * simplify and combine to one file, more consistency with other exmaples * few more nits * nits * spell * format --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
8071aacd98
commit
9b387007ab
@ -11,3 +11,4 @@ MLX Examples was developed with contributions from the following individuals:
|
|||||||
- Sarthak Yadav: Added the `cifar` and `speechcommands` examples.
|
- Sarthak Yadav: Added the `cifar` and `speechcommands` examples.
|
||||||
- Shunta Saito: Added support for PLaMo models.
|
- Shunta Saito: Added support for PLaMo models.
|
||||||
- Gabrijel Boduljak: Implemented `CLIP`.
|
- Gabrijel Boduljak: Implemented `CLIP`.
|
||||||
|
- Markus Enzweiler: Added the `cvae` examples.
|
||||||
|
@ -20,7 +20,9 @@ Some more useful examples are listed below.
|
|||||||
|
|
||||||
### Image Models
|
### Image Models
|
||||||
|
|
||||||
|
- Image classification using [ResNets on CIFAR-10](cifar).
|
||||||
- Generating images with [Stable Diffusion](stable_diffusion).
|
- Generating images with [Stable Diffusion](stable_diffusion).
|
||||||
|
- Convolutional variational autoencoder [(CVAE) on MNIST](cvae).
|
||||||
|
|
||||||
### Audio Models
|
### Audio Models
|
||||||
|
|
||||||
|
1
cvae/.gitignore
vendored
Normal file
1
cvae/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
models/
|
68
cvae/README.md
Normal file
68
cvae/README.md
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
# Convolutional Variational Autoencoder (CVAE) on MNIST
|
||||||
|
|
||||||
|
Convolutional variational autoencoder (CVAE) implementation in MLX using
|
||||||
|
MNIST.[^1]
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
Install the requirements:
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run
|
||||||
|
|
||||||
|
|
||||||
|
To train a VAE run:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
To see the supported options, do `python main.py -h`.
|
||||||
|
|
||||||
|
Training with the default options should give:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
$ python train.py
|
||||||
|
Options:
|
||||||
|
Device: GPU
|
||||||
|
Seed: 0
|
||||||
|
Batch size: 128
|
||||||
|
Max number of filters: 64
|
||||||
|
Number of epochs: 50
|
||||||
|
Learning rate: 0.001
|
||||||
|
Number of latent dimensions: 8
|
||||||
|
Number of trainable params: 0.1493 M
|
||||||
|
Epoch 1 | Loss 14626.96 | Throughput 1803.44 im/s | Time 34.3 (s)
|
||||||
|
Epoch 2 | Loss 10462.21 | Throughput 1802.20 im/s | Time 34.3 (s)
|
||||||
|
...
|
||||||
|
Epoch 50 | Loss 8293.13 | Throughput 1804.91 im/s | Time 34.2 (s)
|
||||||
|
```
|
||||||
|
|
||||||
|
The throughput was measured on a 32GB M1 Max.
|
||||||
|
|
||||||
|
Reconstructed and generated images will be saved after each epoch in the
|
||||||
|
`models/` path. Below are examples of reconstructed training set images and
|
||||||
|
generated images.
|
||||||
|
|
||||||
|
#### Reconstruction
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
#### Generation
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
At the time of writing, MLX does not have transposed 2D convolutions. The
|
||||||
|
example approximates them with a combination of nearest neighbor upsampling and
|
||||||
|
regular convolutions, similar to the original U-Net. We intend to update this
|
||||||
|
example once transposed 2D convolutions are available.
|
||||||
|
|
||||||
|
[^1]: For a good overview of VAEs see the original paper [Auto-Encoding
|
||||||
|
Variational Bayes](https://arxiv.org/abs/1312.6114) or [An Introduction to
|
||||||
|
Variational Autoencoders](https://arxiv.org/abs/1906.02691).
|
BIN
cvae/assets/rec_mnist.png
Normal file
BIN
cvae/assets/rec_mnist.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 136 KiB |
BIN
cvae/assets/samples_mnist.png
Normal file
BIN
cvae/assets/samples_mnist.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 158 KiB |
53
cvae/dataset.py
Normal file
53
cvae/dataset.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
from mlx.data.datasets import load_mnist
|
||||||
|
|
||||||
|
|
||||||
|
def mnist(batch_size, img_size, root=None):
|
||||||
|
# load train and test sets using mlx-data
|
||||||
|
load_fn = load_mnist
|
||||||
|
tr = load_fn(root=root, train=True)
|
||||||
|
test = load_fn(root=root, train=False)
|
||||||
|
|
||||||
|
# number of image channels is 1 for MNIST
|
||||||
|
num_img_channels = 1
|
||||||
|
|
||||||
|
# normalize to [0,1]
|
||||||
|
def normalize(x):
|
||||||
|
return x.astype("float32") / 255.0
|
||||||
|
|
||||||
|
# iterator over training set
|
||||||
|
tr_iter = (
|
||||||
|
tr.shuffle()
|
||||||
|
.to_stream()
|
||||||
|
.image_resize("image", h=img_size[0], w=img_size[1])
|
||||||
|
.key_transform("image", normalize)
|
||||||
|
.batch(batch_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
# iterator over test set
|
||||||
|
test_iter = (
|
||||||
|
test.to_stream()
|
||||||
|
.image_resize("image", h=img_size[0], w=img_size[1])
|
||||||
|
.key_transform("image", normalize)
|
||||||
|
.batch(batch_size)
|
||||||
|
)
|
||||||
|
return tr_iter, test_iter
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
batch_size = 32
|
||||||
|
img_size = (64, 64) # (H, W)
|
||||||
|
|
||||||
|
tr_iter, test_iter = mnist(batch_size=batch_size, img_size=img_size)
|
||||||
|
|
||||||
|
B, H, W, C = batch_size, img_size[0], img_size[1], 1
|
||||||
|
print(f"Batch size: {B}, Channels: {C}, Height: {H}, Width: {W}")
|
||||||
|
|
||||||
|
batch_tr_iter = next(tr_iter)
|
||||||
|
assert batch_tr_iter["image"].shape == (B, H, W, C), "Wrong training set size"
|
||||||
|
assert batch_tr_iter["label"].shape == (batch_size,), "Wrong training set size"
|
||||||
|
|
||||||
|
batch_test_iter = next(test_iter)
|
||||||
|
assert batch_test_iter["image"].shape == (B, H, W, C), "Wrong training set size"
|
||||||
|
assert batch_test_iter["label"].shape == (batch_size,), "Wrong training set size"
|
229
cvae/main.py
Normal file
229
cvae/main.py
Normal file
@ -0,0 +1,229 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import dataset
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import mlx.optimizers as optim
|
||||||
|
import model
|
||||||
|
import numpy as np
|
||||||
|
from mlx.utils import tree_flatten
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def grid_image_from_batch(image_batch, num_rows):
|
||||||
|
"""
|
||||||
|
Generate a grid image from a batch of images.
|
||||||
|
Assumes input has shape (B, H, W, C).
|
||||||
|
"""
|
||||||
|
|
||||||
|
B, H, W, _ = image_batch.shape
|
||||||
|
|
||||||
|
num_cols = B // num_rows
|
||||||
|
|
||||||
|
# Calculate the size of the output grid image
|
||||||
|
grid_height = num_rows * H
|
||||||
|
grid_width = num_cols * W
|
||||||
|
|
||||||
|
# Normalize and convert to the desired data type
|
||||||
|
image_batch = np.array(image_batch * 255).astype(np.uint8)
|
||||||
|
|
||||||
|
# Reshape the batch of images into a 2D grid
|
||||||
|
grid_image = image_batch.reshape(num_rows, num_cols, H, W, -1)
|
||||||
|
grid_image = grid_image.swapaxes(1, 2)
|
||||||
|
grid_image = grid_image.reshape(grid_height, grid_width, -1)
|
||||||
|
|
||||||
|
# Convert the grid to a PIL Image
|
||||||
|
return Image.fromarray(grid_image.squeeze())
|
||||||
|
|
||||||
|
|
||||||
|
def loss_fn(model, X):
|
||||||
|
X_recon, mu, logvar = model(X)
|
||||||
|
|
||||||
|
# Reconstruction loss
|
||||||
|
recon_loss = nn.losses.mse_loss(X_recon, X, reduction="sum")
|
||||||
|
|
||||||
|
# KL divergence between encoder distribution and standard normal:
|
||||||
|
kl_div = -0.5 * mx.sum(1 + logvar - mu.square() - logvar.exp())
|
||||||
|
|
||||||
|
# Total loss
|
||||||
|
return recon_loss + kl_div
|
||||||
|
|
||||||
|
|
||||||
|
def train_epoch(model, data, optimizer, epoch):
|
||||||
|
loss_acc = 0.0
|
||||||
|
throughput_acc = 0.0
|
||||||
|
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||||
|
|
||||||
|
# Iterate over training batches
|
||||||
|
for batch_count, batch in enumerate(data):
|
||||||
|
X = mx.array(batch["image"])
|
||||||
|
|
||||||
|
throughput_tic = time.perf_counter()
|
||||||
|
|
||||||
|
# Forward pass + backward pass + update
|
||||||
|
loss, grads = loss_and_grad_fn(model, X)
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
|
||||||
|
# Evaluate updated model parameters
|
||||||
|
mx.eval(model.parameters(), optimizer.state)
|
||||||
|
|
||||||
|
throughput_toc = time.perf_counter()
|
||||||
|
throughput_acc += X.shape[0] / (throughput_toc - throughput_tic)
|
||||||
|
loss_acc += loss.item()
|
||||||
|
|
||||||
|
if batch_count > 0 and (batch_count % 10 == 0):
|
||||||
|
print(
|
||||||
|
" | ".join(
|
||||||
|
[
|
||||||
|
f"Epoch {epoch:4d}",
|
||||||
|
f"Loss {(loss_acc / batch_count):10.2f}",
|
||||||
|
f"Throughput {(throughput_acc / batch_count):8.2f} im/s",
|
||||||
|
f"Batch {batch_count:5d}",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
end="\r",
|
||||||
|
)
|
||||||
|
|
||||||
|
return loss_acc, throughput_acc, batch_count
|
||||||
|
|
||||||
|
|
||||||
|
def reconstruct(model, batch, out_file):
|
||||||
|
# Reconstruct a single batch only
|
||||||
|
images = mx.array(batch["image"])
|
||||||
|
images_recon = model(images)[0]
|
||||||
|
paired_images = mx.stack([images, images_recon]).swapaxes(0, 1).flatten(0, 1)
|
||||||
|
grid_image = grid_image_from_batch(paired_images, num_rows=16)
|
||||||
|
grid_image.save(out_file)
|
||||||
|
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
model,
|
||||||
|
out_file,
|
||||||
|
num_samples=128,
|
||||||
|
):
|
||||||
|
# Sample from the latent distribution:
|
||||||
|
z = mx.random.normal([num_samples, model.num_latent_dims])
|
||||||
|
|
||||||
|
# Decode the latent vectors to images:
|
||||||
|
images = model.decode(z)
|
||||||
|
|
||||||
|
# Save all images in a single file
|
||||||
|
grid_image = grid_image_from_batch(images, num_rows=8)
|
||||||
|
grid_image.save(out_file)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
# Load the data
|
||||||
|
img_size = (64, 64, 1)
|
||||||
|
train_iter, test_iter = dataset.mnist(
|
||||||
|
batch_size=args.batch_size, img_size=img_size[:2]
|
||||||
|
)
|
||||||
|
|
||||||
|
save_dir = Path(args.save_dir)
|
||||||
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Load the model
|
||||||
|
vae = model.CVAE(args.latent_dims, img_size, args.max_filters)
|
||||||
|
mx.eval(vae.parameters())
|
||||||
|
|
||||||
|
num_params = sum(x.size for _, x in tree_flatten(vae.trainable_parameters()))
|
||||||
|
print("Number of trainable params: {:0.04f} M".format(num_params / 1e6))
|
||||||
|
|
||||||
|
optimizer = optim.AdamW(learning_rate=args.lr)
|
||||||
|
|
||||||
|
# Batches for reconstruction
|
||||||
|
train_batch = next(train_iter)
|
||||||
|
test_batch = next(test_iter)
|
||||||
|
|
||||||
|
for e in range(1, args.epochs + 1):
|
||||||
|
# Reset iterators and stats at the beginning of each epoch
|
||||||
|
train_iter.reset()
|
||||||
|
vae.train()
|
||||||
|
|
||||||
|
# Train one epoch
|
||||||
|
tic = time.perf_counter()
|
||||||
|
loss_acc, throughput_acc, batch_count = train_epoch(
|
||||||
|
vae, train_iter, optimizer, e
|
||||||
|
)
|
||||||
|
toc = time.perf_counter()
|
||||||
|
|
||||||
|
vae.eval()
|
||||||
|
|
||||||
|
print(
|
||||||
|
" | ".join(
|
||||||
|
[
|
||||||
|
f"Epoch {e:4d}",
|
||||||
|
f"Loss {(loss_acc / batch_count):10.2f}",
|
||||||
|
f"Throughput {(throughput_acc / batch_count):8.2f} im/s",
|
||||||
|
f"Time {toc - tic:8.1f} (s)",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Reconstruct a batch of training and test images
|
||||||
|
reconstruct(vae, train_batch, save_dir / f"train_{e:03d}.png")
|
||||||
|
reconstruct(vae, test_batch, save_dir / f"test_{e:03d}.png")
|
||||||
|
|
||||||
|
# Generate images
|
||||||
|
generate(vae, save_dir / f"generated_{e:03d}.png")
|
||||||
|
|
||||||
|
vae.save_weights(str(save_dir / "weights.npz"))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--cpu",
|
||||||
|
action="store_true",
|
||||||
|
help="Use CPU instead of GPU acceleration",
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size", type=int, default=128, help="Batch size for training"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-filters",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="Maximum number of filters in the convolutional layers",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--epochs", type=int, default=50, help="Number of training epochs"
|
||||||
|
)
|
||||||
|
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--latent-dims",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Number of latent dimensions (positive integer)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-dir",
|
||||||
|
type=str,
|
||||||
|
default="models/",
|
||||||
|
help="Path to save the model and reconstructed images.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.cpu:
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
mx.random.seed(args.seed)
|
||||||
|
|
||||||
|
print("Options: ")
|
||||||
|
print(f" Device: {'GPU' if not args.cpu else 'CPU'}")
|
||||||
|
print(f" Seed: {args.seed}")
|
||||||
|
print(f" Batch size: {args.batch_size}")
|
||||||
|
print(f" Max number of filters: {args.max_filters}")
|
||||||
|
print(f" Number of epochs: {args.epochs}")
|
||||||
|
print(f" Learning rate: {args.lr}")
|
||||||
|
print(f" Number of latent dimensions: {args.latent_dims}")
|
||||||
|
|
||||||
|
main(args)
|
172
cvae/model.py
Normal file
172
cvae/model.py
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
# from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/unet.py
|
||||||
|
def upsample_nearest(x, scale: int = 2):
|
||||||
|
B, H, W, C = x.shape
|
||||||
|
x = mx.broadcast_to(x[:, :, None, :, None, :], (B, H, scale, W, scale, C))
|
||||||
|
x = x.reshape(B, H * scale, W * scale, C)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class UpsamplingConv2d(nn.Module):
|
||||||
|
"""
|
||||||
|
A convolutional layer that upsamples the input by a factor of 2. MLX does
|
||||||
|
not yet support transposed convolutions, so we approximate them with
|
||||||
|
nearest neighbor upsampling followed by a convolution. This is similar to
|
||||||
|
the approach used in the original U-Net.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(
|
||||||
|
in_channels, out_channels, kernel_size, stride=stride, padding=padding
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
x = self.conv(upsample_nearest(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
"""
|
||||||
|
A convolutional variational encoder.
|
||||||
|
Maps the input to a normal distribution in latent space and sample a latent
|
||||||
|
vector from that distribution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_latent_dims, image_shape, max_num_filters):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# number of filters in the convolutional layers
|
||||||
|
num_filters_1 = max_num_filters // 4
|
||||||
|
num_filters_2 = max_num_filters // 2
|
||||||
|
num_filters_3 = max_num_filters
|
||||||
|
|
||||||
|
# Output (BHWC): B x 32 x 32 x num_filters_1
|
||||||
|
self.conv1 = nn.Conv2d(image_shape[-1], num_filters_1, 3, stride=2, padding=1)
|
||||||
|
# Output (BHWC): B x 16 x 16 x num_filters_2
|
||||||
|
self.conv2 = nn.Conv2d(num_filters_1, num_filters_2, 3, stride=2, padding=1)
|
||||||
|
# Output (BHWC): B x 8 x 8 x num_filters_3
|
||||||
|
self.conv3 = nn.Conv2d(num_filters_2, num_filters_3, 3, stride=2, padding=1)
|
||||||
|
|
||||||
|
# Batch Normalization
|
||||||
|
self.bn1 = nn.BatchNorm(num_filters_1)
|
||||||
|
self.bn2 = nn.BatchNorm(num_filters_2)
|
||||||
|
self.bn3 = nn.BatchNorm(num_filters_3)
|
||||||
|
|
||||||
|
# Divide the spatial dimensions by 8 because of the 3 strided convolutions
|
||||||
|
output_shape = [num_filters_3] + [
|
||||||
|
dimension // 8 for dimension in image_shape[:-1]
|
||||||
|
]
|
||||||
|
|
||||||
|
flattened_dim = math.prod(output_shape)
|
||||||
|
|
||||||
|
# Linear mappings to mean and standard deviation
|
||||||
|
self.proj_mu = nn.Linear(flattened_dim, num_latent_dims)
|
||||||
|
self.proj_log_var = nn.Linear(flattened_dim, num_latent_dims)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
x = nn.leaky_relu(self.bn1(self.conv1(x)))
|
||||||
|
x = nn.leaky_relu(self.bn2(self.conv2(x)))
|
||||||
|
x = nn.leaky_relu(self.bn3(self.conv3(x)))
|
||||||
|
x = mx.flatten(x, 1) # flatten all dimensions except batch
|
||||||
|
|
||||||
|
mu = self.proj_mu(x)
|
||||||
|
logvar = self.proj_log_var(x)
|
||||||
|
# Ensure this is the std deviation, not variance
|
||||||
|
sigma = mx.exp(logvar * 0.5)
|
||||||
|
|
||||||
|
# Generate a tensor of random values from a normal distribution
|
||||||
|
eps = mx.random.normal(sigma.shape)
|
||||||
|
|
||||||
|
# Reparametrization trick to brackpropagate through sampling.
|
||||||
|
z = eps * sigma + mu
|
||||||
|
|
||||||
|
return z, mu, logvar
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
"""A convolutional decoder"""
|
||||||
|
|
||||||
|
def __init__(self, num_latent_dims, image_shape, max_num_filters):
|
||||||
|
super().__init__()
|
||||||
|
self.num_latent_dims = num_latent_dims
|
||||||
|
num_img_channels = image_shape[-1]
|
||||||
|
self.max_num_filters = max_num_filters
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
num_filters_1 = max_num_filters
|
||||||
|
num_filters_2 = max_num_filters // 2
|
||||||
|
num_filters_3 = max_num_filters // 4
|
||||||
|
|
||||||
|
# divide the last two dimensions by 8 because of the 3 upsampling convolutions
|
||||||
|
self.input_shape = [dimension // 8 for dimension in image_shape[:-1]] + [
|
||||||
|
num_filters_1
|
||||||
|
]
|
||||||
|
flattened_dim = math.prod(self.input_shape)
|
||||||
|
|
||||||
|
# Output: flattened_dim
|
||||||
|
self.lin1 = nn.Linear(num_latent_dims, flattened_dim)
|
||||||
|
# Output (BHWC): B x 16 x 16 x num_filters_2
|
||||||
|
self.upconv1 = UpsamplingConv2d(
|
||||||
|
num_filters_1, num_filters_2, 3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
# Output (BHWC): B x 32 x 32 x num_filters_1
|
||||||
|
self.upconv2 = UpsamplingConv2d(
|
||||||
|
num_filters_2, num_filters_3, 3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
# Output (BHWC): B x 64 x 64 x #img_channels
|
||||||
|
self.upconv3 = UpsamplingConv2d(
|
||||||
|
num_filters_3, num_img_channels, 3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Batch Normalizations
|
||||||
|
self.bn1 = nn.BatchNorm(num_filters_2)
|
||||||
|
self.bn2 = nn.BatchNorm(num_filters_3)
|
||||||
|
|
||||||
|
def __call__(self, z):
|
||||||
|
x = self.lin1(z)
|
||||||
|
|
||||||
|
# reshape to BHWC
|
||||||
|
x = x.reshape(
|
||||||
|
-1, self.input_shape[0], self.input_shape[1], self.max_num_filters
|
||||||
|
)
|
||||||
|
|
||||||
|
# approximate transposed convolutions with nearest neighbor upsampling
|
||||||
|
x = nn.leaky_relu(self.bn1(self.upconv1(x)))
|
||||||
|
x = nn.leaky_relu(self.bn2(self.upconv2(x)))
|
||||||
|
# sigmoid to ensure pixel values are in [0,1]
|
||||||
|
x = mx.sigmoid(self.upconv3(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CVAE(nn.Module):
|
||||||
|
"""
|
||||||
|
A convolutional variational autoencoder consisting of an encoder and a
|
||||||
|
decoder.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_latent_dims, input_shape, max_num_filters):
|
||||||
|
super().__init__()
|
||||||
|
self.num_latent_dims = num_latent_dims
|
||||||
|
self.encoder = Encoder(num_latent_dims, input_shape, max_num_filters)
|
||||||
|
self.decoder = Decoder(num_latent_dims, input_shape, max_num_filters)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
# image to latent vector
|
||||||
|
z, mu, logvar = self.encoder(x)
|
||||||
|
# latent vector to image
|
||||||
|
x = self.decode(z)
|
||||||
|
return x, mu, logvar
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
return self.encoder(x)[0]
|
||||||
|
|
||||||
|
def decode(self, z):
|
||||||
|
return self.decoder(z)
|
4
cvae/requirements.txt
Normal file
4
cvae/requirements.txt
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
mlx>=0.0.9
|
||||||
|
mlx-data
|
||||||
|
numpy
|
||||||
|
Pillow
|
@ -1,4 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from sys import exit
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
@ -6,8 +7,6 @@ import mlx.nn as nn
|
|||||||
|
|
||||||
from .base import BaseModelArgs
|
from .base import BaseModelArgs
|
||||||
|
|
||||||
from sys import exit
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import hf_olmo
|
import hf_olmo
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
Loading…
Reference in New Issue
Block a user