| 
									
										
										
										
											2024-02-07 05:02:27 +01:00
										 |  |  |  | # Copyright © 2023-2024 Apple Inc. | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | import argparse | 
					
						
							|  |  |  |  | import time | 
					
						
							| 
									
										
										
										
											2024-02-08 13:00:41 -08:00
										 |  |  |  | from functools import partial | 
					
						
							| 
									
										
										
										
											2024-02-07 05:02:27 +01:00
										 |  |  |  | from pathlib import Path | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | import dataset | 
					
						
							|  |  |  |  | import mlx.core as mx | 
					
						
							|  |  |  |  | import mlx.nn as nn | 
					
						
							|  |  |  |  | import mlx.optimizers as optim | 
					
						
							|  |  |  |  | import numpy as np | 
					
						
							| 
									
										
										
										
											2024-02-08 13:00:41 -08:00
										 |  |  |  | import vae | 
					
						
							| 
									
										
										
										
											2024-02-07 05:02:27 +01:00
										 |  |  |  | from mlx.utils import tree_flatten | 
					
						
							|  |  |  |  | from PIL import Image | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | def grid_image_from_batch(image_batch, num_rows): | 
					
						
							|  |  |  |  |     """
 | 
					
						
							|  |  |  |  |     Generate a grid image from a batch of images. | 
					
						
							|  |  |  |  |     Assumes input has shape (B, H, W, C). | 
					
						
							|  |  |  |  |     """
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     B, H, W, _ = image_batch.shape | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     num_cols = B // num_rows | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     # Calculate the size of the output grid image | 
					
						
							|  |  |  |  |     grid_height = num_rows * H | 
					
						
							|  |  |  |  |     grid_width = num_cols * W | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     # Normalize and convert to the desired data type | 
					
						
							|  |  |  |  |     image_batch = np.array(image_batch * 255).astype(np.uint8) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     # Reshape the batch of images into a 2D grid | 
					
						
							|  |  |  |  |     grid_image = image_batch.reshape(num_rows, num_cols, H, W, -1) | 
					
						
							|  |  |  |  |     grid_image = grid_image.swapaxes(1, 2) | 
					
						
							|  |  |  |  |     grid_image = grid_image.reshape(grid_height, grid_width, -1) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     # Convert the grid to a PIL Image | 
					
						
							|  |  |  |  |     return Image.fromarray(grid_image.squeeze()) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | def loss_fn(model, X): | 
					
						
							|  |  |  |  |     X_recon, mu, logvar = model(X) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     # Reconstruction loss | 
					
						
							|  |  |  |  |     recon_loss = nn.losses.mse_loss(X_recon, X, reduction="sum") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     # KL divergence between encoder distribution and standard normal: | 
					
						
							|  |  |  |  |     kl_div = -0.5 * mx.sum(1 + logvar - mu.square() - logvar.exp()) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     # Total loss | 
					
						
							|  |  |  |  |     return recon_loss + kl_div | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | def reconstruct(model, batch, out_file): | 
					
						
							|  |  |  |  |     # Reconstruct a single batch only | 
					
						
							|  |  |  |  |     images = mx.array(batch["image"]) | 
					
						
							|  |  |  |  |     images_recon = model(images)[0] | 
					
						
							|  |  |  |  |     paired_images = mx.stack([images, images_recon]).swapaxes(0, 1).flatten(0, 1) | 
					
						
							|  |  |  |  |     grid_image = grid_image_from_batch(paired_images, num_rows=16) | 
					
						
							|  |  |  |  |     grid_image.save(out_file) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | def generate( | 
					
						
							|  |  |  |  |     model, | 
					
						
							|  |  |  |  |     out_file, | 
					
						
							|  |  |  |  |     num_samples=128, | 
					
						
							|  |  |  |  | ): | 
					
						
							|  |  |  |  |     # Sample from the latent distribution: | 
					
						
							|  |  |  |  |     z = mx.random.normal([num_samples, model.num_latent_dims]) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     # Decode the latent vectors to images: | 
					
						
							|  |  |  |  |     images = model.decode(z) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     # Save all images in a single file | 
					
						
							|  |  |  |  |     grid_image = grid_image_from_batch(images, num_rows=8) | 
					
						
							|  |  |  |  |     grid_image.save(out_file) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | def main(args): | 
					
						
							|  |  |  |  |     # Load the data | 
					
						
							|  |  |  |  |     img_size = (64, 64, 1) | 
					
						
							|  |  |  |  |     train_iter, test_iter = dataset.mnist( | 
					
						
							|  |  |  |  |         batch_size=args.batch_size, img_size=img_size[:2] | 
					
						
							|  |  |  |  |     ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     save_dir = Path(args.save_dir) | 
					
						
							|  |  |  |  |     save_dir.mkdir(parents=True, exist_ok=True) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     # Load the model | 
					
						
							| 
									
										
										
										
											2024-02-08 13:00:41 -08:00
										 |  |  |  |     model = vae.CVAE(args.latent_dims, img_size, args.max_filters) | 
					
						
							|  |  |  |  |     mx.eval(model.parameters()) | 
					
						
							| 
									
										
										
										
											2024-02-07 05:02:27 +01:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-08 13:00:41 -08:00
										 |  |  |  |     num_params = sum(x.size for _, x in tree_flatten(model.trainable_parameters())) | 
					
						
							| 
									
										
										
										
											2024-02-07 05:02:27 +01:00
										 |  |  |  |     print("Number of trainable params: {:0.04f} M".format(num_params / 1e6)) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     optimizer = optim.AdamW(learning_rate=args.lr) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     # Batches for reconstruction | 
					
						
							|  |  |  |  |     train_batch = next(train_iter) | 
					
						
							|  |  |  |  |     test_batch = next(test_iter) | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-08 13:00:41 -08:00
										 |  |  |  |     state = [model.state, optimizer.state] | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     @partial(mx.compile, inputs=state, outputs=state) | 
					
						
							|  |  |  |  |     def step(X): | 
					
						
							|  |  |  |  |         loss_and_grad_fn = nn.value_and_grad(model, loss_fn) | 
					
						
							|  |  |  |  |         loss, grads = loss_and_grad_fn(model, X) | 
					
						
							|  |  |  |  |         optimizer.update(model, grads) | 
					
						
							|  |  |  |  |         return loss | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-07 05:02:27 +01:00
										 |  |  |  |     for e in range(1, args.epochs + 1): | 
					
						
							|  |  |  |  |         # Reset iterators and stats at the beginning of each epoch | 
					
						
							|  |  |  |  |         train_iter.reset() | 
					
						
							| 
									
										
										
										
											2024-02-08 13:00:41 -08:00
										 |  |  |  |         model.train() | 
					
						
							| 
									
										
										
										
											2024-02-07 05:02:27 +01:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |         # Train one epoch | 
					
						
							|  |  |  |  |         tic = time.perf_counter() | 
					
						
							| 
									
										
										
										
											2024-02-08 13:00:41 -08:00
										 |  |  |  |         loss_acc = 0.0 | 
					
						
							|  |  |  |  |         throughput_acc = 0.0 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # Iterate over training batches | 
					
						
							|  |  |  |  |         for batch_count, batch in enumerate(train_iter): | 
					
						
							|  |  |  |  |             X = mx.array(batch["image"]) | 
					
						
							|  |  |  |  |             throughput_tic = time.perf_counter() | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |             # Forward pass + backward pass + update | 
					
						
							|  |  |  |  |             loss = step(X) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |             # Evaluate updated model parameters | 
					
						
							|  |  |  |  |             mx.eval(state) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |             throughput_toc = time.perf_counter() | 
					
						
							|  |  |  |  |             throughput_acc += X.shape[0] / (throughput_toc - throughput_tic) | 
					
						
							|  |  |  |  |             loss_acc += loss.item() | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |             if batch_count > 0 and (batch_count % 10 == 0): | 
					
						
							|  |  |  |  |                 print( | 
					
						
							|  |  |  |  |                     " | ".join( | 
					
						
							|  |  |  |  |                         [ | 
					
						
							|  |  |  |  |                             f"Epoch {e:4d}", | 
					
						
							|  |  |  |  |                             f"Loss {(loss_acc / batch_count):10.2f}", | 
					
						
							|  |  |  |  |                             f"Throughput {(throughput_acc / batch_count):8.2f} im/s", | 
					
						
							|  |  |  |  |                             f"Batch {batch_count:5d}", | 
					
						
							|  |  |  |  |                         ] | 
					
						
							|  |  |  |  |                     ), | 
					
						
							|  |  |  |  |                     end="\r", | 
					
						
							|  |  |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2024-02-07 05:02:27 +01:00
										 |  |  |  |         toc = time.perf_counter() | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         print( | 
					
						
							|  |  |  |  |             " | ".join( | 
					
						
							|  |  |  |  |                 [ | 
					
						
							|  |  |  |  |                     f"Epoch {e:4d}", | 
					
						
							|  |  |  |  |                     f"Loss {(loss_acc / batch_count):10.2f}", | 
					
						
							|  |  |  |  |                     f"Throughput {(throughput_acc / batch_count):8.2f} im/s", | 
					
						
							|  |  |  |  |                     f"Time {toc - tic:8.1f} (s)", | 
					
						
							|  |  |  |  |                 ] | 
					
						
							|  |  |  |  |             ) | 
					
						
							|  |  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-02-08 13:00:41 -08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |         model.eval() | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-07 05:02:27 +01:00
										 |  |  |  |         # Reconstruct a batch of training and test images | 
					
						
							| 
									
										
										
										
											2024-02-08 13:00:41 -08:00
										 |  |  |  |         reconstruct(model, train_batch, save_dir / f"train_{e:03d}.png") | 
					
						
							|  |  |  |  |         reconstruct(model, test_batch, save_dir / f"test_{e:03d}.png") | 
					
						
							| 
									
										
										
										
											2024-02-07 05:02:27 +01:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |         # Generate images | 
					
						
							| 
									
										
										
										
											2024-02-08 13:00:41 -08:00
										 |  |  |  |         generate(model, save_dir / f"generated_{e:03d}.png") | 
					
						
							| 
									
										
										
										
											2024-02-07 05:02:27 +01:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-08 13:00:41 -08:00
										 |  |  |  |         model.save_weights(str(save_dir / "weights.npz")) | 
					
						
							| 
									
										
										
										
											2024-02-07 05:02:27 +01:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  |  |     parser = argparse.ArgumentParser() | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |  |         "--cpu", | 
					
						
							|  |  |  |  |         action="store_true", | 
					
						
							|  |  |  |  |         help="Use CPU instead of GPU acceleration", | 
					
						
							|  |  |  |  |     ) | 
					
						
							|  |  |  |  |     parser.add_argument("--seed", type=int, default=0, help="Random seed") | 
					
						
							|  |  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |  |         "--batch-size", type=int, default=128, help="Batch size for training" | 
					
						
							|  |  |  |  |     ) | 
					
						
							|  |  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |  |         "--max-filters", | 
					
						
							|  |  |  |  |         type=int, | 
					
						
							|  |  |  |  |         default=64, | 
					
						
							|  |  |  |  |         help="Maximum number of filters in the convolutional layers", | 
					
						
							|  |  |  |  |     ) | 
					
						
							|  |  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |  |         "--epochs", type=int, default=50, help="Number of training epochs" | 
					
						
							|  |  |  |  |     ) | 
					
						
							|  |  |  |  |     parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |  |         "--latent-dims", | 
					
						
							|  |  |  |  |         type=int, | 
					
						
							|  |  |  |  |         default=8, | 
					
						
							|  |  |  |  |         help="Number of latent dimensions (positive integer)", | 
					
						
							|  |  |  |  |     ) | 
					
						
							|  |  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |  |         "--save-dir", | 
					
						
							|  |  |  |  |         type=str, | 
					
						
							|  |  |  |  |         default="models/", | 
					
						
							|  |  |  |  |         help="Path to save the model and reconstructed images.", | 
					
						
							|  |  |  |  |     ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     args = parser.parse_args() | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     if args.cpu: | 
					
						
							|  |  |  |  |         mx.set_default_device(mx.cpu) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     np.random.seed(args.seed) | 
					
						
							|  |  |  |  |     mx.random.seed(args.seed) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     print("Options: ") | 
					
						
							|  |  |  |  |     print(f"  Device: {'GPU' if not args.cpu else 'CPU'}") | 
					
						
							|  |  |  |  |     print(f"  Seed: {args.seed}") | 
					
						
							|  |  |  |  |     print(f"  Batch size: {args.batch_size}") | 
					
						
							|  |  |  |  |     print(f"  Max number of filters: {args.max_filters}") | 
					
						
							|  |  |  |  |     print(f"  Number of epochs: {args.epochs}") | 
					
						
							|  |  |  |  |     print(f"  Learning rate: {args.lr}") | 
					
						
							|  |  |  |  |     print(f"  Number of latent dimensions: {args.latent_dims}") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     main(args) |