{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Import Library" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import mnist" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import mlx.core as mx\n", "import mlx.nn as nn\n", "import mlx.optimizers as optim\n", "\n", "from tqdm import tqdm\n", "import numpy as np\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# GAN Architecture" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generator 👨🏻‍🎨" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def GenBlock(in_dim:int,out_dim:int):\n", " \n", " return nn.Sequential(\n", " nn.Linear(in_dim,out_dim),\n", " nn.BatchNorm(out_dim),\n", " nn.ReLU()\n", " )" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "class Generator(nn.Module):\n", "\n", " def __init__(self, z_dim:int = 10, im_dim:int = 784, hidden_dim: int =128):\n", " super(Generator, self).__init__()\n", " # Build the neural network\n", " self.gen = nn.Sequential(\n", " GenBlock(z_dim, hidden_dim),\n", " GenBlock(hidden_dim, hidden_dim * 2),\n", " GenBlock(hidden_dim * 2, hidden_dim * 4),\n", " GenBlock(hidden_dim * 4, hidden_dim * 8),\n", "\n", "\n", " nn.Linear(hidden_dim * 8,im_dim),\n", " nn.Sigmoid()\n", " )\n", " \n", " def __call__(self, noise):\n", "\n", " return self.gen(noise)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Generator(\n", " (gen): Sequential(\n", " (layers.0): Sequential(\n", " (layers.0): Linear(input_dims=100, output_dims=128, bias=True)\n", " (layers.1): BatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (layers.2): ReLU()\n", " )\n", " (layers.1): Sequential(\n", " (layers.0): Linear(input_dims=128, output_dims=256, bias=True)\n", " (layers.1): BatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (layers.2): ReLU()\n", " )\n", " (layers.2): Sequential(\n", " (layers.0): Linear(input_dims=256, output_dims=512, bias=True)\n", " (layers.1): BatchNorm(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (layers.2): ReLU()\n", " )\n", " (layers.3): Sequential(\n", " (layers.0): Linear(input_dims=512, output_dims=1024, bias=True)\n", " (layers.1): BatchNorm(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (layers.2): ReLU()\n", " )\n", " (layers.4): Linear(input_dims=1024, output_dims=784, bias=True)\n", " (layers.5): Sigmoid()\n", " )\n", ")" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gen = Generator(100)\n", "gen" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def get_noise(n_samples, z_dim):\n", "\n", " return np.random.randn(n_samples,z_dim)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Discriminator 🕵🏻‍♂️" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "def DisBlock(in_dim:int,out_dim:int):\n", " \n", " return nn.Sequential(\n", " nn.Linear(in_dim,out_dim),\n", " nn.LeakyReLU(negative_slope=0.2)\n", " )" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "class Discriminator(nn.Module):\n", "\n", " def __init__(self,im_dim:int = 784, hidden_dim:int = 128):\n", " super(Discriminator, self).__init__()\n", "\n", " self.disc = nn.Sequential(\n", " DisBlock(im_dim, hidden_dim * 4),\n", " DisBlock(hidden_dim * 4, hidden_dim * 2),\n", " DisBlock(hidden_dim * 2, hidden_dim),\n", "\n", " nn.Linear(hidden_dim,1),\n", " )\n", " \n", " def __call__(self, noise):\n", "\n", " return self.disc(noise)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Discriminator(\n", " (disc): Sequential(\n", " (layers.0): Sequential(\n", " (layers.0): Linear(input_dims=784, output_dims=512, bias=True)\n", " (layers.1): LeakyReLU()\n", " )\n", " (layers.1): Sequential(\n", " (layers.0): Linear(input_dims=512, output_dims=256, bias=True)\n", " (layers.1): LeakyReLU()\n", " )\n", " (layers.2): Sequential(\n", " (layers.0): Linear(input_dims=256, output_dims=128, bias=True)\n", " (layers.1): LeakyReLU()\n", " )\n", " (layers.3): Linear(input_dims=128, output_dims=1, bias=True)\n", " )\n", ")" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "disc = Discriminator()\n", "disc" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Model Training 🏋🏻‍♂️" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# Set your parameters\n", "criterion = nn.losses.binary_cross_entropy\n", "n_epochs = 200\n", "z_dim = 64\n", "display_step = 500\n", "batch_size = 128\n", "lr = 0.00001" ] }, { "cell_type": "code", "execution_count": 197, "metadata": {}, "outputs": [], "source": [ "gen = Generator(z_dim)\n", "mx.eval(gen.parameters())\n", "gen_opt = optim.Adam(learning_rate=lr)\n", "\n", "disc = Discriminator()\n", "mx.eval(disc.parameters())\n", "disc_opt = optim.Adam(learning_rate=lr)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Losses" ] }, { "cell_type": "code", "execution_count": 198, "metadata": {}, "outputs": [], "source": [ "def disc_loss(gen, disc, real, num_images, z_dim):\n", " noise = mx.array(get_noise(num_images, z_dim))\n", " fake_images = gen(noise)\n", " \n", " fake_disc = disc(fake_images)\n", " \n", " fake_labels = mx.zeros((len(fake_images),1))\n", " fake_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)\n", " \n", " real_disc = disc(real)\n", " real_labels = mx.ones((len(real),1))\n", " real_loss = nn.losses.binary_cross_entropy(real_disc,real_labels,with_logits=True)\n", "\n", " disc_loss = (fake_loss + real_loss) / 2\n", "\n", " return disc_loss" ] }, { "cell_type": "code", "execution_count": 204, "metadata": {}, "outputs": [], "source": [ "def gen_loss(gen, disc, num_images, z_dim):\n", "\n", " noise = mx.array(get_noise(num_images, z_dim))\n", " fake_images = gen(noise)\n", " \n", " fake_disc = disc(fake_images)\n", "\n", " fake_labels = mx.ones((len(fake_images),1))\n", " \n", " gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)\n", "\n", " return gen_loss" ] }, { "cell_type": "code", "execution_count": 205, "metadata": {}, "outputs": [], "source": [ "train_images, _, test_images, _ = map(\n", " mx.array, getattr(mnist, 'mnist')()\n", ")" ] }, { "cell_type": "code", "execution_count": 206, "metadata": {}, "outputs": [], "source": [ "def batch_iterate(batch_size:int, ipt:list):\n", " perm = mx.array(np.random.permutation(len(ipt)))\n", " for s in range(0, ipt.size, batch_size):\n", " ids = perm[s : s + batch_size]\n", " yield ipt[ids]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### show batch of images" ] }, { "cell_type": "code", "execution_count": 207, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "for X in batch_iterate(16, train_images):\n", " fig,axes = plt.subplots(4, 4, figsize=(4, 4))\n", "\n", " for i, ax in enumerate(axes.flat):\n", " img = mx.array(X[i]).reshape(28,28)\n", " ax.imshow(img,cmap='gray')\n", " ax.axis('off')\n", " break" ] }, { "cell_type": "code", "execution_count": 208, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "0it [00:00, ?it/s]\n" ] }, { "ename": "TypeError", "evalue": "'bool' object is not callable", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[208], line 28\u001b[0m\n\u001b[1;32m 23\u001b[0m disc_opt\u001b[38;5;241m.\u001b[39mupdate(disc, D_grads)\n\u001b[1;32m 25\u001b[0m \u001b[38;5;66;03m# Update gradients\u001b[39;00m\n\u001b[0;32m---> 28\u001b[0m G_loss,G_grads \u001b[38;5;241m=\u001b[39m \u001b[43mG_loss_grad\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgen\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdisc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mz_dim\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 30\u001b[0m \u001b[38;5;66;03m# Update optimizer\u001b[39;00m\n\u001b[1;32m 31\u001b[0m gen_opt\u001b[38;5;241m.\u001b[39mupdate(gen, G_grads)\n", "File \u001b[0;32m~/miniforge3/lib/python3.10/site-packages/mlx/nn/utils.py:34\u001b[0m, in \u001b[0;36mvalue_and_grad..wrapped_value_grad_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(fn)\n\u001b[1;32m 33\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapped_value_grad_fn\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m---> 34\u001b[0m value, grad \u001b[38;5;241m=\u001b[39m \u001b[43mvalue_grad_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrainable_parameters\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 35\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m value, grad\n", "File \u001b[0;32m~/miniforge3/lib/python3.10/site-packages/mlx/nn/utils.py:28\u001b[0m, in \u001b[0;36mvalue_and_grad..inner_fn\u001b[0;34m(params, *args, **kwargs)\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21minner_fn\u001b[39m(params, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 27\u001b[0m model\u001b[38;5;241m.\u001b[39mupdate(params)\n\u001b[0;32m---> 28\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "\u001b[0;31mTypeError\u001b[0m: 'bool' object is not callable" ] } ], "source": [ "batch_size = 8\n", "cur_step = 0\n", "mean_generator_loss = 0\n", "mean_discriminator_loss = 0\n", "test_generator = True # Whether the generator should be tested\n", "gen_loss = False\n", "error = False\n", "\n", "D_loss_grad = nn.value_and_grad(disc, disc_loss)\n", "G_loss_grad = nn.value_and_grad(gen, gen_loss)\n", "\n", "\n", "for epoch in range(n_epochs):\n", " \n", " # Dataloader returns the batches\n", " for real in tqdm(batch_iterate(batch_size, train_images)):\n", "\n", " # Flatten the batch of real images from the dataset\n", " \n", " D_loss,D_grads = D_loss_grad(gen, disc, real, batch_size, z_dim)\n", "\n", " # Update optimizer\n", " disc_opt.update(disc, D_grads)\n", " \n", " # Update gradients\n", " \n", " \n", " G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)\n", " \n", " # Update optimizer\n", " gen_opt.update(gen, G_grads)\n", " \n", " # Update gradients\n", "\n", " \n", "\n", " # # Keep track of the average discriminator loss\n", " # mean_discriminator_loss += disc_loss.item() / display_step\n", "\n", " # # Keep track of the average generator loss\n", " # mean_generator_loss += gen_loss.item() / display_step\n", "\n", " # ### Visualization code ###\n", " # if cur_step % display_step == 0 and cur_step > 0:\n", " # print(f\"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}\")\n", " # fake_noise = get_noise(cur_batch_size, z_dim, device=device)\n", " # fake = gen(fake_noise)\n", " # show_tensor_images(fake)\n", " # show_tensor_images(real)\n", " # mean_generator_loss = 0\n", " # mean_discriminator_loss = 0\n", " # cur_step += 1\n" ] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.10" } }, "nbformat": 4, "nbformat_minor": 2 }