mlx-examples/gan/playground.ipynb

950 lines
409 KiB
Plaintext
Raw Normal View History

2024-07-26 21:07:40 +08:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Import Library"
]
},
{
"cell_type": "code",
2024-07-30 07:44:41 +08:00
"execution_count": 657,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"import mnist"
]
},
{
"cell_type": "code",
2024-07-30 07:44:41 +08:00
"execution_count": 658,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"import mlx.core as mx\n",
"import mlx.nn as nn\n",
"import mlx.optimizers as optim\n",
"\n",
"from tqdm import tqdm\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
2024-07-30 07:17:12 +08:00
{
"cell_type": "code",
2024-07-30 07:44:41 +08:00
"execution_count": 659,
2024-07-30 07:17:12 +08:00
"metadata": {},
"outputs": [],
"source": [
"# mx.set_default_device(mx.gpu)"
]
},
2024-07-26 21:07:40 +08:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# GAN Architecture"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generator 👨🏻‍🎨"
]
},
{
"cell_type": "code",
2024-07-30 07:44:41 +08:00
"execution_count": 660,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def GenBlock(in_dim:int,out_dim:int):\n",
" \n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
2024-07-30 07:06:52 +08:00
" nn.BatchNorm(out_dim, 0.8),\n",
" nn.LeakyReLU(0.2)\n",
2024-07-26 21:07:40 +08:00
" )"
]
},
{
"cell_type": "code",
2024-07-30 07:44:41 +08:00
"execution_count": 661,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"class Generator(nn.Module):\n",
"\n",
2024-07-30 07:06:52 +08:00
" def __init__(self, z_dim:int = 10, im_dim:int = 784, hidden_dim: int = 256):\n",
2024-07-26 21:07:40 +08:00
" super(Generator, self).__init__()\n",
" # Build the neural network\n",
" self.gen = nn.Sequential(\n",
" GenBlock(z_dim, hidden_dim),\n",
" GenBlock(hidden_dim, hidden_dim * 2),\n",
" GenBlock(hidden_dim * 2, hidden_dim * 4),\n",
" GenBlock(hidden_dim * 4, hidden_dim * 8),\n",
"\n",
"\n",
" nn.Linear(hidden_dim * 8,im_dim),\n",
" )\n",
" \n",
" def __call__(self, noise):\n",
2024-07-30 07:06:52 +08:00
" x = self.gen(noise)\n",
" return mx.tanh(x)"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 07:44:41 +08:00
"execution_count": 662,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Generator(\n",
" (gen): Sequential(\n",
" (layers.0): Sequential(\n",
2024-07-30 07:06:52 +08:00
" (layers.0): Linear(input_dims=100, output_dims=256, bias=True)\n",
" (layers.1): BatchNorm(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): LeakyReLU()\n",
2024-07-26 21:07:40 +08:00
" )\n",
" (layers.1): Sequential(\n",
2024-07-30 07:06:52 +08:00
" (layers.0): Linear(input_dims=256, output_dims=512, bias=True)\n",
" (layers.1): BatchNorm(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): LeakyReLU()\n",
2024-07-26 21:07:40 +08:00
" )\n",
" (layers.2): Sequential(\n",
2024-07-30 07:06:52 +08:00
" (layers.0): Linear(input_dims=512, output_dims=1024, bias=True)\n",
" (layers.1): BatchNorm(1024, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): LeakyReLU()\n",
2024-07-26 21:07:40 +08:00
" )\n",
" (layers.3): Sequential(\n",
2024-07-30 07:06:52 +08:00
" (layers.0): Linear(input_dims=1024, output_dims=2048, bias=True)\n",
" (layers.1): BatchNorm(2048, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): LeakyReLU()\n",
2024-07-26 21:07:40 +08:00
" )\n",
2024-07-30 07:06:52 +08:00
" (layers.4): Linear(input_dims=2048, output_dims=784, bias=True)\n",
2024-07-26 21:07:40 +08:00
" )\n",
")"
]
},
2024-07-30 07:44:41 +08:00
"execution_count": 662,
2024-07-26 21:07:40 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gen = Generator(100)\n",
"gen"
]
},
{
"cell_type": "code",
2024-07-30 07:44:41 +08:00
"execution_count": 663,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-30 00:44:16 +08:00
"# make 2D noise with shape n_samples x z_dim\n",
"def get_noise(n_samples:list[int], z_dim:int)->list[int]:\n",
" return mx.random.normal(shape=(n_samples, z_dim))"
]
},
{
"cell_type": "code",
2024-07-30 07:44:41 +08:00
"execution_count": 664,
2024-07-30 00:44:16 +08:00
"metadata": {},
"outputs": [
{
"data": {
2024-07-30 07:44:41 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAWFklEQVR4nO3ca2zW9d3H8U8ZpQdhLRZKS7shpRURCsI4SxlqHYfoAjNAMtDoDoYdHuiiRreZzZll2eayuGWZmRJlZGMiYW4Y5BCVKLOcQYRROayxlnZtEdqK5SS97mff5E7upNfnl8z7zp336/H1vi5KWz78n3xzMplMRgAASBrwv/0HAAD838EoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIAzM9oXr1q2z37yhocFuSktL7UaSPvOZz9jN6dOn7eaDDz6wm+uuu85u+vr67EZK+5rq6ursZvjw4XZTXl5uN5L0yiuv2M3JkyftZtSoUXZTUFBgN5WVlXYjSSUlJXbzxz/+0W5uuukmu7l48aLdTJgwwW4kqaOjw24GDPD//1tcXGw3vb29diNJb775pt08/PDDdpPN7zpPCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACBkfRCvqanJfvP6+nq7OX/+vN1I0rvvvms3gwYNspuKigq7ycvLs5uUo1+SNGLECLvp6uqym6NHj9rNsGHD7EaSlixZYjebNm2ym9zcXLv53e9+ZzcPPvig3UjS5MmT7Sbl52H69Ol2093dbTfV1dV2I0kTJ060m/Xr19tNylHKlIOZUtrvxurVq+2Gg3gAAAujAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAkPVBvJSjZKdOnbKbI0eO2I2UdsSroaHBbkpLS+3mb3/7m93MmDHDbiSpqqrKbnbu3Gk3jz76qN2cPHnSbiSpra3NblK+t1OmTLGbLVu22E3KcTZJ+sc//mE3w4cPt5tDhw7ZzYcffmg3qXbv3m03o0ePtpuUr6m8vNxuJOn3v/+93dx///1Jn9UfnhQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAACHrK6kffPCB/eb//ve/7aagoMBuJKmsrMxu8vPz7aa3t9du6uvr7SbluqUk7d27125Svk/Dhg2zm127dtmNJDU3N9vNE088kfRZrnPnztlN6iXNxsZGuxk8eLDdtLa22s2iRYvsZvLkyXYjSVevXrWb48eP283ly5ft5pNPPrEbScrNzbWblMuv2eBJAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAIScTCaTyeaFDz30kP3mZ8+etZuUw3ZS2hGv2bNn203KMa5jx47ZTWFhod1I0oULF+zm4YcftpunnnrKbu677z67kaTTp0/bzZ49e+zmvffes5uUY4KVlZV2I6Udndu8ebPdTJgwwW5Sfm9TjktKaUf+nn76abuZO3eu3aQcipSkG264wW46OjrsZtWqVf2+hicFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEAZm/cKBWb80LF682G4OHDhgN5LU3d1tN01NTXYzevRou0k54JVyRE+ScnJy7GbhwoV2c+utt9pNY2Oj3UjSqVOn7Kavr89uysvL7WbUqFF2M2fOHLuRpIqKCrvZu3ev3Vx77bV2k3IQL+VzJGn37t12k/J3d9NNN9nNP//5T7uRpO3bt9tN6vHQ/vCkAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAELWV+56enrsNz958qTdHD582G4kaf/+/XazYMECuzlz5ozdfPvb37ab733ve3YjSefPn7eb+++/325SjgmuWbPGbiRp5cqVdvPAAw/YzZtvvmk37733nt00NzfbjSS99dZbdvOd73znU2nmz59vNz//+c/tRpKeeuopuyktLbWblJ/xy5cv240kdXV12c3NN9+c9Fn94UkBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAAhJxMJpP5T735fffdZzfjxo1L+qwRI0bYzaFDh+zm0qVLdpNy7Grx4sV2I0m//e1v7eauu+6ym7KyMrvJy8uzG0k6ePCg3QwZMsRu3nnnHbtpb2+3myVLltiNJLW0tNhNYWGh3SxcuNBuXnrpJbtJOS6ZKuXflYEDs74XGlKOgErSjBkz7Cbl+OWqVav6fQ1PCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAkPWV1F//+tf2m584ccJuUi9plpSU2E1vb6/dLFu2zG727dtnNzk5OXYjSadPn7abV1991W6mTZtmN319fXYjSR0dHXYzZcoUuzlw4IDdrFy50m5S/x7Wr19vN1evXrWblJ+9efPm2U3K91VK+5quXLliN8XFxXaTchFZSrtmm3KZNptLwDwpAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgDAw2xcWFRXZb37NNdfYTVNTk91IacerOjs77aahocFuysrK7Katrc1uJGncuHF28+6779rNpUuX7CblAKEkFRQU2M3SpUvt5vDhw3Zz4cIFu8nmKNn/JOV3sKKiwm7GjBljNynH2aqrq+1GkjZt2mQ3CxcutJuenh67SfldkqQvfvGLdpPyM54NnhQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAyMlkMplsXvj888/bb/7JJ598Ko0k3X333XZz77332s1XvvIVu0k54JWfn283kjR9+nS7STmatm7dOrupqamxGyntsOK5c+fsZurUqXazdu1au5kzZ47dSGnH1iZPnmw3dXV1dnPgwAG7efvtt+1GSjuQmJubazd79+61m5S/Oynt5+gXv/iF3WTz7yRPCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAMzPaFDz30kP3ms2bNspsPP/zQbiSps7PTbrq6uuymtbXVbr7xjW/YTUtLi91IaX8PAwb4/zdIOVJXW1trN5K0Z88euykqKrKbxx9/3G5uueUWu0n5eiRpyJAhdvPWW2/ZTcrPUG9vr93s37/fbiTpm9/8pt3Mnj3bbhYsWGA33d3ddiNJI0eOtJtTp04lfVZ/eFIAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAISsr6SuWLHCfvNBgwbZzXe/+127kaSf/OQndlNcXGw3ZWVldvPVr37Vbl577TW7kaRf/vKXdlNTU2M3KZcqN23aZDeSNHToULspKSmxmxtuuMFuUn6Ghg0bZjdS2uXXlM/aunWr3Xz88cd2k/I7K0mrV6+2mytXrthNyhXXqVOn2o0kZTIZu5k5c2bSZ/WHJwUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQcjJZXmL6/ve/b7/5mDFj7Kanp8duJOn999+3m8LCQruZNGmS3bS1tdnNjh077EaS6uvr7WbIkCF209XVZTe7du2yG0m65ppr7Gbs2LF2k5+fbzcbN260m+rqaruR0o7vpXxNKcftUo5fdnd3240knTlzxm6eeeYZu1m6dKndpPz7IEm1tbV209HRYTePPvpov6/hSQEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAACErA/izZkzx37zn/70p3Zz9uxZu5GkvLw8u9m2bdun8jnDhg2zm6KiIruRpPb2drt544037Ob
2024-07-30 00:44:16 +08:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"img = get_noise(28,28)\n",
"plt.imshow(img, cmap='gray')\n",
"plt.axis('off')\n",
"plt.show()"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Discriminator 🕵🏻‍♂️"
]
},
{
"cell_type": "code",
2024-07-30 07:44:41 +08:00
"execution_count": 665,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def DisBlock(in_dim:int,out_dim:int):\n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
2024-07-30 00:44:16 +08:00
" nn.LeakyReLU(negative_slope=0.2),\n",
2024-07-26 21:07:40 +08:00
" )"
]
},
{
"cell_type": "code",
2024-07-30 07:44:41 +08:00
"execution_count": 666,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"class Discriminator(nn.Module):\n",
"\n",
2024-07-30 07:37:09 +08:00
" def __init__(self,im_dim:int = 784, hidden_dim:int = 256):\n",
2024-07-26 21:07:40 +08:00
" super(Discriminator, self).__init__()\n",
"\n",
" self.disc = nn.Sequential(\n",
2024-07-30 07:37:09 +08:00
" DisBlock(im_dim, hidden_dim * 4),\n",
" DisBlock(hidden_dim * 4, hidden_dim * 2),\n",
2024-07-26 21:07:40 +08:00
" DisBlock(hidden_dim * 2, hidden_dim),\n",
2024-07-30 07:37:09 +08:00
" \n",
" nn.Dropout(0.3),\n",
2024-07-26 21:07:40 +08:00
" nn.Linear(hidden_dim,1),\n",
2024-07-30 07:37:09 +08:00
" # nn.Sigmoid()\n",
2024-07-26 21:07:40 +08:00
" )\n",
" \n",
" def __call__(self, noise):\n",
2024-07-30 07:37:09 +08:00
" x = noise + 1.0\n",
" x = self.disc(noise)\n",
" out = mx.log(mx.softmax(x)) \n",
" return out"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 07:44:41 +08:00
"execution_count": 667,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Discriminator(\n",
" (disc): Sequential(\n",
" (layers.0): Sequential(\n",
2024-07-30 07:37:09 +08:00
" (layers.0): Linear(input_dims=784, output_dims=1024, bias=True)\n",
2024-07-26 21:07:40 +08:00
" (layers.1): LeakyReLU()\n",
" )\n",
" (layers.1): Sequential(\n",
2024-07-30 07:37:09 +08:00
" (layers.0): Linear(input_dims=1024, output_dims=512, bias=True)\n",
" (layers.1): LeakyReLU()\n",
" )\n",
" (layers.2): Sequential(\n",
" (layers.0): Linear(input_dims=512, output_dims=256, bias=True)\n",
2024-07-26 21:07:40 +08:00
" (layers.1): LeakyReLU()\n",
" )\n",
2024-07-30 07:37:09 +08:00
" (layers.3): Dropout(p=0.30000000000000004)\n",
" (layers.4): Linear(input_dims=256, output_dims=1, bias=True)\n",
2024-07-26 21:07:40 +08:00
" )\n",
")"
]
},
2024-07-30 07:44:41 +08:00
"execution_count": 667,
2024-07-26 21:07:40 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"disc = Discriminator()\n",
"disc"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model Training 🏋🏻‍♂️"
]
},
{
2024-07-30 00:44:16 +08:00
"cell_type": "markdown",
2024-07-26 21:07:40 +08:00
"metadata": {},
"source": [
2024-07-30 00:44:16 +08:00
"### Losses"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-07-30 00:44:16 +08:00
"#### Discriminator Loss"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 07:44:41 +08:00
"execution_count": 668,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-26 21:36:29 +08:00
"def disc_loss(gen, disc, real, num_images, z_dim):\n",
2024-07-30 00:44:16 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" noise = mx.array(get_noise(num_images, z_dim))\n",
" fake_images = gen(noise)\n",
2024-07-27 06:09:51 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" fake_disc = disc(fake_images)\n",
" \n",
2024-07-27 05:19:08 +08:00
" fake_labels = mx.zeros((fake_images.shape[0],1))\n",
2024-07-30 00:44:16 +08:00
" \n",
2024-07-30 07:06:52 +08:00
" fake_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels)\n",
2024-07-26 21:07:40 +08:00
" \n",
" real_disc = disc(real)\n",
2024-07-27 05:19:08 +08:00
" real_labels = mx.ones((real.shape[0],1))\n",
"\n",
2024-07-30 07:06:52 +08:00
" real_loss = nn.losses.binary_cross_entropy(real_disc,real_labels)\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-30 07:06:52 +08:00
" disc_loss = (fake_loss + real_loss) / 2\n",
2024-07-26 21:07:40 +08:00
"\n",
" return disc_loss"
]
},
2024-07-30 00:44:16 +08:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Generator Loss"
]
},
2024-07-26 21:07:40 +08:00
{
"cell_type": "code",
2024-07-30 07:44:41 +08:00
"execution_count": 669,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-26 21:36:29 +08:00
"def gen_loss(gen, disc, num_images, z_dim):\n",
2024-07-26 21:07:40 +08:00
"\n",
" noise = mx.array(get_noise(num_images, z_dim))\n",
2024-07-30 07:06:52 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" fake_images = gen(noise)\n",
" fake_disc = disc(fake_images)\n",
"\n",
2024-07-27 05:19:08 +08:00
" fake_labels = mx.ones((fake_images.shape[0],1))\n",
2024-07-30 00:44:16 +08:00
" \n",
2024-07-30 07:06:52 +08:00
" gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels)\n",
2024-07-26 21:07:40 +08:00
"\n",
" return gen_loss"
]
},
{
"cell_type": "code",
2024-07-30 07:44:41 +08:00
"execution_count": 670,
2024-07-30 00:44:16 +08:00
"metadata": {},
"outputs": [],
"source": [
"# Get only the training images\n",
"train_images,*_ = map(np.array, mnist.mnist())"
]
},
{
"cell_type": "code",
2024-07-30 07:44:41 +08:00
"execution_count": 671,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-30 00:44:16 +08:00
"# Normalize the images to fall between -1,1\n",
"train_images = train_images * 2.0 - 1.0"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 07:44:41 +08:00
"execution_count": 672,
2024-07-29 06:24:50 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2024-07-30 07:44:41 +08:00
"<matplotlib.image.AxesImage at 0x156ddd690>"
2024-07-29 06:24:50 +08:00
]
},
2024-07-30 07:44:41 +08:00
"execution_count": 672,
2024-07-29 06:24:50 +08:00
"metadata": {},
"output_type": "execute_result"
2024-07-30 00:44:16 +08:00
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaEAAAGdCAYAAAC7EMwUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAaI0lEQVR4nO3df2jU9x3H8dfVH1d1lytBk7vUmGVF202dpWrVYP3R1cxApf4oWMtGZEPa+YOJ/cGsDNNBjdgpRdI6V0amW239Y9a6KdUMTXRkijpdRYtYjDOdCcFM72LUSMxnf4hHz1j1e975vkueD/iCufu+vY/ffuvTby75xueccwIAwMBD1gsAAHRfRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJjpab2AW3V0dOjcuXMKBALy+XzWywEAeOScU0tLi/Ly8vTQQ3e+1km7CJ07d075+fnWywAA3Kf6+noNHDjwjvuk3afjAoGA9RIAAElwL3+fpyxCH3zwgQoLC/Xwww9r5MiR2rdv3z3N8Sk4AOga7uXv85REaPPmzVq8eLGWLVumI0eO6JlnnlFJSYnOnj2bipcDAGQoXyruoj1mzBg99dRTWrduXeyx73//+5o+fbrKy8vvOBuNRhUMBpO9JADAAxaJRJSVlXXHfZJ+JXTt2jUdPnxYxcXFcY8XFxertra20/5tbW2KRqNxGwCge0h6hM6fP6/r168rNzc37vHc3Fw1NjZ22r+8vFzBYDC28ZVxANB9pOwLE259Q8o5d9s3qZYuXapIJBLb6uvrU7UkAECaSfr3CfXv3189evTodNXT1NTU6epIkvx+v/x+f7KXAQDIAEm/Eurdu7dGjhypqqqquMerqqpUVFSU7JcDAGSwlNwxYcmSJfrpT3+qUaNGady4cfr973+vs2fP6tVXX03FywEAMlRKIjR79mw1NzfrN7/5jRoaGjRs2DDt2LFDBQUFqXg5AECGSsn3Cd0Pvk8IALoGk+8TAgDgXhEhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmelovAEgnPXr08DwTDAZTsJLkWLhwYUJzffv29Tzz+OOPe55ZsGCB55nf/va3nmfmzJnjeUaSrl696nlm5cqVnmfefvttzzNdBVdCAAAzRAgAYCbpESorK5PP54vbQqFQsl8GANAFpOQ9oaFDh+rvf/977ONEPs8OAOj6UhKhnj17cvUDALirlLwndOrUKeXl5amwsFAvvfSSTp8+/a37trW1KRqNxm0AgO4h6REaM2aMNm7cqJ07d+rDDz9UY2OjioqK1NzcfNv9y8vLFQwGY1t+fn6ylwQASFNJj1BJSYlmzZql4cOH67nnntP27dslSRs2bLjt/kuXLlUkEolt9fX1yV4SACBNpfybVfv166fhw4fr1KlTt33e7/fL7/enehkAgDSU8u8Tamtr05dffqlwOJzqlwIAZJikR+j1119XTU2N6urqdODAAb344ouKRqMqLS1N9ksBADJc0j8d9/XXX2vOnDk6f/68BgwYoLFjx2r//v0qKChI9ksBADJc0iP0ySefJPu3RJoaNGiQ55nevXt7nikqKvI8M378eM8zkvTII494npk1a1ZCr9XVfP31155n1q5d63lmxowZnmdaWlo8z0jSv//9b88zNTU1Cb1Wd8W94wAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAMz7nnLNexDdFo1EFg0HrZXQrTz75ZEJzu3fv9jzDf9vM0NHR4XnmZz/7meeZS5cueZ5JRENDQ0JzFy5c8Dxz8uTJhF6rK4pEIsrKyrrjPlwJAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwExP6wXA3tmzZxOaa25u9jzDXbRvOHDggOeZixcvep6ZPHmy5xlJunbtmueZP/3pTwm9Fro3roQAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADPcwBT63//+l9DcG2+84Xnm+eef9zxz5MgRzzNr1671PJOoo0ePep6ZMmWK55nW1lbPM0OHDvU8I0m//OUvE5oDvOJKCABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAw43POOetFfFM0GlUwGLReBlIkKyvL80xLS4vnmfXr13uekaSf//znnmd+8pOfeJ75+OOPPc8AmSYSidz1/3muhAAAZogQAMCM5wjt3btX06ZNU15ennw+n7Zu3Rr3vHNOZWVlysvLU58+fTRp0iQdP348WesFAHQhniPU2tqqESNGqKKi4rbPr1q1SmvWrFFFRYUOHjyoUCikKVOmJPR5fQBA1+b5J6uWlJSopKTkts855/Tee+9p2bJlmjlzpiRpw4YNys3N1aZNm/TKK6/c32oBAF1KUt8TqqurU2Njo4qLi2OP+f1+TZw4UbW1tbedaWtrUzQajdsAAN1DUiPU2NgoScrNzY17PDc3N/bcrcrLyxUMBmNbfn5+MpcEAEhjKfnqOJ/PF/exc67TYzctXbpUkUgkttXX16diSQCANOT5PaE7CYVCkm5cEYXD4djjTU1Nna6ObvL7/fL7/clcBgAgQyT1SqiwsFChUEhVVVWxx65du6aamhoVFRUl86UAAF2A5yuhS5cu6auvvop9XFdXp6NHjyo7O1uDBg3S4sWLtWLFCg0ePFiDBw/WihUr1LdvX7388stJXTgAIPN5jtChQ4c0efLk2MdLliyRJJWWluqPf/yj3nzzTV25ckXz58/XhQsXNGbMGO3atUuBQCB5qwYAdAncwBRd0rvvvpvQ3M1/VHlRU1Pjeea5557zPNPR0eF5BrDEDUwBAGmNCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZriLNrqkfv36JTT317/+1fPMxIkTPc+UlJR4ntm1a5fnGcASd9EGAKQ1IgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMNzAFvuGxxx7zPPOvf/3L88zFixc9z+zZs8fzzKFDhzzPSNL777/veSbN/ipBGuAGpgCAtEaEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmOEGpsB9mjFjhueZyspKzzOBQMDzTKLeeustzzMbN270PNPQ0OB5BpmDG5gCANIaEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGG5gCBoYNG+Z5Zs2aNZ5nfvSjH3meSdT69es9z7zzzjueZ/773/96noENbmAKAEhrRAgAYMZzhPbu3atp06YpLy9PPp9PW7dujXt+7ty58vl8cdvYsWOTtV4AQBfiOUKtra0aMWKEKioqvnWfqVOnqqGhIbbt2LHjvhYJAOiaenodKCkpUUlJyR338fv9CoVCCS8KANA9pOQ9oerqauXk5GjIkCGaN2+empqavnXftrY2RaPRuA0A0D0kPUIlJSX66KOPtHv3bq1evVoHDx7Us88+q7a2ttvuX15ermAwGNvy8/OTvSQAQJry/Om4u5k9e3bs18OGDdOoUaNUUFCg7du3a+bMmZ32X7p0qZYsWRL7OBqNEiI
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
2024-07-29 06:24:50 +08:00
}
],
"source": [
2024-07-30 00:44:16 +08:00
"plt.imshow(train_images[0].reshape(28,28),cmap='gray')"
2024-07-29 06:24:50 +08:00
]
},
{
"cell_type": "code",
2024-07-30 07:44:41 +08:00
"execution_count": 673,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-30 00:44:16 +08:00
"def batch_iterate(batch_size: int, ipt: list[int])-> list[int]:\n",
2024-07-29 06:24:50 +08:00
" perm = np.random.permutation(len(ipt))\n",
" for s in range(0, len(ipt), batch_size):\n",
2024-07-26 21:07:40 +08:00
" ids = perm[s : s + batch_size]\n",
2024-07-30 00:44:16 +08:00
" yield ipt[ids]"
]
},
{
"cell_type": "code",
2024-07-30 07:44:41 +08:00
"execution_count": 674,
2024-07-30 00:44:16 +08:00
"metadata": {},
"outputs": [],
"source": [
"def show_images(imgs:list[int],num_imgs:int = 25):\n",
" if (imgs.shape[0] > 0): \n",
" fig,axes = plt.subplots(5, 5, figsize=(5, 5))\n",
" \n",
" for i, ax in enumerate(axes.flat):\n",
" img = mx.array(imgs[i]).reshape(28,28)\n",
" ax.imshow(img,cmap='gray')\n",
" ax.axis('off')\n",
" plt.show()"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-07-30 00:44:16 +08:00
"### show first batch of images"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 07:44:41 +08:00
"execution_count": 675,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
2024-07-30 07:44:41 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAADhEUlEQVR4nOy9Z3OcV3rmf3XOOUc0ciJAMEASxaAsrTUzHnvssWu37ClX7QfYd/sp9hts2buu2rV3vZoZe8YjzUgaBQaJYiYygUbqnHOO/xf8n8NGIAmSDaC7cX5VKAU0we4Hz3Puc+77uq+b02w2m2AwGAwGo41wj/sNMBgMBqP3YMGFwWAwGG2HBRcGg8FgtB0WXBgMBoPRdlhwYTAYDEbbYcGFwWAwGG2HBRcGg8FgtB0WXBgMBoPRdlhwYTAYDEbb4R/0hRwO5zDfR1fxMqYG7Po9gV2/V+NlTTXYNXwCuwdfjYNcP3ZyYTAYDEbbYcGFwWAwGG2HBRcGg8FgtB0WXBgMBoPRdlhwYTAYDEbbYcGFwWAwGG3nwFJkBgN4uhyTzZxjMBitdGVw4XA44PP54HK59KvRaKBerwMAeDweOBwO6vU6Go3Gju8xXhyhUAilUgmZTIazZ8/C6XSCy+WCx+OhWCxiY2MD2WwWm5ub8Pv9x/12GQxGB9CVwYXL5UIkEoHH40EgEIDH46FWq6FSqYDD4UAoFILH46FcLqNSqdAgw3bXL4dYLIbZbIbJZMIvfvELXL58GQKBAEKhELFYDH/4wx/g8/nw+eefs+DCYDAAdElw4fF44HK5UKlUUCqVEAqFUKvV4PP5EIvFEAgEKJVKKBaL4PF4kEql4PP5yOfzKBQKKBaLiMViqFQqKBaLqFQqaDabLNgcEJFIBJPJBLPZDJVKBalUSgM7Ce5cLpd1MO9CLBZDLBbTe63ZbKJYLKJWqx33W2MwDp2ODy5cLhdSqRQikQgXLlzAa6+9BrVaDafTCYlEAo1GA6lUikwmg1gsBj6fD6PRCKFQiGg0ilgsBr/fjxs3biCZTMLtdiMajdKTDuP56HQ6XL58GXa7HTabDUKhEM1mE7VaDdVqlX6x1OMTOBwOzGYznE4nGo0GqtUqyuUyPB4PEonEcb89BuPQ6ejgwuFw6ElEIpHAZDJhYGAAarUa/f39kEgk0Gq1kMlkSKfT0Gg04PP5MJlMEIvFUKvVUKvVEAgE2NjYgEAgQDgcRjabRblcRrVaZaeXAyASiWAwGGAwGCCRSMDhcFCr1VAul1EsFlEqlVAqldBoNI77rR4KXO7+okoOh0NrfhwOh36R76nVahiNRtTrdZRKJVQqFUQiEXC5XHZyZvQ8HRtcBAIBDR4fffQRnE4nJiYmMDY2BrFYDIVCAT6fD5FIBACQSCQwGAzgcrkQCoXgcDhQqVS0GK3RaJBOp2E2m7G2tgafz4fV1VW6+2YP+l7IwimXy+FwOOBwOCCVSgEAq6ur+P777xGJRHD9+nWEw2EEg8FjfsftRyAQQK/XQyKR7Pken8+Hy+WCTqeDTCaDWq2mgYjL5cLpdMJms6FcLiOVSiGTyYDH44HP5yObzSKVSrH7jtGzdGxw4fP5kEql0Ov1ePPNNzE1NUWLyoTWB1MoFEIoFNL/5nA4kMlkkMlk0Ov1cLlcyOfzyGQyNK2zvb0NAKjVauwh3weyM5dIJDAajTAajTSY+3w+fPPNNwiFQrh//z5SqdTxvtlDgsfjQaPRQKlU7vmeWCzG6dOn0d/fD61WC6vVSpWKHA4HFosFRqMRhUIB4XAY8Xgca2triMfjaDabSKfT7L5j9CwdFVy4XC50Oh0UCgWMRiP6+/thNpvhcDigVqshFouf+zMajQaKxSLq9TpNUZDCP5/PR19fH01jpNNpZDIZrK+vI5/PH/bH6yo4HA4cDgfsdjsmJiboKbBQKCCfzyOVSqFYLPb8qU8qleLcuXPo6+vb8z0+n4/BwUEYjUbI5XJoNJodKTKpVAoulwuBQAClUgkOh4MzZ85AoVAgFApha2sL2WwW6+vryOVyx/Dp2g+Hw4FcLodYLIZSqYTBYIBIJIJGo6Ebk1ZqtRp8Ph/i8Tjy+TwSiQTq9Trq9TqazSZkMhmkUinNZHA4HFQqFdRqNbpZ7OX772kQxSzZALaKaQQCAbRaLUQiEfh8Png83g4RFPn/m5ubWFhYQLVaPZQNdkcFFx6Ph8HBQQwNDWFiYgLvvvsulEolLBYLfVCfR61WQyKRQLFYBAD6kJtMJohEIpw7dw6nT5+G0+mEUqmE1+tFLBZjwWUXXC4XZ86cwUcffQSHwwGz2QyRSITt7W3EYjEEAgFks1nkcrmeLuSr1Wr8/Oc/x6VLl/Z9+AQCAa0NkvuTPOjkoefxeBCLxdDr9dDr9bSwv7q6iq2tLfzjP/5jzwQXLpcLk8kEo9GIoaEhXLhwARqNBtPT09DpdHteXygU8O///u+Ym5vDxsYG7t+/j3K5TGt4Op0OVqsVSqUSVqsVHA4HqVQKhUIBXq+35++/pyEQCKBSqahak8fj0e8pFAqcOXOGpnPFYjF0Oh3Gx8chk8mg0+kglUrxz//8z/hv/+2/IZPJIJ/Pt13FeOzBhcPh0J2dVCqFzWaDzWaDxWKhBeR6vY5CoYBSqYRyuUx3g6SRj8vlUjVOuVxGMBhEsVikD7dWq4Ver6dRWyQSQalUQqvVIpPJ0AXiJO6AdkN+HyKRiBakyY6Hw+Egm80iHo8jlUohn8+jUqn0bCEfAK05qVQq+v+edZ+07iCr1SoqlQrdOfJ4PCgUCkgkEhQKBWQyGeRyOfD5x/4YvhDkHiFBUywWU9ECl8uFXq+nz7HZbIZGo4HRaIRGo9nzswqFAqxWK+LxOIrFIvx+P20ZaDabsFqtcDgckMvlNO0ol8vpCVoikaBSqfT8CZogkUggkUggk8lgMpkgFArpWkhQKBRwOBzQarX096PVamEymSCVSqHVauk/lUolzfa0m2O9q0lx0+Vy4ac//SnMZjNGR0fpTsVgMKBcLmNxcRHxeBwLCwtYWFiAQqGAy+WiBX+pVAq/34/V1VUUi0VEo1GUSiV6Yc+cOQOr1brjWK5WqzE0NAQOhwOFQgGhUIharXYid0GtiEQi9PX1QaVSYXJyEqdOnYJMJoNQKESxWMSDBw9w7949rK2tYW1tDZVKBeVy+bjfdsfRbDYRCATg9/uhVqsxMDCwI01hsVigVCohEAggk8mO++2+EEKhEC6Xi94jU1NTSKVSmJubQ7VaxQcffIDp6WmoVCqYzeZnfkaRSITXXnsNExMT2N7exszMDE17AUBfXx8cDgfdEHK5XNqr9u2336JUKiGVSsHv9x/KAtkpkI309PQ0zpw5A6PRiOnpacjlcnpfEfh8Pk1jE1GOUCiEXC6nG51mswmj0YjZ2VmEQiHcunWr7RL5Yw8uJBc4NTWFvr4+9PX1UdUXj8dDpVJBNBqF1+vFw4cPcePGDajVamQyGcjlcpjNZiiVSrjdbty5c4fmbSuVCmQyGSQSCRQKBUqlEprNJt1ZkjywSqWiv5xut4lpx+mLx+NBpVLBYDDQNA7ZGdXrdYTDYWrzkkqluvp6HZSnyYZ3/7/W/240GshmswiHw2g2m/Q6kRw5EZsEg0EIBILD/QBthKT51Go1DAYDhoeHMTs7i0gkglQqhXK5jNHRUczMzOyokzztGnK5XFitVgCP61ukpkKeVafTCbvdDoFAQL9fLpdRq9Xg8Xig1WrRaDQQCoWO+lIcGeSaCwQCmEwmjI2NwWaz4fXXX4dSqaT15N20Xu/dDc7NZhNSqRRGoxHVavVQ7sFjCy4cDgcTExOYnZ2F0+nE6Ogo9Ho95HI5OBwOCoUC0uk0IpEIbt68CbfbjY2NDeTzedTrdSwvL0MkEmFjYwNisRjxeByRSIQeqcnDXKlUsL29ja+//hpWqxUTExOwWq30whYKBQwPD6PZbMLv9yMcDh/XJXlhWgvHVqsVfD4fwWAQyWTypX+mTCbDuXPnMDg4iMHBQYhEIpRKJQSDQcTjcWxtbWF7exupVKqn02GEcrmMR48eQSaTwWw2w2g0olQ
2024-07-26 21:07:40 +08:00
"text/plain": [
2024-07-29 06:24:50 +08:00
"<Figure size 500x500 with 25 Axes>"
2024-07-26 21:07:40 +08:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
2024-07-30 00:44:16 +08:00
"X = batch_iterate(25, train_images)\n",
"for x in X: \n",
" show_images(x)\n",
2024-07-26 21:07:40 +08:00
" break"
]
},
2024-07-30 07:06:52 +08:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training Cycle"
]
},
2024-07-26 21:07:40 +08:00
{
"cell_type": "code",
2024-07-30 07:44:41 +08:00
"execution_count": 676,
2024-07-27 06:09:51 +08:00
"metadata": {},
2024-07-29 06:24:50 +08:00
"outputs": [],
2024-07-27 06:09:51 +08:00
"source": [
2024-07-30 07:44:41 +08:00
"lr = 1e-4\n",
"z_dim = 128\n",
2024-07-30 07:06:52 +08:00
"\n",
2024-07-27 06:09:51 +08:00
"gen = Generator(z_dim)\n",
"mx.eval(gen.parameters())\n",
2024-07-30 07:44:41 +08:00
"gen_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999]) #,betas=[0.5, 0.9]\n",
2024-07-27 06:09:51 +08:00
"\n",
"disc = Discriminator()\n",
"mx.eval(disc.parameters())\n",
2024-07-30 07:44:41 +08:00
"disc_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999])"
2024-07-29 06:30:08 +08:00
]
},
2024-07-27 06:09:51 +08:00
{
"cell_type": "code",
2024-07-30 07:44:41 +08:00
"execution_count": 679,
2024-07-27 06:09:51 +08:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-30 07:44:41 +08:00
" 0%| | 1/200 [00:05<18:01, 5.43s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Losses D=array(2.28961, dtype=float32) G=array(5.5491, dtype=float32)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 1%| | 2/200 [00:10<17:20, 5.26s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Losses D=array(2.28963, dtype=float32) G=array(5.5491, dtype=float32)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 2%|▏ | 3/200 [00:15<17:08, 5.22s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Losses D=array(2.28957, dtype=float32) G=array(5.5491, dtype=float32)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 2%|▏ | 4/200 [00:20<16:55, 5.18s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Losses D=array(2.28956, dtype=float32) G=array(5.5491, dtype=float32)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 2%|▎ | 5/200 [00:26<17:02, 5.24s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Losses D=array(2.28958, dtype=float32) G=array(5.5491, dtype=float32)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 3%|▎ | 6/200 [00:31<16:52, 5.22s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Losses D=array(2.28962, dtype=float32) G=array(5.5491, dtype=float32)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 4%|▎ | 7/200 [00:36<16:38, 5.18s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Losses D=array(2.28961, dtype=float32) G=array(5.5491, dtype=float32)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 4%|▍ | 8/200 [00:41<16:49, 5.26s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Losses D=array(2.28955, dtype=float32) G=array(5.5491, dtype=float32)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 4%|▍ | 9/200 [00:47<16:56, 5.32s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Losses D=array(2.2896, dtype=float32) G=array(5.5491, dtype=float32)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 5%|▌ | 10/200 [00:52<17:03, 5.39s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Losses D=array(2.28964, dtype=float32) G=array(5.5491, dtype=float32)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 6%|▌ | 11/200 [00:58<17:03, 5.42s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Losses D=array(2.28956, dtype=float32) G=array(5.54911, dtype=float32)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 6%|▌ | 12/200 [01:03<16:44, 5.35s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Losses D=array(2.28952, dtype=float32) G=array(5.5491, dtype=float32)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 6%|▋ | 13/200 [01:08<16:26, 5.28s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Losses D=array(2.28963, dtype=float32) G=array(5.5491, dtype=float32)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 7%|▋ | 14/200 [01:13<16:11, 5.23s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Losses D=array(2.28959, dtype=float32) G=array(5.5491, dtype=float32)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 8%|▊ | 15/200 [01:19<16:17, 5.28s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Losses D=array(2.2896, dtype=float32) G=array(5.5491, dtype=float32)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 8%|▊ | 16/200 [01:24<16:18, 5.32s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Losses D=array(2.28965, dtype=float32) G=array(5.5491, dtype=float32)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 8%|▊ | 17/200 [01:29<16:12, 5.31s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Losses D=array(2.28969, dtype=float32) G=array(5.54911, dtype=float32)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 9%|▉ | 18/200 [01:35<16:05, 5.31s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Losses D=array(2.28959, dtype=float32) G=array(5.5491, dtype=float32)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 10%|▉ | 19/200 [01:40<15:58, 5.30s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Losses D=array(2.28961, dtype=float32) G=array(5.54911, dtype=float32)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 10%|█ | 20/200 [01:45<15:47, 5.26s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Losses D=array(2.28956, dtype=float32) G=array(5.54911, dtype=float32)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 10%|█ | 21/200 [01:50<15:35, 5.23s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Losses D=array(2.2896, dtype=float32) G=array(5.54911, dtype=float32)\n",
"Step 21: Generator loss: array(5.54911, dtype=float32), discriminator loss: array(2.77676, dtype=float32)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9SYyl6ZUW/Nx5nucbcWMecx6r0lVlu9rudgENqBuwaCRYIFYs2ECDWLBgyQJ1b9gAO4QEiwZBN8LGZXfZLteclZmVGfMccYe48zyP/yL6OfnetH/IvJnol37FK6Xsyoy4w/u97znPec5zztGMRqMRLtflulyX63Jdrje4tP9ff4DLdbku1+W6XP//W5fO5XJdrst1uS7XG1+XzuVyXa7Ldbku1xtfl87lcl2uy3W5LtcbX5fO5XJdrst1uS7XG1+XzuVyXa7Ldbku1xtfl87lcl2uy3W5LtcbX5fO5XJdrst1uS7XG1+XzuVyXa7Ldbku1xtf+pf9wX/xL/4FtFotfD4fXC4XcrkcDg4O0Ov1YDAYoNVqMRwOMRgM4HQ6sb6+DpPJhMePH+Pg4ABmsxkOhwN6vR52ux0mkwmNRgP1eh06nQ5WqxVarRb9fh+DwQBerxdTU1Po9/uIx+Oo1+sol8uoVquYnp7Gb/3Wb8FmsyEej6NQKAAAtFoter0e8vk8er0erl27hvX1dZTLZezt7aFUKuGjjz7CwcEBlpaWcPPmTQQCAbz77rtwu914+PAhNjY2oNPpYDKZYDabsba2hmAwiLm5Oayvr0Or1UKj0bzyRv/hH/4h9Ho93nnnHVy/fh2PHz/Gf/pP/wmNRgOBQAAWiwWZTAaJRAKxWAw//OEP4ff78ejRI+zv7yMYDGJ5eRl6vR7dbheDwUD2r9PpoFgsotvtIpFIoFAoIBaL4cqVKzAajdBqtRiNRtjd3cXu7i68Xi+uXLkCi8WC4XCI0WgEj8eDUCiEWq2Gzz77DLlcDl6vF263GyaTCQ6HA8PhELu7u0in07h16xa+/e1vo9fr4ejoCI1GAwCg0WhgNpvhdrthMBjgdDphNBrR7XbRbrcxGo3wD//hP3zl/fujP/ojeX2NRoNUKoWnT5+i2+3C5XLBZDLJvwGAXq9Hv9/H/v4+UqkUvF4v5ufnodPp0Ol0MBgMsLCwgOXlZQDAcDhEv99HJpNBpVKBy+VCKBQCADQaDfR6PWi1F1jMYrHA5/NhOBzi0aNHODk5wdTUFJaXl9Hv93F6eopms4nV1VUsLS2h1+uhVquh2Wxic3MT5+fniMViWFlZgc1mw9TUFIxGI7766itsbm7C7XZjbm4ODocDV69eRTAYRDwex+HhIQaDAf7ZP/tnr7x/APBv/+2/hV6vx82bN7G8vIxnz57hv/7X/4pWq4WpqSnY7XbU63XU63V4PB7cunULJpMJ//N//k98+eWXWFpawnvvvQetVotEIoF6vQ7g4t4ZDAZYrVYMh0N88803OD09xfXr1/GDH/xAzl6pVEKhUEC5XEY0GsV3v/td2Gw2nJ2dIZ/Pw2q1wm63o1Kp4OOPP0Yul8Py8jLm5+dhsVjg9/vR7/fxzTffIJVKYW1tDffv34dOp8NgMEC/38f29jaOjo7gcDgQjUZhNBoxHA4BAB6PB4FAABqNBr//+7//yvv3z//5P4dWq8XMzAxCoRD29/fx4YcfotlsQqvVQqvV4vbt23jw4AHMZjPsdjuGwyG+/PJL7O7uQqfTwWg0Qq/Xw+12w2w2o1wuo1QqwWKxIBqNwmAwoFAooFarwWKxwOVyod1uY2dnB6VSCXa7HTabTV5Lo9FAp9NBo9HA4XAgEAhgMBigWq2i1+vBbrfDarViMBig0+mg1Wrh0aNHiMfjCAaDmJ6ehs1mw8zMDCwWC4rFIiqVCux2O8LhMDQaDQqFAtrtNux2O1wuFzQaDf7xP/7H/8f9emnn4vP5oNFoYDKZMBqNZDP6/T76/T6GwyEcDgecTieq1So++ugjNBoNWK1WcUjcPBq8wWCAwWCAbreLarWKfr+P0Wgkf1cul2EwGDA3NwedTicbnM1m8cknn0Cv16NaraLdbmN5eRm3bt3CaDRCMplEvV7H4eEhfvWrX8FqtSIQCMBoNOKDDz7ABx98IM6qUCjgq6++gtlsxmAwQCwWg1arhV6vh0ajQTqdRjqdBgCEQiHo9Xp4PJ5XPpi1Wg06nQ75fB7pdBq9Xg+xWAztdhtmsxk6nU5e22QyYXNzE3q9HqVSCcPhENVqFfv7+9Dr9bBardDr9cjlcshkMtBqtTCbzTAajXC5XNBqtXA4HNDpdPJdAMjh0+l0OD09FWMJALFYDF6vF1qtFpFIBBaLBblcDvF4HIFAADdv3oTVakUsFoPH44HT6USxWESj0cDh4SEqlQoMBgP0ej1sNhu63S50Op3sncvlgtvtHnvPV1nBYBBarRYmk0neJ5VKAQCuX7+OUCiEVquFVquFdrstAMPr9YoDpZExGAzQ6XTQ6XRIpVLQaDTigAHAZrPJMwEg55XOq1Qq4csvv0S32xUjNj8/j6tXr6Lb7WI4HKJSqaDZbGJnZwdarVaMnM/ng8VigcPhQL/fR7PZRC6Xk71KpVLo9XoIhUIwmUwALhyl3++Xuzfpslqt0Gg0SCQSKBaLKBaLCAaDGI1G8Hq9Ysx6vR7a7TY+/vhjdDodAWpmsxlnZ2cwm81YXFyEw+FANptFOp1Gv99HtVoVYGg2m+FyuZBKpWA0GhGNRjE9PY1SqSRA6Be/+AWGwyHcbjesViscDgfm5+dRrVaRTCZhs9mg1WqRTqeh0+lQKBTEBtEYFgoFAZTtdlteG7g47xaLBU6nE2azWQDjJOCQ54bnYDgcwuPx4N69e2K0+/0+/H4/hsMhisUinj59ina7Da1Wi/n5ebRaLVSrVQCAyWQa+879fh+NRgOdTgdOpxMej0eetU6nw507dzAajdBut9Hr9dDpdFAqlTAYDOQ8m0wmeDweOeedTgedTkfek693/fp13LhxA61WC7VaDe12G8fHx/K9hsMhrFYrPB4PtFotqtWqgI5er/fS+/XSzsXlco39t8lkEi/ZaDTQ7XYRCoUwPT2Ng4MDfP311zg/P8eDBw9w5coV+P1+zM7OwmAwyMPlBa/X6zg+Pkar1ZKDQbRntVqxuLgIl8uFTCaD4XCIcrmMXC43tmHz8/NYXl6GRqOB0WhEuVzGF198gT/7sz9DLBbD7/zO7yAYDOLevXuIxWJ49OgRfvazn6HZbKJUKkGr1WJxcVGci1arRafTwdHREUqlEnw+H6rVKgwGw0TOpdFoQK/Xo1wui+GLRCLo9Xro9XoS8en1etRqNRwdHaHVasFiscBsNqNer4sBDwaDYvwTiQQsFgump6dhNBrlQjIS1Gq14mRsNhu8Xi8ajQYymYw4AADiELRaLbxeL6xWKxKJBA4ODtDv93H79m1YLBYEg0H4fD5otVqUy2VUKhUkEgmUSiVxcA6HQw58s9lEt9sVpzSpcyG4sdvtsFgs6HQ6cLvd0Ol0uH79OhYWFlAul1EsFsXItVotTE9Pw+/3o1gs4uzsDKPRCC6XC0ajUQyjRqMRw2E2m+V78LMSaPC/q9UqPv30UzSbTbz//vuYnp5GNBrF4uIi2u02CoUC9Ho9stkszs/PYTab4fP5YDAY4Ha74fV6BUSNRiNUq1UMh0MUCgXk83mYTCZ0Oh30+325J16vF36/f6K946JzymazqNfr0Ov18Hq9wiYYjUa5k4lEAo8ePUKpVMLt27exvLyMWq2GdDoNp9OJd955B4uLi9ja2hJHWq/X0e/34XQ6EQ6HMRqNUCgUYLVacfXqVXg8HhQKBXi9XpycnODnP/85Wq0W7t27h7m5OdhsNkSjUTidTkxNTckZy+Vycpb1ej2mp6cFqPC9Dw8PUa/XxxwHnXgkEoHT6cRoNHot56zaLoLp9fV19Pt9lEoltFotuFwuAYNPnz5FrVbDzZs3sbCwIGdzNBrBYDDAZDLB5/MhEAigUqlgZ2cH3W4XgUAATqcTrVYL9XodBoNBIoxcLod8Po9KpSIOlXtDpmA
"text/plain": [
"<Figure size 500x500 with 25 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAADcNElEQVR4nOz9549cV37nj79v3Vt1K+ecujpHRpEUKUoaJUszsme8M96B1/aD9T7YZwv/OYtdGFhggcUahr3A16MZz4yyhiJFiRRDN8nOoXLOOf8e8HeOqppNiqHZXdV9XgChGXaxu+r2ufd9zie8P1y32+2CwWAwGIx9RHLYb4DBYDAYRw8mLgwGg8HYd5i4MBgMBmPfYeLCYDAYjH2HiQuDwWAw9h0mLgwGg8HYd5i4MBgMBmPfYeLCYDAYjH2HiQuDwWAw9h3haV/IcdzLfB9DxfOYGrDr9wPs+r0Yz2uqwa7hD7A1+GI8zfVjJxcGg8Fg7DtMXBgMBoOx7zBxYTAYDMa+w8SFwWAwGPsOExcGg8Fg7DtMXBgMBoOx7zx1KTKDwWAwDoanLXse5FmPTFwYDAbjkBEEAUqlEjKZDG63G2azGUajES6XCzzP09el02nkcjk0Gg0Ui0XUajXs7OwglUqh2+0OlNgwcWEwGIxDRhAE6PV6qNVqXLx4EXNzc5icnMSlS5cgiiI4jkO328Xa2hrW19dRLBYRCoWQz+dRr9eRyWTQ6XSYuDD64XkeEokE3W6XLpDdi0QikUAqlUIikUChUEAqlUIQBEil0h/93hqNBqIootVqodVqgeM4SCQSdDodpNNplEol1Ot11Gq1gdv9HCQcx9HrSpBIJJDL5RAEgf4BHl5XjuOgVqshiiJ9fbvdRqVSQbvdRrPZRKvVQrVaRSaTQbvdPvDPxBhspFIpZDIZDAYDZmZmoNfrMTY2BpfLBbPZDIVCAZlMBuBhCEyn08FqtUKtVkMikaBUKmFubg5SqRSlUqnvVNNsNg/1s3Hdp3ySMOuDH9hP6wiJRAK1Wg2ZTIZms4lGo4FOp4NGo9H3c5RKJQwGAxQKBcbHx6HX62EwGGAwGJ74c7VaLS5cuACr1YpcLodCoQCe56FQKFCv1/G73/0Od+/eRSwWg9/vR7PZRLvdfqkCM4jWGxKJBBKJBBaLBRqNhv48URTh8/mg1Wqh1+uh0+kglUohl8shl8sxOzsLp9MJ4OHnqlQqWF9fRy6XQzabRTabxc7ODj799FMUi8V9ea/M/uXFGZQ1aLFYYLVaMTU1hf/yX/4LXC4XTCYT3bQolUpIJA/rrrrdLhqNBur1OjqdDprNJprNJsLhMLLZLNbW1nD9+nWk02ncuHED6XR6398v4Wmu36GcXMiuj+wEu90ums0mOp3OS3+wDRoSiQQqlQpKpRL1eh31eh3tdhuCIKDT6dDXqdVqemy22+0wGo0wGo2wWCxPXPQ6nQ6zs7NwOBzIZDLIZDIQBAFqtRrVahWLi4uIxWJoNBpIJBLgeR7VavVY7LI5jgPHceB5HjKZDIIgQKfTQa/X09coFAo4HA7o9Xp6zaVSKRQKBRQKBWZnZ+HxeAA8vOHK5TIAIJvNIpVKQavVolarQa1Wo9FooNVqHYtr+zyQEzWAvjwD+Rr5L3l+ELrdLlqtFn1+9N43gwzHcVAoFDAYDLDZbBgdHYXb7YZcLodUKu3bZJJnIsdxkMlkkEgk4Hke3W4XMpkMNpsNrVYLwWAQPM9DLpeD5/lDDZUduLjIZDJ4vV7o9XqcPn0aZ8+eRTKZxNWrV5HJZOD3+5FMJg/6bR0aarUav/jFLzA/P49qtUpDKrVare8hpNfr4XK5IIoiLBYLlEol3T0/CVEUodFo6ALtdruQSCQQRREymQxvv/025ubmsLm5icXFRfq7OA6/AyIkNpsN58+fh16vh9vthtFopK8hsXAiKKIogud5CIIAnueh1+v7bl6ZTIaRkRHY7XZUKhVUq1X4fD7I5XKkUil899138Pv9NER5nJFIJFTgyQneaDRCLpfD5XJBqVT2vVYqlUIURUxMTMBisdCv5fN53Lp1C6lUChsbG/D7/YfxcZ4Jsobm5+fx/vvvw+VywWKxQBRFlMtl1Ot1RKNRLC8vo9FooFKpoNVqQa1W0+s0NzcHpVIJqVQKg8GA2dlZaLVahEIhJJNJKBQKuqE8DA5cXKRSKZxOJxwOB9599138+te/xubmJsrlMvx+P7LZ7LF4sBEUCgVee+01vPPOO6hUKiiXyzRO3ysuRqMRXq+X3mBkZ/djuxKyw9tr9yMIAk6dOgUAWF9fh1arhd/vx9LS0rH4HahUKthsNkxMTOCDDz6Aw+GA1+uFyWR6pu/T+zsQBAFWqxUA0Ol00Ol0YLVawfM84vE4QqEQotEoABxrcSGCQk4iRFzsdjt0Oh0WFhb6Qr6CIEAURajVarzxxhsYHx8H8PDaR6NRyOVy7OzsIJ/PD7y4kM8sCAJ8Ph/eeOMN6HQ6GnIlFWFbW1u4evUqSqUS0uk0ms0mzGYzzGYz3cAADzeeKpUKKpUKXq8Xfr8ff/rTn1Aul9FoNI6PuHS7XbTbbXqM7Xa7UKvVOHHiBMxmMxKJBNLpNBqNBqrV6pEPkTWbTfj9fty/fx9arZYusN1hsU6ng0AgAAA0dFitVlGv16kY/Vg4oFgsolgsQhRFGI1GiKKI8fFxWCwWSCQSjIyMQCKRYGxsDO12G9lsFvl8/qV+/sOC4ziYTCbMzMxgZGQEFosFBoOhLzn/JIhYk1BX7w6chC3I/1coFHA6nZDJZPD5fEin00in04hGo0MTwnkROI6jGyKLxQKz2Qye5+nfkYIJs9kMt9sNlUoFj8cDlUpFvwfP8zTXpVar+76/KIrweDyQSCRYW1uDRqNBq9WiBSqDTCaTwcbGBhQKBUKhEDiOo9GbUCiEjY0N1Go1lEoltFotVCoVZLNZFAoFiKIIk8mE8+fPw+fz0TWoVCqxsLAAlUqFVqtF19lBh2MPRVyazSZ9KAIPk1p//ud/jnw+j0QigWQyiVwuh2g0euTj07VaDTdv3kQ6ncaFCxfwxhtv0FhpL9FoFEtLS6hUKshkMqhWq0in00ilUiiXy4jFYqjX60/8WaRggORcdDod/u7v/g4XL16ETqfDK6+8ApfLhbW1NZhMJiwuLh5ZcQGA0dFRvP3227Db7ZiYmIBWq6Ux/ydBNkjtdhv5fB6VSoXGwEmIQiaT0Ztdp9Nhfn4ehUIBGxsb4Hkey8vLiMfjx0JceJ6HVquFQqHAhQsXcOHCBRqulUql0Ov1UCqVMJlMcDgcfVV5vRDBlkqljxS7nDp1CqOjo9ja2sLq6ipKpRIajcbAPj9IRCEYDOLKlSvodDqo1Wqo1WpYXl5GMBhEq9VCvV5/JKRNBOTmzZu0oszhcNDKM71ej5/+9KfI5/Mol8u4d+8ems3mgYvtoYhLo9FArVajuz6e56FSqdDtdmGxWOBwONDtdhGLxQ767R047XYbmUwGSqUSkUgEoVDokWQmAEQiEUQikT5xyWQySKVSqFarSCQSaDQaT/xZJM4vCAKq1SpqtRoSiQQSiQRdmHK5HAqFgu4mjyLkRCGKIrRaLVQqFWQyWd91J6XEJKlKEsW9/221WrSUm4gLKRNXq9XgeZ6GfMjPI6Iul8uPXPVWbzKe5EjkcjlkMhksFgtUKhUNiZNrIZVKodPpoFAooNVqodVqwXEcjWr0hs96hb/3IUlOh61WCxqNBmq1Gp1OBxKJZGDFhbQdlMtlJJNJ+r9rtRpSqRSy2ewT2wK63S4Nd6XTaWQyGahUKuj1etp+QIqFBEFAu92mvTIHxYE/PcgxLZfLIRKJIJfLQaFQQKlUQqlU4s0334TH48GVK1ews7Nz5OPSjUYDy8vL2N7extLSEv7t3/5tz91zrVZDoVBAq9Xqe+iRBx8pT3wSvaGcer2OarWKr7/+GuFwGJcvX4bP5wMA2j/zNLv4YYNUKUqlUmi1WlgsFnpD9pLP5xEMBlEul7G5uYlCoYB8Po9isYh2u41Go4Fms4lQKIRMJkMfpk6nE3/7t3+L8fFxmEymvuIAiUQ
"text/plain": [
"<Figure size 500x500 with 25 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 11%|█ | 22/200 [01:56<15:59, 5.39s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Losses D=array(2.28963, dtype=float32) G=array(5.5491, dtype=float32)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 12%|█▏ | 23/200 [02:01<15:38, 5.30s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Losses D=array(2.28961, dtype=float32) G=array(5.5491, dtype=float32)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 12%|█▏ | 24/200 [02:06<15:28, 5.28s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Losses D=array(2.28957, dtype=float32) G=array(5.5491, dtype=float32)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 12%|█▎ | 25/200 [02:12<15:15, 5.23s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Losses D=array(2.28959, dtype=float32) G=array(5.54911, dtype=float32)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 13%|█▎ | 26/200 [02:17<15:09, 5.23s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Losses D=array(2.28964, dtype=float32) G=array(5.54911, dtype=float32)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 13%|█▎ | 26/200 [02:19<15:36, 5.38s/it]\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[679], line 32\u001b[0m\n\u001b[1;32m 29\u001b[0m gen_opt\u001b[38;5;241m.\u001b[39mupdate(gen, G_grads)\n\u001b[1;32m 31\u001b[0m \u001b[38;5;66;03m# Update gradients\u001b[39;00m\n\u001b[0;32m---> 32\u001b[0m \u001b[43mmx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43meval\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgen\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparameters\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgen_opt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstate\u001b[49m\u001b[43m)\u001b[49m \n\u001b[1;32m 34\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (cur_step \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m) \u001b[38;5;241m%\u001b[39m display_step \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 35\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mStep \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: Generator loss: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mG_loss\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, discriminator loss: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mD_loss\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
2024-07-29 00:18:35 +08:00
]
2024-07-26 21:07:40 +08:00
}
],
"source": [
2024-07-30 00:44:16 +08:00
"# Set your parameters\n",
"n_epochs = 200\n",
"display_step = 5000\n",
2024-07-28 06:10:19 +08:00
"cur_step = 0\n",
2024-07-30 00:44:16 +08:00
"\n",
2024-07-30 07:44:41 +08:00
"batch_size = 128\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-28 06:10:19 +08:00
"D_loss_grad = nn.value_and_grad(disc, disc_loss)\n",
"G_loss_grad = nn.value_and_grad(gen, gen_loss)\n",
"\n",
"\n",
2024-07-30 00:44:16 +08:00
"for epoch in tqdm(range(n_epochs)):\n",
2024-07-28 06:10:19 +08:00
"\n",
2024-07-30 00:44:16 +08:00
" for real in batch_iterate(batch_size, train_images):\n",
2024-07-28 06:10:19 +08:00
" \n",
2024-07-30 00:44:16 +08:00
" # TODO Train Discriminator\n",
2024-07-30 07:06:52 +08:00
" D_loss,D_grads = D_loss_grad(gen, disc,mx.array(real), batch_size, z_dim)\n",
2024-07-28 06:10:19 +08:00
"\n",
" # Update optimizer\n",
" disc_opt.update(disc, D_grads)\n",
" \n",
" # Update gradients\n",
" mx.eval(disc.parameters(), disc_opt.state)\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-30 00:44:16 +08:00
" # TODO Train Generator\n",
" G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)\n",
2024-07-28 06:10:19 +08:00
" \n",
" # Update optimizer\n",
" gen_opt.update(gen, G_grads)\n",
" \n",
" # Update gradients\n",
2024-07-30 07:37:09 +08:00
" mx.eval(gen.parameters(), gen_opt.state) \n",
2024-07-30 07:44:41 +08:00
" \n",
" if (cur_step + 1) % display_step == 0:\n",
" print(f\"Step {epoch}: Generator loss: {G_loss}, discriminator loss: {D_loss}\")\n",
" fake_noise = mx.array(get_noise(batch_size, z_dim))\n",
" fake = gen(fake_noise)\n",
" show_images(fake)\n",
" show_images(real)\n",
" cur_step += 1\n",
"\n",
" print('Losses D={0} G={1}'.format(D_loss,G_loss))"
2024-07-27 05:19:08 +08:00
]
2024-07-26 21:07:40 +08:00
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}