mlx-examples/gan/playground.ipynb

637 lines
885 KiB
Plaintext
Raw Normal View History

2024-07-26 21:07:40 +08:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Import Library"
]
},
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 1,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"import mnist"
]
},
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 2,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"import mlx.core as mx\n",
"import mlx.nn as nn\n",
"import mlx.optimizers as optim\n",
"\n",
"from tqdm import tqdm\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
2024-07-30 07:17:12 +08:00
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 3,
2024-07-30 07:17:12 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-31 16:25:39 +08:00
"# mx.set_default_device(mx.gpu)\n",
"mx.random.seed(42)"
2024-07-30 07:17:12 +08:00
]
},
2024-07-26 21:07:40 +08:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# GAN Architecture"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generator 👨🏻‍🎨"
]
},
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 4,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def GenBlock(in_dim:int,out_dim:int):\n",
" \n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
2024-07-30 07:06:52 +08:00
" nn.BatchNorm(out_dim, 0.8),\n",
" nn.LeakyReLU(0.2)\n",
2024-07-26 21:07:40 +08:00
" )"
]
},
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 5,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"class Generator(nn.Module):\n",
"\n",
2024-07-31 00:50:02 +08:00
" def __init__(self, z_dim:int = 32, im_dim:int = 784, hidden_dim: int = 256):\n",
2024-07-26 21:07:40 +08:00
" super(Generator, self).__init__()\n",
2024-07-30 18:21:38 +08:00
"\n",
2024-07-26 21:07:40 +08:00
" self.gen = nn.Sequential(\n",
" GenBlock(z_dim, hidden_dim),\n",
" GenBlock(hidden_dim, hidden_dim * 2),\n",
" GenBlock(hidden_dim * 2, hidden_dim * 4),\n",
"\n",
2024-07-30 07:56:13 +08:00
" nn.Linear(hidden_dim * 4,im_dim),\n",
2024-07-26 21:07:40 +08:00
" )\n",
" \n",
" def __call__(self, noise):\n",
2024-07-30 07:06:52 +08:00
" x = self.gen(noise)\n",
" return mx.tanh(x)"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 6,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Generator(\n",
" (gen): Sequential(\n",
" (layers.0): Sequential(\n",
2024-07-31 00:50:02 +08:00
" (layers.0): Linear(input_dims=100, output_dims=256, bias=True)\n",
" (layers.1): BatchNorm(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
2024-07-30 07:06:52 +08:00
" (layers.2): LeakyReLU()\n",
2024-07-26 21:07:40 +08:00
" )\n",
" (layers.1): Sequential(\n",
2024-07-31 00:50:02 +08:00
" (layers.0): Linear(input_dims=256, output_dims=512, bias=True)\n",
" (layers.1): BatchNorm(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
2024-07-30 07:06:52 +08:00
" (layers.2): LeakyReLU()\n",
2024-07-26 21:07:40 +08:00
" )\n",
" (layers.2): Sequential(\n",
2024-07-31 00:50:02 +08:00
" (layers.0): Linear(input_dims=512, output_dims=1024, bias=True)\n",
" (layers.1): BatchNorm(1024, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
2024-07-30 07:06:52 +08:00
" (layers.2): LeakyReLU()\n",
2024-07-26 21:07:40 +08:00
" )\n",
2024-07-31 00:50:02 +08:00
" (layers.3): Linear(input_dims=1024, output_dims=784, bias=True)\n",
2024-07-26 21:07:40 +08:00
" )\n",
")"
]
},
2024-08-01 00:59:36 +08:00
"execution_count": 6,
2024-07-26 21:07:40 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gen = Generator(100)\n",
"gen"
]
},
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 7,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-30 00:44:16 +08:00
"# make 2D noise with shape n_samples x z_dim\n",
"def get_noise(n_samples:list[int], z_dim:int)->list[int]:\n",
" return mx.random.normal(shape=(n_samples, z_dim))"
]
},
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 8,
2024-07-30 00:44:16 +08:00
"metadata": {},
"outputs": [
{
"data": {
2024-07-31 16:25:39 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAWNklEQVR4nO3ca2zW9d3H8U9LobRFQCiM80GtdMgKioJYg8VJBIwRmIsYcCw76FTCxsZmZjZM3IxzU4kYNZqFLZ5ZHJODGkCshIMoIudxGEUqZwQpbS1toVz3s29yP7quz+9OvO/ceb8eX+//tdSWz/5PvnmZTCYjAAAk5f9v/w8AAPzfwSgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgFOT6wYceesh+eLt27eymrKzMbiRp165ddjNq1Ci72bFjh90UFxfbzdatW+1Gkr773e/aTWNjo93cfPPNdvPWW2/ZjSQdPnzYbu666y67+c9//mM3Q4YMsZtXXnnFbiRpwIABdnPmzBm7qaystJv6+nq7+eijj+xGSvu36KuvvrKblP9OtbW1diNJ06dPt5uU39c//vGPWT/DmwIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIOR/E69+/v/3w8vJyu3n33XftRpIaGhrspq2tzW66d+9uNx9//PE38j2S9K1vfctutmzZYjeHDh2ymw4dOtiNlHZQMOVnfv3119tNYWGh3XTu3NluJOnkyZN2k3JobenSpXZTUlJiN8OHD7cbSZozZ47djBw50m4GDx5sN3369LEbSdq0aZPd3HrrrUnflQ1vCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACDkfBCvqqrKfvjTTz9tN6WlpXaT6uzZs3aTyWTsZvTo0XbT1NRkN5L0zDPP2M3tt99uNymH9woKcv51+29OnDhhN3v37rWblKNua9assZuysjK7kaRu3brZTcphxZSfd0VFhd2k/K5K0g033GA3AwcOtJtjx47ZTcq/KVLa8dCUI3q5/K3zpgAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACDmfrUy5GDhkyBC7OX78uN1IaZdI6+rq7Gbw4MF209LSYjepV1JTrtmmXN88fPiw3RQVFdmNJHXt2tVuLly4YDdr1661m5TLr9u2bbMbKe3SZ8rvXsr3nDt3zm5Srp1K0m9/+1u7efDBB+1m2rRpdnP99dfbjSRdeeWVdpN6ZTYb3hQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAyPmaV3V1tf3wUaNG2U27du3sRpI+++wzu5k6dardrFixwm5uuukmuzl9+rTdSNKUKVPsZt26dXbT3NxsN3l5eXYjSYMGDbKblENwmUzGbn7/+9/bzbx58+xGki5evGg3KYcBd+/ebTc/+clP7Gbu3Ll2I0nvvPOO3aQcsuzdu7fdNDY22o0kPfXUU3ZTUlKS9F3Z8KYAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAQs4H8S6//HL74Rs3brSbDh062I2UdjRt586ddnPJJZfYzYIFC+xmzpw5diNJZ86csZtNmzbZTcrhvRdeeMFuJOncuXN206tXL7spKMj5zyFceumldrNv3z67kaSOHTvaTcoxxm7dutnN888/bzePPPKI3UjSww8/bDdVVVV2s3z5cru544477EZK+/dr//79Sd+VDW8KAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIOR8ASzl8NeYMWPspqWlxW4k6eTJk3YzceJEu3niiSfs5tprr7Wb3bt3240klZaW2k1zc7PddOnSxW7Ky8vtRpJOnTplNyk/h5QDjj/84Q/tpqamxm4kafLkyXbz6KOP2s1zzz1nN3V1dXazZs0au5GkDz/80G5aW1vtJuUA4QcffGA3kjR37ly7uf/++5O+KxveFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEDI+SDenj177IfX19fbTVFRkd1I0qeffmo3e/futZsNGzbYTcr/tq1bt9qNJHXr1s1utm/fbjfV1dV2M3jwYLuRpGHDhtnNokWL7KaystJuVqxYYTedO3e2G0k6ePCg3SxYsMBu3n77bbspKSmxm+LiYruRpC1bttjN7373O7tp37693aT8/UnSn/70J7upra1N+q5seFMAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAIS8TCaTyeWDP//5z+2Hf//737ebzz//3G4kaeTIkXazc+dOuzlx4oTdpFxb/OKLL+xGklpbW+2mqanJbsaPH283K1eutBtJuvvuu+1m8+bNdnPkyBG7ue222+xm/fr1diNJHTt2tJsvv/zSblJ+3ilXaVOunUrS448/bjfvv/++3bS0tNhN6gXcrl272s0///lPu1myZEnWz/CmAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAELOB/E2btxoP/yJJ56wmwkTJtiNlHb4K+UIVYqUQ2YbNmxI+q6U43t9+vT5RpqUY4KS1K5dO7vp1KmT3XTp0sVu3njjDbtJOegmScuXL7ebmTNn2s2dd95pN++8847dVFVV2Y0kzZ49225SDiT++Mc/tpuUI5uS1NzcbDcXL160m1mzZmX9DG8KAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIBTk+sGFCxfaD29sbLSbhoYGu5GkgQMH2k1ZWZnd1NXV2c3LL79sNylH4CTpmmuusZv+/fvbzdq1a+1m0KBBdiOlHS6sr6+3m2PHjtnNpEmT7CY/P+3/i1VUVNjNX/7yF7t56qmn7Obo0aN2M2fOHLuRpKlTp9rNlVdeaTcjRoywm2XLltmNlPbv3rBhw5K+KxveFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEDI+SBeU1OT/fCnn37abv71r3/ZjSQVFhbazZIlS+ymb9++djNx4kS7ufrqq+1Gkv7whz/YzW9+8xu76dGjh9106dLFbiRp8eLFdjN79my7Wbdund2k/N59/vnndiNJw4cPt5s9e/bYzeuvv243d999t9289tprdiNJjz/+uN1MmTLFbjp37mw3o0ePthsp7Xfv3//+d9J3ZcObAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAg5HwltaioyH74Aw88YDcp1xYl6eWXX7abq666ym5aW1vtprq62m5qa2vtRpImT55sNy+99JLdlJeX203KlU9J2r59u92kXLjs16+f3ZSUlNhNp06d7EZKu6TZ3NxsNyl/gymXVW+66Sa7kaRp06bZzaZNm+xm7ty5dnPvvffajSTdcsstdpPyO54L3hQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAyPkg3tixY/2HF+T8+P+xyspKu+nWrZvd7N27126GDRtmN01NTXYjpf3M29ra7KZ9+/bfSCNJFRUVdvP888/bzYMPPmg3b7/9tt1MmTLFblK/q6yszG42b95sN9dee63drFy50m4kad68eXaT8jf4ox/9yG4OHTpkN1LaAcyjR4/azahRo7J+hjcFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEHK+nta7d2/74ePGjbObMWPG2I0kvfDCC3Zz5MgRu7niiivs5rP
2024-07-30 00:44:16 +08:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"img = get_noise(28,28)\n",
"plt.imshow(img, cmap='gray')\n",
"plt.axis('off')\n",
"plt.show()"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Discriminator 🕵🏻‍♂️"
]
},
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 9,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def DisBlock(in_dim:int,out_dim:int):\n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
2024-07-30 00:44:16 +08:00
" nn.LeakyReLU(negative_slope=0.2),\n",
2024-07-30 18:21:38 +08:00
" nn.Dropout(0.3),\n",
2024-07-26 21:07:40 +08:00
" )"
]
},
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 10,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"class Discriminator(nn.Module):\n",
"\n",
2024-07-31 00:50:02 +08:00
" def __init__(self,im_dim:int = 784, hidden_dim:int = 256):\n",
2024-07-26 21:07:40 +08:00
" super(Discriminator, self).__init__()\n",
"\n",
" self.disc = nn.Sequential(\n",
2024-07-30 07:37:09 +08:00
" DisBlock(im_dim, hidden_dim * 4),\n",
" DisBlock(hidden_dim * 4, hidden_dim * 2),\n",
2024-07-26 21:07:40 +08:00
" DisBlock(hidden_dim * 2, hidden_dim),\n",
2024-07-30 07:37:09 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" nn.Linear(hidden_dim,1),\n",
2024-07-30 18:21:38 +08:00
" nn.Sigmoid()\n",
2024-07-26 21:07:40 +08:00
" )\n",
" \n",
2024-07-30 18:21:38 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" def __call__(self, noise):\n",
2024-07-30 18:21:38 +08:00
" return self.disc(noise)"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 11,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Discriminator(\n",
" (disc): Sequential(\n",
" (layers.0): Sequential(\n",
2024-07-31 00:50:02 +08:00
" (layers.0): Linear(input_dims=784, output_dims=1024, bias=True)\n",
2024-07-26 21:07:40 +08:00
" (layers.1): LeakyReLU()\n",
2024-07-30 18:21:38 +08:00
" (layers.2): Dropout(p=0.30000000000000004)\n",
2024-07-26 21:07:40 +08:00
" )\n",
" (layers.1): Sequential(\n",
2024-07-31 00:50:02 +08:00
" (layers.0): Linear(input_dims=1024, output_dims=512, bias=True)\n",
2024-07-30 07:37:09 +08:00
" (layers.1): LeakyReLU()\n",
2024-07-30 18:21:38 +08:00
" (layers.2): Dropout(p=0.30000000000000004)\n",
2024-07-30 07:37:09 +08:00
" )\n",
" (layers.2): Sequential(\n",
2024-07-31 00:50:02 +08:00
" (layers.0): Linear(input_dims=512, output_dims=256, bias=True)\n",
2024-07-26 21:07:40 +08:00
" (layers.1): LeakyReLU()\n",
2024-07-30 18:21:38 +08:00
" (layers.2): Dropout(p=0.30000000000000004)\n",
2024-07-26 21:07:40 +08:00
" )\n",
2024-07-31 00:50:02 +08:00
" (layers.3): Linear(input_dims=256, output_dims=1, bias=True)\n",
2024-07-30 18:21:38 +08:00
" (layers.4): Sigmoid()\n",
2024-07-26 21:07:40 +08:00
" )\n",
")"
]
},
2024-08-01 00:59:36 +08:00
"execution_count": 11,
2024-07-26 21:07:40 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"disc = Discriminator()\n",
"disc"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model Training 🏋🏻‍♂️"
]
},
{
2024-07-30 00:44:16 +08:00
"cell_type": "markdown",
2024-07-26 21:07:40 +08:00
"metadata": {},
"source": [
2024-07-30 00:44:16 +08:00
"### Losses"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-07-30 00:44:16 +08:00
"#### Discriminator Loss"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 12,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-26 21:36:29 +08:00
"def disc_loss(gen, disc, real, num_images, z_dim):\n",
2024-07-30 00:44:16 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" noise = mx.array(get_noise(num_images, z_dim))\n",
" fake_images = gen(noise)\n",
2024-07-27 06:09:51 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" fake_disc = disc(fake_images)\n",
" \n",
2024-07-27 05:19:08 +08:00
" fake_labels = mx.zeros((fake_images.shape[0],1))\n",
2024-07-30 18:21:38 +08:00
" \n",
" fake_loss = mx.mean(nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True))\n",
2024-07-30 00:44:16 +08:00
" \n",
2024-07-30 18:21:38 +08:00
" real_disc = mx.array(disc(real))\n",
2024-07-27 05:19:08 +08:00
" real_labels = mx.ones((real.shape[0],1))\n",
2024-07-30 18:21:38 +08:00
" \n",
" real_loss = mx.mean(nn.losses.binary_cross_entropy(real_disc,real_labels,with_logits=True))\n",
" \n",
2024-07-31 00:50:02 +08:00
" disc_loss = (fake_loss + real_loss) / 2.0\n",
2024-07-26 21:07:40 +08:00
"\n",
" return disc_loss"
]
},
2024-07-30 00:44:16 +08:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Generator Loss"
]
},
2024-07-26 21:07:40 +08:00
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 13,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-26 21:36:29 +08:00
"def gen_loss(gen, disc, num_images, z_dim):\n",
2024-07-26 21:07:40 +08:00
"\n",
" noise = mx.array(get_noise(num_images, z_dim))\n",
2024-07-30 07:06:52 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" fake_images = gen(noise)\n",
2024-07-30 18:21:38 +08:00
" fake_disc = mx.array(disc(fake_images))\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-27 05:19:08 +08:00
" fake_labels = mx.ones((fake_images.shape[0],1))\n",
2024-07-30 00:44:16 +08:00
" \n",
2024-07-30 18:21:38 +08:00
" gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)\n",
" \n",
" return mx.mean(gen_loss)"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 14,
2024-07-30 00:44:16 +08:00
"metadata": {},
"outputs": [],
"source": [
"# Get only the training images\n",
2024-07-30 18:21:38 +08:00
"train_images,train_labels,*_ = map(np.array, mnist.mnist())"
2024-07-30 00:44:16 +08:00
]
},
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 15,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-30 00:44:16 +08:00
"# Normalize the images to fall between -1,1\n",
"train_images = train_images * 2.0 - 1.0"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 16,
2024-07-29 06:24:50 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2024-08-01 06:04:14 +08:00
"<matplotlib.image.AxesImage at 0x11e6a3880>"
2024-07-29 06:24:50 +08:00
]
},
2024-08-01 00:59:36 +08:00
"execution_count": 16,
2024-07-29 06:24:50 +08:00
"metadata": {},
"output_type": "execute_result"
2024-07-30 00:44:16 +08:00
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaEAAAGdCAYAAAC7EMwUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAaI0lEQVR4nO3df2jU9x3H8dfVH1d1lytBk7vUmGVF202dpWrVYP3R1cxApf4oWMtGZEPa+YOJ/cGsDNNBjdgpRdI6V0amW239Y9a6KdUMTXRkijpdRYtYjDOdCcFM72LUSMxnf4hHz1j1e975vkueD/iCufu+vY/ffuvTby75xueccwIAwMBD1gsAAHRfRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJjpab2AW3V0dOjcuXMKBALy+XzWywEAeOScU0tLi/Ly8vTQQ3e+1km7CJ07d075+fnWywAA3Kf6+noNHDjwjvuk3afjAoGA9RIAAElwL3+fpyxCH3zwgQoLC/Xwww9r5MiR2rdv3z3N8Sk4AOga7uXv85REaPPmzVq8eLGWLVumI0eO6JlnnlFJSYnOnj2bipcDAGQoXyruoj1mzBg99dRTWrduXeyx73//+5o+fbrKy8vvOBuNRhUMBpO9JADAAxaJRJSVlXXHfZJ+JXTt2jUdPnxYxcXFcY8XFxertra20/5tbW2KRqNxGwCge0h6hM6fP6/r168rNzc37vHc3Fw1NjZ22r+8vFzBYDC28ZVxANB9pOwLE259Q8o5d9s3qZYuXapIJBLb6uvrU7UkAECaSfr3CfXv3189evTodNXT1NTU6epIkvx+v/x+f7KXAQDIAEm/Eurdu7dGjhypqqqquMerqqpUVFSU7JcDAGSwlNwxYcmSJfrpT3+qUaNGady4cfr973+vs2fP6tVXX03FywEAMlRKIjR79mw1NzfrN7/5jRoaGjRs2DDt2LFDBQUFqXg5AECGSsn3Cd0Pvk8IALoGk+8TAgDgXhEhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmelovAEgnPXr08DwTDAZTsJLkWLhwYUJzffv29Tzz+OOPe55ZsGCB55nf/va3nmfmzJnjeUaSrl696nlm5cqVnmfefvttzzNdBVdCAAAzRAgAYCbpESorK5PP54vbQqFQsl8GANAFpOQ9oaFDh+rvf/977ONEPs8OAOj6UhKhnj17cvUDALirlLwndOrUKeXl5amwsFAvvfSSTp8+/a37trW1KRqNxm0AgO4h6REaM2aMNm7cqJ07d+rDDz9UY2OjioqK1NzcfNv9y8vLFQwGY1t+fn6ylwQASFNJj1BJSYlmzZql4cOH67nnntP27dslSRs2bLjt/kuXLlUkEolt9fX1yV4SACBNpfybVfv166fhw4fr1KlTt33e7/fL7/enehkAgDSU8u8Tamtr05dffqlwOJzqlwIAZJikR+j1119XTU2N6urqdODAAb344ouKRqMqLS1N9ksBADJc0j8d9/XXX2vOnDk6f/68BgwYoLFjx2r//v0qKChI9ksBADJc0iP0ySefJPu3RJoaNGiQ55nevXt7nikqKvI8M378eM8zkvTII494npk1a1ZCr9XVfP31155n1q5d63lmxowZnmdaWlo8z0jSv//9b88zNTU1Cb1Wd8W94wAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAMz7nnLNexDdFo1EFg0HrZXQrTz75ZEJzu3fv9jzDf9vM0NHR4XnmZz/7meeZS5cueZ5JRENDQ0JzFy5c8Dxz8uTJhF6rK4pEIsrKyrrjPlwJAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwExP6wXA3tmzZxOaa25u9jzDXbRvOHDggOeZixcvep6ZPHmy5xlJunbtmueZP/3pTwm9Fro3roQAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADPcwBT63//+l9DcG2+84Xnm+eef9zxz5MgRzzNr1671PJOoo0ePep6ZMmWK55nW1lbPM0OHDvU8I0m//OUvE5oDvOJKCABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAw43POOetFfFM0GlUwGLReBlIkKyvL80xLS4vnmfXr13uekaSf//znnmd+8pOfeJ75+OOPPc8AmSYSidz1/3muhAAAZogQAMCM5wjt3btX06ZNU15ennw+n7Zu3Rr3vHNOZWVlysvLU58+fTRp0iQdP348WesFAHQhniPU2tqqESNGqKKi4rbPr1q1SmvWrFFFRYUOHjyoUCikKVOmJPR5fQBA1+b5J6uWlJSopKTkts855/Tee+9p2bJlmjlzpiRpw4YNys3N1aZNm/TKK6/c32oBAF1KUt8TqqurU2Njo4qLi2OP+f1+TZw4UbW1tbedaWtrUzQajdsAAN1DUiPU2NgoScrNzY17PDc3N/bcrcrLyxUMBmNbfn5+MpcEAEhjKfnqOJ/PF/exc67TYzctXbpUkUgkttXX16diSQCANOT5PaE7CYVCkm5cEYXD4djjTU1Nna6ObvL7/fL7/clcBgAgQyT1SqiwsFChUEhVVVWxx65du6aamhoVFRUl86UAAF2A5yuhS5cu6auvvop9XFdXp6NHjyo7O1uDBg3S4sWLtWLFCg0ePFiDBw/WihUr1LdvX7388stJXTgAIPN5jtChQ4c0efLk2MdLliyRJJWWluqPf/yj3nzzTV25ckXz58/XhQsXNGbMGO3atUuBQCB5qwYAdAncwBRd0rvvvpvQ3M1/VHlRU1Pjeea5557zPNPR0eF5BrDEDUwBAGmNCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZriLNrqkfv36JTT317/+1fPMxIkTPc+UlJR4ntm1a5fnGcASd9EGAKQ1IgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMNzAFvuGxxx7zPPOvf/3L88zFixc9z+zZs8fzzKFDhzzPSNL777/veSbN/ipBGuAGpgCAtEaEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmOEGpsB9mjFjhueZyspKzzOBQMDzTKLeeustzzMbN270PNPQ0OB5BpmDG5gCANIaEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGG5gCBoYNG+Z5Zs2aNZ5nfvSjH3meSdT69es9z7zzzjueZ/773/96noENbmAKAEhrRAgAYMZzhPbu3atp06YpLy9PPp9PW7dujXt+7ty58vl8cdvYsWOTtV4AQBfiOUKtra0aMWKEKioqvnWfqVOnqqGhIbbt2LHjvhYJAOiaenodKCkpUUlJyR338fv9CoVCCS8KANA9pOQ9oerqauXk5GjIkCGaN2+empqavnXftrY2RaPRuA0A0D0kPUIlJSX66KOPtHv3bq1evVoHDx7Us88+q7a2ttvuX15ermAwGNvy8/OTvSQAQJry/Om4u5k9e3bs18OGDdOoUaNUUFCg7du3a+bMmZ32X7p0qZYsWRL7OBqNEiI
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
2024-07-29 06:24:50 +08:00
}
],
"source": [
2024-07-30 00:44:16 +08:00
"plt.imshow(train_images[0].reshape(28,28),cmap='gray')"
2024-07-29 06:24:50 +08:00
]
},
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 17,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-30 00:44:16 +08:00
"def batch_iterate(batch_size: int, ipt: list[int])-> list[int]:\n",
2024-07-29 06:24:50 +08:00
" perm = np.random.permutation(len(ipt))\n",
" for s in range(0, len(ipt), batch_size):\n",
2024-07-26 21:07:40 +08:00
" ids = perm[s : s + batch_size]\n",
2024-07-30 00:44:16 +08:00
" yield ipt[ids]"
]
},
{
"cell_type": "code",
2024-08-01 06:04:14 +08:00
"execution_count": 18,
2024-07-30 00:44:16 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-08-01 01:23:57 +08:00
"def show_images(epoch_num:int,imgs:list[int],num_imgs:int = 25):\n",
2024-07-30 00:44:16 +08:00
" if (imgs.shape[0] > 0): \n",
" fig,axes = plt.subplots(5, 5, figsize=(5, 5))\n",
" \n",
" for i, ax in enumerate(axes.flat):\n",
" img = mx.array(imgs[i]).reshape(28,28)\n",
" ax.imshow(img,cmap='gray')\n",
" ax.axis('off')\n",
2024-08-01 01:23:57 +08:00
" plt.tight_layout()\n",
" plt.savefig('gen_images/img_{}.png'.format(epoch_num))\n",
2024-07-30 00:44:16 +08:00
" plt.show()"
2024-07-26 21:07:40 +08:00
]
},
2024-07-30 07:06:52 +08:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training Cycle"
]
},
2024-07-26 21:07:40 +08:00
{
"cell_type": "code",
2024-08-01 06:04:14 +08:00
"execution_count": 19,
2024-07-27 06:09:51 +08:00
"metadata": {},
2024-07-29 06:24:50 +08:00
"outputs": [],
2024-07-27 06:09:51 +08:00
"source": [
2024-07-31 16:25:39 +08:00
"lr = 2e-5\n",
2024-07-30 18:21:38 +08:00
"z_dim = 128\n",
2024-07-30 07:06:52 +08:00
"\n",
2024-07-27 06:09:51 +08:00
"gen = Generator(z_dim)\n",
"mx.eval(gen.parameters())\n",
2024-07-30 07:44:41 +08:00
"gen_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999]) #,betas=[0.5, 0.9]\n",
2024-07-27 06:09:51 +08:00
"\n",
"disc = Discriminator()\n",
"mx.eval(disc.parameters())\n",
2024-07-30 07:44:41 +08:00
"disc_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999])"
2024-07-29 06:30:08 +08:00
]
},
2024-07-27 06:09:51 +08:00
{
"cell_type": "code",
2024-08-01 06:04:14 +08:00
"execution_count": 20,
2024-07-27 06:09:51 +08:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-31 16:25:39 +08:00
" 0%| | 0/500 [00:00<?, ?it/s]"
2024-07-30 18:24:53 +08:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-01 06:04:14 +08:00
"Epoch: 0, iteration: 468, Discriminator Loss:array(0.5373, dtype=float32), Generator Loss: array(0.670858, dtype=float32)\n"
2024-07-30 18:24:53 +08:00
]
},
{
"data": {
2024-08-01 06:04:14 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAekAAAHqCAYAAAAgWrY5AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9SW/keZLmhz++O5100t3ppHMNMiIYay6VWZWVXTO9jdCQhMGgMTdBGGCAucxN70IHvQkBgg4SdNRlIA1GM92SulXVVZVVkVvswX0nfXf6vujg8zGae2apGUn2X/mHaABRUUzS+ft9F1see8wsMBgMBrqVW7mVW7mVW7mVH50E/99+gFu5lVu5lVu5lVv5frk10rdyK7dyK7dyKz9SuTXSt3Irt3Irt3IrP1K5NdK3ciu3ciu3cis/Urk10rdyK7dyK7dyKz9SuTXSt3Irt3Irt3IrP1K5NdK3ciu3ciu3cis/Urk10rdyK7dyK7dyKz9SuTXSt3Irt3Irt3IrP1IJX/UH//RP/1SSFAqFFI1GFQwGdXFxoWKxqFarpWazqWq1ql6vp0AgoEAgoGAwqPn5eS0uLioYDOrs7EzHx8fqdDrqdrvqdDoKhUKamZnR9PS0+v2+arWa6vW6+v2+/dzk5KQ2NjaUy+VUrVa1ubmp09NTpVIpPXz4UNlsVoVCQVtbWyqVSur3++p2uxoMBpqcnFQ6nbZnDofDCgaD9sy9Xk8XFxeqVqsaDAaam5tTLpdTIBBQo9FQvV5XMBhUMpnU9PS0gsGgYrGYotGo2u22Tk5OVC6X1ev11G631el07Nl7vZ6i0aiSyaRisZgkaXt7+wdv1r/+1/9akjQxMaFMJqNoNKr9/X29ePFCpVJJ8XhcU1NTCofDCoVCikQiCgQCuri4UK1WU7/fVzweVzwe12AwUKlUst/7/PPP9cknn6jT6ejrr7/W27dvFQgEND09rUQiocFgoHa7beu1s7OjYrGoRCKhXC6nqakp1et1nZ6eqtFoKBgMKhQKKRAIKJlMam5uTvF43PYmGAxqdnZWS0tLCoVCevnypb788ktbb/Zqbm5Oc3NzisViWlhY0NzcnMLhsGKxmOLxuHq9nmq1mhqNhi4uLrS1taWTkxPF43HNz89renpa7XZblUpFrVZLkvQf/sN/+MF78C/+xb+QJE1NTSmbzSoej+v09FTv3r1TtVrVxMSEJicn7ZyFQiFJsvMmSZFIRLFYTP1+X2dnZzo/P1cwGNTa2ppWVlY0GAxUrVbVaDTUarVULBZVq9UUDoeVTqc1NTWlWq2mra0tFQqFkT2oVqs6ODiwn4/FYgqFQpqcnFQ2m1U0GpUk9ft9BYNB5XI5ra+vKxqN6tmzZ/q7v/s7XVxc2DMGg0FNTU1pampK8Xhc9+/f19ramoLBoLrdrrrdrprNpnZ3d3V2diZJdtfi8bhmZ2c1NTWli4sLHR4eqlKpSJJ+//vf/+A9+G/+m/9GklQqlbS7u6taraZ4PK5kMqlIJKKzszPt7OyoXq9renpaqVTK9iEYHI1L+v2+isWiisWiAoGAstmsstmsAoGA3aNms6m9vT2dnp4qGo0qm81qenpa9XpdBwcHKhaLSqfTevLkiebn55XP5/Xy5UsVCgUNBgPR1DGTyWh1dVUTExMqlUo6OztTp9PR3bt39fTpU8XjcT179ky//vWvVa/Xbd159kAgoMnJSf3xH/+xPv/8cwUCAeXzeZXLZZXLZX355Zd6+/atotGocrmc6dRms2ln38t19uCf/tN/KkmanZ3V3bt3NT09rdevX+tv/uZvlM/n9ejRI/3xH/+xUqmUjo+Ptb29rUajoUqlomq1Kkm6e/euNjY2FI1GTW8OBgN73na7rePjYxUKBU1MTOjBgwdaWVlRu93W6empKpWKGo2G6WB0XjgctnPZ6/WUSCSUTqdN7/V6PUlSLBZTIpGw9WWfms2mms2mut2ura0kpdNppVIpBQIBtVotdTodRSIRra2taWFhQd1uV+fn56pUKopEIpqdnTX9U61W1Ww2dXFxobOzM9VqNUnSixcv/t61vo2kneBcXEV+LN1Ux5XOrdzKTchV78Gt/MPKdffh/+v7+EPf/8ei36X3iKSXlpYkDSOCi4sL8zYajYba7bZarZZ6vZ5F0tLQgNTrdZVKJYVCIdVqNXW7XfV6PfX7fYu24/H4iNfHf2eh+v2+Li4uVC6XLcqORCIaDAYql8sWgbdaLfX7/eGLhYevFgqFFAwGFQgE7PORSCRintfU1JSkYZTK8/M+kpRIJJTJZCRJFxcXKpVKFkkEAgFFo1FlMhnF4/ER746o+iZkYmJCkizCCQQCCoVCmpiYMOSg0WhIGkYzsVhMgUBA7XZb7XZbkmy/iGj9GoRCIXsnPicQCNjz+/2Ynp5WOBwe+b1gMKiZmRlNTk6OPPfU1JQSiYTi8bh9DlH69PS0QqGQMpmMcrmcReG8XzgcVq1WU7vdViaTUSAQ0GAwUL1et/NUKpV0cXGhVqulbrereDyucDisfr9vHq8/TzchrFGv11On01E4HFY0GlWv11OpVJIkWx88b85nOBxWJBKRJLXbbUWjUUWjUaVSKc3NzanX6xmKA5pERNbpdNRsNtXv9+3MRqPRkX0iOiCyCAaDSiQS9ndYh0AgYM8hSclkUsvLy2o2m3Zv2O9ms2l/n71pt9umA0Br+LtE7DwrkclNCBF7vV43fdPpdGzdQR6466y5JHt2dA/3gLWam5vT2tqa+v2+Tk5OlM/nTedx11utlhqNhrrdrqLRqKanpxWLxVSv15XP51WpVNTv90ciYP426ydJk5OT6vV6GgwGKhQKikQiSiQS+uijj0bWCv12cXGhTqejer2uarWqQCCgcrls79tqtexuxWIxJZNJdTod++r3+4aGXVf4jG63a+sRCoWUy+UM1ajX6xoMBmq1WnYeOL98Rq1WszMYCoXU6/Xs81ireDxuSCTvgp5iHTmvIEfsL5/L97xwP9E3oVDInrder9vno3f7/b4qlYrC4bASiYRmZmbsDBWLRXU6HZXLZYukg8Gg6VqQq3g8rmw2q5mZmSuv9ZWN9EcffSRJOjk50bfffqtyuaxaraZyuWzKp9vtmpFEisWims3miKLiIGFkUqmUlpeXDZYcV6qdTkf5fN4MdK/X08TEhAaDgQ4PD23DOYAozFAoNGLQWPxer6epqSmlUimFw2FNTEwomUxKGhrgSqVicDwHLJVK6d69e2q323r+/LkODg5s4YEEP/jgA62urur8/Fy/+tWv7LICW15X0um0JBkUzGVMpVKKRqMql8s6PT1Vu91WIpHQ1NSUgsGgms2mrU2n01G73VYwGNTExITBs9FoVOFw2C4RaQMuUSAQMOMXiUSUy+UUiUTU6XRUrVYN+llcXDRj1el0NBgMNDExYc/IZQDuXlxcVDgcNses1WrZfnW7XR0fH+vo6MjgI0kGE5NqOTs7U6lUUjgc1szMjKampkxx87k8y00JF5KLyDkrFAo6OjpSs9k0SJ494Ox7pZBMJs2JWV5e1sOHD9VsNg3m88qo1+uZsQuFQspms1pYWFCn07E0UTAYNGcSJ0caOnYoMy/eSJNWwCngTu/u7mpvb89gSZRfo9FQPp9Xo9HQ6empjo+PFY/HNTExYc5ApVJRr9cz+PEmDPWbN2/s/VjHer2uQqGgdrutWq2mQqGgTqejSqViezPu/HGu0RdTU1Pa2NjQz372M9Xrdf37f//vtbW1Zc4IxgHnMBQKaWpqygKMfD6vs7Mz2zMUNQaC58Qxm52dNTj93bt3CgaDevjwof70T/9UsVjM1rTZbGpzc1OlUsnSVKenpxoMBuZIYLhxWkgxcfeBj73zch1hH0lz4rA8fPjQ7MD5+bkk2Tqh931K8OzszJzIRCJhAVmlUrG7MjMzY3vFPhB4oEd8ii8UCimRSGh2dlaJREK1Wk0nJyd2P7yDhkSjUYPDy+WySqWSOp2OJicnlUqlJEmVSkXFYlGRSESPHj3SvXv3bA92d3dNJ5BqKpfLisfjikajSqfTSiQSmpiY0PLysjkdV5ErG2kUJFg
2024-07-30 18:24:53 +08:00
"text/plain": [
2024-08-01 00:59:36 +08:00
"<Figure size 500x500 with 25 Axes>"
2024-07-31 01:01:14 +08:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-08-01 06:04:14 +08:00
" 20%|██ | 100/500 [10:18<38:54, 5.84s/it]"
2024-07-31 01:01:14 +08:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-01 06:04:14 +08:00
"Epoch: 100, iteration: 468, Discriminator Loss:array(0.536409, dtype=float32), Generator Loss: array(0.665009, dtype=float32)\n"
2024-07-31 01:01:14 +08:00
]
},
{
"data": {
2024-08-01 06:04:14 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAekAAAHqCAYAAAAgWrY5AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9x3Nk6ZUeDj/pvUUmkLCFQqG6y7CqLbtHJOenUZDDkZmRFNpIG+0UoT9JC62kjUITipA0M2JoNBSDHDqxm6x25VEFb9N7b74Fvufg5MVNIBNIlCHzRCAKlbh573tfc8xznKXX6/UwoQlNaEITmtCE3jiyvu4BTGhCE5rQhCY0IXOaCOkJTWhCE5rQhN5QmgjpCU1oQhOa0ITeUJoI6QlNaEITmtCE3lCaCOkJTWhCE5rQhN5QmgjpCU1oQhOa0ITeUJoI6QlNaEITmtCE3lCaCOkJTWhCE5rQhN5QmgjpCU1oQhOa0ITeULIPe6HFYjn37zabDQDQ7XbR7XaHHsSge4+7GJrFYpEfff9erzf2Z5k9G8BI8zLoHoPIbrfD6XTCYrGg3W6j2WwO9V7RaBT//t//e/zbf/tvUSgU8B/+w3/AX/7lX8r3X2dROv3OFxmHxWKBy+VCOByG2+0GAGxsbFx4PP/0n/5TAEAqlcLa2hoKhQKWlpbw3nvvIRQKYXNzE1999RVKpVLf+M3GbrFY4HQ64XK5AACNRgONRmOocQQCAaysrCAWi2F6ehoff/wx5ufn8fjxY/zX//pf8ezZs77rrVYrbDYbLBYLut0u2u32qXvyGqD/DI8y73a7HX6/Hx6PB+12G5VKRfZRt9vtO3MXpfPOwWXIarXCaj22XUblY+Mki8Vy4f1utVpht9tht9vR7XbRbDbR6XRgt9vh8/ngcDgAHO/hi1IsFpOzFQgE4HK5UCgUkEwm0Wg0+vab5iEOhwMulwsWiwWdTgedTkf+3uv1EAgE8E/+yT/Bn/7pn6LdbuOrr77C8+fPUS6XsbGxgYODA1mXbrcrcsdiscDtdiMSicDj8aBWqyGXy6Fer6Pb7cpzOEd6D1ksFtjtdhkX96rNZsPi4iKuXbsGi8WCnZ0dbG9vo91uy9j1/VwuF2ZmZhCJRNBqtZDJZFAul9HtdtFqtdDpdBCNRnH37l3MzMwAAP7Tf/pP5871xJL+PaSrZGITejPpdStTE/rDogmPeXU0tCV9HtlstgtZcYPoKhjO62Ji49rQZ1llAJBIJHDr1i14vV5sb2/j2bNnqNVq596n1WphfX0dv/jFL1CpVLC3t9dn9Zz3XKvV2qeBvknU6/XQ6XTQaDTGsv60PorFItrtNiwWC+r1OpLJJMrlMnK5XJ+VavZMm80m2j/PC3CCsjgcDiQSCUxNTaHVauHw8BCZTKbvHp1OB7VaDaVSCVarFc+fP0cqlcL29jYqlYrpPNDy0Os6NTWFRCIBp9OJTCaDo6MjtNvtkSxem80Gn88Hj8cjz2k0GqcsJT7zVRDnls+2WCzw+/0IBoOwWq0olUooFoum+3VUa994PoyohbbizrsH7zMIeRlmXFyDTqdzCr1otVpjOQderxfA8btyvblver0erFYrXC4XrFYr2u22/E3vP21B0zJuNpvY39/HV199hU6ng83NTaRSKdTrdbHQNVmtVjgcDthsNjgcDpkjPZ8OhwM+n0/Gwr1JIvpgs9nkGuD4PNZqNWQyGVitVlQqFZlP4xr1ej1Bjnim6/V631nivXO53Gjo1CgLc+aN/v9Qis1mQ61WE0hgVLpqQXrZA3BRGidzMoPCrFYr7t69i3/37/4dEokEfvSjH+Hw8LBPSBvhfm64Wq2GX/ziF3j69Ck6nQ4ODw/lMJu5B4zPtdvtsNlsAuuMU1CPYz3a7TbK5fJY1uDly5cAjhUbQtOFQgHPnj2DzWY7F7ImxE3ovdFooNls9jFTr9eLjz/+GJ9++ikKhQL+7u/+Drlcrm9e2+028vk8ms0mksmkjKteryObzZ56rpkCZbFYcOPGDfzwhz9EIBDAr3/9a/zkJz8ZWaFxOByYm5tDIpFAvV7H9vY2UqnUKYh73MqqkTQzpOJIYWWxWDAzM4Nbt27BbrfjxYsXWFtbO7VWVGT0/c4bi/E82mw2eDwe2O12tFot1Gq1MwW12XgvSlqA8D58LoVps9m88P1JsVgMANBsNlEul9HpdARaBk4Eo91uR71eR7VaPeU+MY610+mgXC7jiy++wMuXL9Hr9VCtVtFoNESIU5Dq/3s8HjidTthsNtjtp0Wax+PB9PQ03G43SqWSQPJ6PegqtFqtfcpMNpsVHkrhzvFSmHMPdLtdZDIZ5PN5mW/yUCrmjUYDW1tb2NvbG3quR/ZJn7XR6AdpNpsTOOQK6Lw1iEQiuHXrFpaXl/Hw4UM4nc5T3+XvWqttt9s4ODjA0dERAJj6b4xasCYymTcVbuWhGwcVCoW++wLHjGpYX7L2GfK7RgZut9sxPT2NGzduIJPJIBAInDpPZLjAiRJyHvM1ro/FYkEoFMLy8jIikQjW1tYGruNZe4+WdCQSQblchtVqPTXfZhbOVZJRIaXPcmpqCg6HA4eHh30IkKaL+oL5Xfo4HQ6H/P88ooU46Nqz7mG2NmbGiBbclyXyFgoi+ly1ouRwOGC32wVxMo7F+EPBm06nkU6n+663Wq1wOp2w2+196AjPksPh6LOGNdntdni9Xni9XrTbbZFRmngv7gmOkQiBccz8v/GdKPwpmHk/IitUREYxZMZmSVOTslqtYhkMS8Nu5KukN1XAaBrEPLm59vf38ZOf/ATxeBxffvklqtXqqe9y03DD6L+fBf0B6AsE0TAm76OtplHorCAZvt9l7j9OuujztdAg/A7AlGk2m01sbW3ht7/9LSqVCiqVCnw+n7iR2u22oBbAseW9tLQEv9+PfD6PnZ0dFItFOJ1OsWYajQYqlUpfsIvVakU6ncaXX34Jn8+H7e1tdDodEdTDviuteofDgVqtZupiIV3V+jmdTng8HthsNrRaLbHqtPVWKpWwvb0Nm812CpkwIy3Ez5oP/TmvJw/U1pbxWn6ug5a4phpW5XU2m61PmJk9/1VRuVwGcLxXKaA9Hg9CoRBsNptA7frv3W63TxnRigytZU2EzBnoxnfXPIFuVl5DhEujU1o2tdttOBwOQbI4d8FgENFoFFarVRAqjUDYbDZMT08jFouh1+vh6OhI0CIqI/p9OXYq41yzUc4VydIb8hvnaXicuHFANq+TjPDuOA6Avudl5kVriBwXfTHamrHb7cjn80gmk6csGg0JjeKSoCbrcDjQ6XREWOj3u9AGHPBdzhdholHHexZdZk2NiITZ/cw+11q6/tws4tlutyMajSIYDMo1FMr5fF58zlSalpeX8Rd/8Re4ffs21tbW8Fd/9Vd4+fIlQqEQFhYW4PP5kM1msbOzg2q12mdxeL1e8dNWKhWUSiWBHs0iwM3IarXC7XbD6XSKH2+QJa3f+6JkvBcRgUQiAbfbjXw+j8PDw1OQptvthsfjgcViEbfEWYyT54po01l7z8xyN5IZ/G+xWODxeBAIBGCz2fqgYb0GDodDBFqr1Rp6bc6iy5yDRCIBAKKI9Ho9zMzMYGVlBYFAAEdHR1hfX0e1WkWz2RSlia4eWtpOpxO9Xg+VSkXem/Nit9sRiUQQDAbFAq1Wq33wsc1mg9frhdPpRKvVQqlUEli6Xq+L8hAOh2V/at85DcTp6WksLS3B4XBga2sLL1++RL1el2tcLhc+/fRTfPLJJ2i323jw4AG+/vprdLtdeDweuN1udLtdOZ82mw2BQECyHMrlct/9OPfakBpEY7OkX2e6wh8aGYUZmX+xWBTtbhCZCZlhaRCsfVFlRjO1Ya553WjLZciICBgDVzR1Oh3kcjkUi0U4HA4Eg0H4fD4A/Yoafa12ux0zMzNYXl5GuVwWQUS/YCAQQLVaFWVH7xkK5l6vdwruO+tdjIpGrVbrY0JnXX8VZLPZ4Ha74fP5hPEZn1mv11Gv12VM50HIw+xPkkYDtTDWUKeZEgr0x3VQeTKbw1HHdJVEJEivt91uRzgcRigUkhR
2024-07-31 01:01:14 +08:00
"text/plain": [
2024-08-01 00:59:36 +08:00
"<Figure size 500x500 with 25 Axes>"
2024-07-31 01:01:14 +08:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-08-01 06:04:14 +08:00
" 40%|████ | 200/500 [20:10<29:27, 5.89s/it]"
2024-07-31 16:50:32 +08:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-01 06:04:14 +08:00
"Epoch: 200, iteration: 468, Discriminator Loss:array(0.556644, dtype=float32), Generator Loss: array(0.689756, dtype=float32)\n"
2024-07-31 16:50:32 +08:00
]
},
{
"data": {
2024-08-01 06:04:14 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAekAAAHqCAYAAAAgWrY5AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9Z5Nj6XkefCHn0EDnMKlnZnc2cYN2KVKkRJp2lS0nlcpfXeXf5h9gfVBJJalki6SK0lLkxtndyR2mE7qRcwbeD/Ne99x4+gANoIEOS1xVXZ0OznnOE+4cbN1ut4sZZphhhhlmmOHKwX7ZA5hhhhlmmGGGGawxY9IzzDDDDDPMcEUxY9IzzDDDDDPMcEUxY9IzzDDDDDPMcEUxY9IzzDDDDDPMcEUxY9IzzDDDDDPMcEUxY9IzzDDDDDPMcEUxY9IzzDDDDDPMcEUxY9IzzDDDDDPMcEXhHPZCm80m3x0OB2w2G7rdLtrtNqyKlvH6UQqaud1urK6uYmlpCfV6HQcHB0ilUuh2u7DZbHJPotvtyv0dDge8Xi+cTifa7TYajQY6nY58DftuvKfNZoPb7YbH4wEA1Ot11Ov1vp8f9j3PU+DNfH8TdrsdDocDANDpdNBut2Gz2bCwsICVlRU4HA6cnJwgkUig1Wqd+izv3+l00O12Ybfb4fV64XK50O120Wg00Gq10O125ZppgGOx2WzynEHPsho7YbUPp7kGgz6nxzLKGBwOBxwOh5y3s/az+Vw+c1zw+QDQbrfRbrdPPcPpdMLpdKLT6aDVap26xsR5xuNyueQeVvvQ7XbD6/XCbrej0WigVquh0+mcoiFW69Bvnfg3u90Ol8sFh8OBTqeDZrN56iyZMJ856Bpz79rtdnkur7FaAz32YffXtM+B+d6aFtlsNiSTSRwfH1vSIr43+YvD4UA4HEYgEECr1UKxWES1WgWAkc7DRWCUMzfMNTNNeoYZZphhhhmuKIbWpDXI/YfRpkbRMjudDmq1GnK5HNrtNlqtFux2u3x+0H2oZQAQbYMS1jBjsNLSqRVwbFaghEvp9SJKobtcLvj9fjidTtEU+O7UGPQ4Go0G8vk8HA4HarWapaYJnJ5fzinvd5ES6zDaxzC4qNL0lPz77YHzvsMwFoV+nx0VNpsNLpdLNFZqboPeTVvVpj3n2nrm8/ngcDjQarVQr9dlnkg7+ln6TA2Zmp7+H6Etbfy/lZY9zHsPYw3jddqipM/fJGnueeH1ehGJROB2u1GtVlEsFtFsNgG8HucgWsR30tqzFS3tdrtoNptC67QlYZT3taLzZ10/rgVsUrAN22BDv9xZm3LcF6OJ2el0ymf1oRtkQuMzzzL3WP1dbxCTuVNI6Gca06bIVqt1JiObhIkpEolgbW0Nfr8fhUIBR0dHsuGtDrLT6YTL5YLNZkOz2ZRDpM1p/QQubWbTmLa522o9Bl3fz9xthUmb+ehmoHn+LPPneZ47bSLBMxQOhxGJRGCz2ZDP51EoFAbO7ahC1Xneg3sjEAhgfn4eXq8XlUoFmUwGjUbjFEMz94+5hv1M4FYuNp550gUqBMPQOX3eNPTn9F7S7qtGo4FmszlQYbmMc7C0tIQ333wTc3NzSCQSeP78OYrFYo9gp9GPFrndbrjdbmHG+l31Z016Zc7HMGugz+pZbhm9BqO6mobBMGtwLk160uh2u+L7pTTvdDqHPvT9rtMS2iBYSVlnLaI+yOP6K0eF0+mE3+9HJBJBq9WC0+k8JRhptFotS8YxzHj1ptRWA35+lL1wkRL+RUPPzTTe8yLnjWfP6/XCZrOhXC6fyYQucnxak/Z6vQgGg+h0Oj2xMoOYlNaaifOu2bCf14qEFSigmgrAsL7miz5jbrcbc3NzmJ+fR7VaFQbcj7ZY0SItkFDZsXoHfk4rZOPQoFE+fxn03cRYTPosTEJC4+RNwoRm9Vkt1Zo/j3JfbV6+qMPRaDSQy+VQq9VQKpXEJO12u+WQtFqtHq2in7DB96W0CPTXXk0zebfb7ZH89XppBINBxONxeDwelEolZDIZkaCtnqnn1ePxiNm1Xq/LO/Ub12UIAXxvzaTP2hNnacbU1oDJSPDafKq1S7vdLpYrWgEoLJdKJXmXSCSCdrs9MIDyotFsNlEqldBsNlGtVk+ZQAHrYErTTK0/c5b7px/M/5n0xGZ7FVzndrtht9t7LC6RSATxeBwOhwPZbBbpdFrWodt9FTS1vLyMaDSKbreL4+NjJJPJnvMyiDFOE5VKBYeHhyiVSkilUj0mbO4rjUG0SK/LIDenlUXTZLb91srpdCIQCMDhcKDRaKBSqch4rKy/PCsOhwPz8/OyBul0GplM5kLozVSYNDBaFKP5f17Tz2QyiTGROJkbe1Ri2M/EPE1UKhUcHBxIhCkPqd/vRywWg9PpRLlcRqFQEGZtpVloaVKb8Hj9WaDU7/F44HQ6xS9oEsuFhQV88MEHmJubw+7uLh4+fIhisSjXUHqmiUsTp0AggGg0CgDIZrPI5/OWpq7L1NC1xUUzw34E6SxmQIHL7XZLnMZ5mbTdbofb7e4xyXe7XXg8HtGY6e/rdDqoVCpoNBqw2+2IRCKIxWLodDpIp9OWgtJloF6vI5lMyr41rT7Aq/gNvjfjNwYxNdN9xfXrxwCs9p4WXPV+oNbvdDpFsACA27dv46OPPoLX68VXX32FTz/9FNVqVXzqPp8P9+/fxw9/+EO0Wi388z//M7LZrJyXfsLGRaxRoVDA48ePe96p0+mIWdv0q2tapOkP55Lj5xr0o62cd1P4tDKBa3i9XlEYCoVCzxxqV4GO8el2u3A6nbh16xbee+89tFotfPHFF8jn81NxbZmYGpM+D6wm/DwYtJHPu8Evg1i122054EwJodDBtDGawantmtASviYmZI5Wfjbt8+J3HkYKPM1ms4dYkuHQJJbNZuF2u0W74f07nY5oGPrvTqcTXq9Xfr6q0ARmGK1mGGFV+ybPC21SbLVaPZomfX3aukEGxTF4PB4hVlfFbaEDiPR8EXxnjlmbS/V300LEc9Bvjfqtr3mmKMTyiwKDOZ5QKITFxUUEAgFsb2/3uOd4DkOhEFZWVtBqtRAKhXo0dSv6dlHQfmUT+t214tVv/5jrwO/a3K/vbc6zFpbNe2qLIQXTWq3Wcx9N+600eb/fj/n5ebTbbQQCgQub8ytH9UyJeJL31dB5xFbXTYIQXcQiahNzrVZDoVCA0+kUjbbfO1jNR7fbhdvtxo0bNyRX/eXLlzg+PkYoFMJbb72FjY0NdDod1Ot1NJvNnkOSzWaxvb2NXC53iskcHx+Ldu/3+4Wh03S6tLSExcVF2O127O/v4+DgQISRfD4v7zdJ4W0aMF0gg64bBJqWrQjPOPD7/VhcXITX60U+n0cqlUKj0QBgLajpcVarVeRyuZ41uGowiSy/9Dk3BUCuk4a2MliZxV0uFyKRCPx+v5jbSewpLGtNOhAIYHFxET6fr2esOsg0Fouh230VMBWLxfDOO++gVCohnU4jmUyi3W5jf38fv/vd79DpdHB0dHTKb2sKFZe9RpxHu92OQCCAcDgMp9OJbDYrUfgcK4AeWqzdO8xJ5/xQSHnw4AHW19fRbrfF6lOtVpFIJFAqlcQSwvtTqIlEIohGo/D5fGi327KvKUA5HA5Eo1HEYjHYbDYcHR1hf38f3W4XJycn+O6779DpdHrcDdPGlWTSF7HB+gWWaM3lPBHMplQ9LZCIt9ttIRhaa+I1g8bJe3Q6Hfh8PvzRH/0RfvrTnyKXy+Hv/u7vkM/nsbKygv/xP/4HfvGLX6DVaiGVSqFcLkvhmFarhe3tbfzjP/4jdnZ2xAxOYrW9vQ3glTbMyOFSqYRcLgcAeOutt/CTn/wELpcLv/rVr+T+5XJZrAb9ijjwPc5614vAMEz1rD1Os6CZynIehMNh3L17F5FIBLu7uxLToE19WiPRDK5YLKJSqQBA36CeqwQt6DOVk0za1I5MLdTcY6YVw+PxYGFhAfPz86hUKkgkEuh0OnC73QgGg3C73QBeW33W1tbw3nvvYX5+XhhvvV7v8U+TYdTrdaysrGBhYQGNRgOff/6
2024-07-31 16:50:32 +08:00
"text/plain": [
2024-08-01 00:59:36 +08:00
"<Figure size 500x500 with 25 Axes>"
2024-07-31 16:50:32 +08:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-08-01 06:04:14 +08:00
" 60%|██████ | 300/500 [30:02<18:58, 5.69s/it]"
2024-07-31 16:50:32 +08:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-01 06:04:14 +08:00
"Epoch: 300, iteration: 468, Discriminator Loss:array(0.58722, dtype=float32), Generator Loss: array(0.661292, dtype=float32)\n"
2024-07-31 16:50:32 +08:00
]
},
{
"data": {
2024-08-01 06:04:14 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAekAAAHqCAYAAAAgWrY5AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy915dcZ3Ydvm/lnKtzRE4kSJAYDMNwhpoZWpolaWw5e3kt/1l+8ZO9LIcHW/pJI0uaIFIzjEMwgMgNNDpXV1fOuer3AO2Dry6qgQ7V3VXdd6+FBaBD1a0vnLjPOVqn0+nAgAEDBgwYMDBwMB31AxgwYMCAAQMGesNQ0gYMGDBgwMCAwlDSBgwYMGDAwIDCUNIGDBgwYMDAgMJQ0gYMGDBgwMCAwlDSBgwYMGDAwIDCUNIGDBgwYMDAgMJQ0gYMGDBgwMCAwlDSBgwYMGDAwIDCstMf1DTtpd/nn06ng3a7ve+H28kzmc1mmEwmdDodNJtNvKiBmsvlQiAQgMViQbFYRC6XQ6vVOvDnVLGfBm8v2wMDO0M/9sBqtcLtdsNisaBWq6FcLqPVakHTNJhMJmiahna7va97wPPtdrvhdDrRbDZRLBZRrVa3fS7g5Z/P5/MhFArBZrMhn88jmUyi2Wzu+Tm3e/YXPYtxD44eg7QH1B3A0+c6To0wVd0IAO12Wz7fTj6n4UkbMGDAgAEDA4ode9I7xWFbQfTad/KerVYL1WoVZrMZjUbjEJ7ucGG1WuFwOGAymVCv11GtVvu6F/QQDytSMsjodDpoNBpot9sSjdF7AyosFgssFgs0TUOz2dzx+WOEqFarodVq9WXdm80mqtUqGo0GarXagdzX4+QJGdg9TCbTcxGlF50J9XvH7ex0Oh2Rm/z/bqDtdMDGduGNnYTZXhb62g/U939ZuMRkMsFsNvc8OId1MA4yxOT3+xGNRmGz2ZDNZpFIJPpmjGiaBovFArPZ3KWghhH92AM1rK1/ba6L+j5utxsejwcmkwmlUgnFYvGl68fXpsDrdDpotVr7PqsmkwlWqxWapqHVakmo+7CN673CCHf3BwexBzRULRYL7HY7NE1Do9FAvV4XuXzclPB+sJO16IsnrVoJhw2+ryo89d8j1ByhPod+HGCxWOByuWCz2VAul/suzKgs2u32iReUVJjAs9wxz5L+PPH7NpsNJpOpZ075RWi1WrviTrzMKG6326jVas/9/G7y2gYM9AJlKh0iVV4YZ2pv6IuS1gsmCnN+r5fHwI0EsONw9cueYbvnednvqM+iPi89JbPZDI/HA5fLhXa7jXw+j2KxuK/nPQjU63XkcjlYrVaUy+W+e7pUFNutL9eRXp8apRhWr3un2E4QUWg1m01UKhXxLPTfB55f173eie32hn/v9n3UZ1QVeT/urYHdw2w2S0SrXymQfoHnodVqodFoyNnfiUzulyI/yMjtUWDfSrrXQtBr0Ic69D9jsVgk59YPlvWLlMeLnp3PwgPFQ8+v22w2TE1NYXJyEo1GAwsLCweiBF+Glx2+crncdTH6yVxXBcKL0gk2m00ECH+WIdrjcmn0eNG5o4fdaDSQy+UAPFNuqsfRr1D2dujFOt/pe/FzAN2h90ajcejVEScdmqbBZrPBbrej0+mgWq12RUWOGmpIm9UOOzHS9ZHNvd4DlTdzXORN34ljQLdH1au0Q90Q1Uo/SLxIwakCTO81mEwmCSMHAgFUq1XY7fYDf969YLdh0d3iZQdfr3TU3zmJ4S592YW6N73O/0Hfg73eN/X3VCV90lMeRwX1jg3qHuxFSfbTAz5OqcwDUdLtdls8ul5Kg5aVGj49aOjfgwcdeMqKpidNpQxAiFJmsxmFQgHr6+uo1+soFotHsvmDfuB4MVVvW33ml13CnV5SehNWqxWdTgf1en0g2fpci+2Ehepx8P9H9Sw7+V3138fJUxk2kO2vnp3jgH6dqcMmAx809s3u3u5nd1KYvhOhfFD5BZvNJuVKVMT6z6iyd9XQbalUQqVS2dP7HmdWK1nDNH7U/JQayu2VD92OF9ALZrMZfr8fHo8HrVYLuVxuVxyBg9iD7UJ1R3nG9/s+L/qs+xWox/keHDR2Kl9fhuO6B8PkQe/kOQ/Ek94tceugQeGvKgLmdfRKutfhIxu2Wq1KlMDA9niZJbsdwWq77/WC2WyWMqJBEBgv4z28CDv5mZ1GIV70M3sNj+pfb1gE4HGFEcV4MXa6NmqKE+hdwjsI5MgDUdL9xH4XyOPxIBwOw+FwIBqNYnp6Gi6XSwhhaj4aeJY7bDabSCaTSCQSqNVq2NraEgKcyl4fJGblUaMX+UkNe79I2Kv/VlMO6vdVwlWtVkOhUBgYo+mgLjIrC5xOJ1qtVs+2oBaLBW63G1arVRr2kLTDc602l9CTe9Sa96MWSAYMHCTY8MlisWB0dBSnT58WfWC1WgEA+XwemUwGtVoNm5ub2NjYQKPROLK7MfBKer/w+/04f/48gsEgrl69ip/85CeIRCJdVpTq+VWrVfnz7bff4ptvvkGxWESz2UQ6nRYlrZKjDEX9FFSk+vXYieWvEsxsNpsYUOr3VMZ6pVJBpVLZEXP0MHBQF9hisSASiSASiaBarWJ9ff05Nq/dbkckEoHH40G1WkU2m0W1Wu1SzIw8sG6VBk+lUnlOQQ8yIcmAgb1C0zTY7XYEg0E4HA68+eab+NM//VOMjo7CbrfD5XIBAJaXl7GwsIB8Po/PP/8cqVSqq7LlsHEslTQ9MZPJBLfbDb/fj2AwiJGREUxPT2NkZKRneLDdbqNaraJSqaBcLmN1dRVerxftdht2ux1msxntdlsEnYHn0S9mpj4Upfe61VK54+r9qURGeslqWoZn3Ol0wuVywePxwGKxoF6vi4LmOvI1eHZZu8ozD6CngWXAwHEA5YjZbIbD4ZBqnfHxcYyPj8sdAp72m8hms7BarXC5XF3pUkNJ7xPciOnpaXzve9/DyMgIwuEwpqen4Xa7MTs7C5fL1aWgVSVAgehwONDpdOB2u+H1etHpdBAIBBAMBtFoNFAsFlEul7tIT9uFdE8C9pOP3e71yBNwu92yZ7VaTbxpNpQ5bkxOgp/fbDaLd0xjMRwOw+fz4fz58xgZGYHD4YDf74fD4ZDwP71h1dihgq9UKiiVSmg0Gtjc3MTS0hIqlQqKxSIKhYIQ/VSypKG89wZ13fUleS+b2megf+B9CAaDuHz5MsLhMM6dOyd3yWq1wmazyR1zOp2o1+uw2Wxdr3EUOHZKWtM0nDp1Cv/xP/5HvPrqq7BYLHA6ndJghTXOekVNRWuz2aQcy+v1IhAIwGw2IxwOY2RkRMhj9Xq9q8Vor3zsSYGe2bzfNVDDs4FAAGNjY9A0DfV6HbVaTbgBDNUOArmjn6DnyzailUoFzWZTvACv14uZmRn86Z/+KV5//fUuRWCxWOD1ertIdfqUTi6Xw9bWFiqVCh48eIAvv/wSuVwO8XhcygxZ1sYzbijpvUGteFCNpnq9fmLlxVEiHA7j2rVrmJ6extzcHEZHR+Hz+QBASoYdDgc8Hg+azaak3Y5yn46NkjabzRKS9vv9GBkZwfj4eNdQDRW98m4qyUbtNsaaXIYPVTb4SRFeFDZ6QpcqeDhdab+1m+r668O0/NNr746TwFM/n76OmgaMx+NBIBDo6pNst9vh8/lEuKhKmsaM1WqVNqUsZWs2m3A6nbDZbGJocZ2PUy3uYUC9E5QhLEskGEI1cLiwWq3w+XwIBoPweDxde0OdMGh8jGOhpDVNw9jYGL7//e9jfHwcFy9exMjIyHNlVaqnp88zqJ4G/+12uzEyMiLTXAqFAmq1GhwOB0ZHR9FoNJDJZFAoFI5lqJtCxmw2IxQK4caNGzh79uy25VIrKyv47W9/i7W1tT2/p8lkgsfjwejoqDAxOXKTaQaO4dR70IetqLerjd4vmHOn0Gg2m13DTSqVCpxOJ+7duwcACAQCOHv2LCKRiBg26vnu9dw0fgKBAGZnZ6UCwmw2o1aroVgsCmEyl8t1cQB2i8OqAz9qUDG7XC6Ew2G4XC44HA4xmrLZLDY2NqTHwqApg+MM3lW3243
2024-07-31 16:50:32 +08:00
"text/plain": [
2024-08-01 00:59:36 +08:00
"<Figure size 500x500 with 25 Axes>"
2024-07-31 16:50:32 +08:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-08-01 06:04:14 +08:00
" 80%|████████ | 400/500 [39:54<10:08, 6.08s/it]"
2024-08-01 01:23:57 +08:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-01 06:04:14 +08:00
"Epoch: 400, iteration: 468, Discriminator Loss:array(0.587837, dtype=float32), Generator Loss: array(0.683923, dtype=float32)\n"
2024-08-01 01:23:57 +08:00
]
},
{
"data": {
2024-08-01 06:04:14 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAekAAAHqCAYAAAAgWrY5AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz913Nc2ZUlDq/03lt4EAQ9y7CqVKoqmWpJHT3VPeqY7oiJmL9vnmYmJmKeplutX2vUklotqbwhiw4k4YH03t503wO/tXnyMkHCIxO4KwIBEEykufecs93aa5sGg8EABgwYMGDAgIGxg/ms34ABAwYMGDBgYDQMI23AgAEDBgyMKQwjbcCAAQMGDIwpDCNtwIABAwYMjCkMI23AgAEDBgyMKQwjbcCAAQMGDIwpDCNtwIABAwYMjCkMI23AgAEDBgyMKQwjbcCAAQMGDIwprPt9oMlkOsn3cWFwFIE3i8UizzHqeUwm09Bj+v3+S69pNpthNpvl8epj+fjDwu12Ix6Pw+l0otFoIJPJoNVqyWvp15D6OSwWC6zW58ux1+uh1+vt+1rxudXnM5vNsNlssFqt6Pf76Ha78vm63e6hP+Mk7QOXy4VAIACHw4FGo4FyuYxOp3OkNXhcOMp7mKR7MM4w7sHZYz/3wIikDZxbGAeJAQMGJh37jqQNnD3243XxMXs91mazweFwAHgeUR4lqgSeG0Kz2SxRPF+33++LkXydsWQkzMdZLJahyFj9THt9Lv3vTSYTrFYr7Hb7gSPz84Jer4d2u41+vw9N046cKRkFNRtjwICB44dhpM8Z1IN4lOFiCtRkMqFWq6FarR457WW328UgAhhKLasGeq8UPb/rv4Dnhqbf72MwGIix1WMwGLzkCJjNZrhcLrhcLmiaJsZqHHFShq7T6aBWq8FkMqHf7x+7o0LnDIDcIwMGDBwvDCN9TsDI83WPsVgscDgcMJlMaLVaLxnFw7yuxWKBxWKRWrca8e7HAI0ysHwuQq2x7wX1Nfi+bDbbUFR/kTAYDNDpdE70NS7idTVg4DRhGOkJhT5C1RuoUUZxMBhA0zSJrlqtFnq9HiwWC5LJJKLRKLrdLlKpFHK53EvGVjWe/D8aaKa6afjb7bZEbqoTsNd7ZaTHf+tfRyWEjUp7j/q3pmkwm80nluq96NDfGwMGDBw/DCM9gVCjX32aUU0fjzJcjUYD7XYbJpNJ0scOhwO3bt3CD3/4Q7RaLfzud79DqVSSejUNrd1uh81mk9Rzv9+H1WoVFnW320W1WkW325XHEKrRVX+v1rAZ8dIhUD8Pf6aR3iu9qqZf+VnpAOij83HBpBo5/b00YMDA8cMw0hOK/ZKy9BhV1zWZTAgEApiZmUGj0YDP55N6oz6SZsSsRrZqhN3tdqFp2p7v9XXEL319WZ+O389z8Xm63a581kk1hAYMGLjYMIz0BIKRpN5QsdeYUTIj2v2kmwuFAtbW1tBqtVAul4ciVT6m1+tJny3/X08Q07+W+poHSY2qj9srda9+hlH/r2eGHwfIjGfv9X7KDAaGuwAMGDCwfxhG+v8PVRADGP/Ia5TRs1gscDqdMJvNL9WEKXLC1C/w4uDs9/vY2dnBt99+i06ng0wmIwxt9VDtdDrodrtDBlmt9eqjXvVaHiQtul9jztdRo371b/V16OO4p263G8Dza6G/lvw+7mvnLMCWuHEtORgwMK64sEZarXvqCVH8rh7C44a96rFMSesPQ9WIqPVeGpVWqyV1aKar9dGPagT1z8HX1Ds7p4VRLPKTeA90dkhyM7B/qGvPgAED+8OFNNImkwkzMzNYXl6Gy+WC1+uF1+uVemu/30en08HDhw/x6NEjiSCPKvxx0qB4hdlsfkn+kYY1FArB7/cDACqVCiqVCgaDAZrNphjpXq8nJLFIJIJgMIhut4tisYh6vQ6z2Sy90U6nE/F4HF6vF6VSCevr65IuPy0JSn17lt5JUJ2Io4IypyTOqe/hqM7JKHnT8wJmU87b5zJg4KRxIY202WzG8vIy/vEf/xHxeByJRAKzs7OiTtXv91GtVvE//+f/RCaTQa1WQ7PZPPNDZlQdVv2dKiIy6qA3m82Ix+NYXl6GyWTCkydPUK/X0ev1UK1W0W63ATxPEzudTni9Xty4cQNLS0uo1+tYWVnB7u4urFYr3G43nE4n/H4/lpaWEAqFsLu7i0qlglarhU6ns6f4yHFDrYvre77pnNjtdomCj4JGoyHP+7o2sINiP+x1/eOP67VPGiTyGTBg4GC4kEaayluxWAzJZBLT09OYn58fMtKlUgnBYBA2m21sCS+jDDUNltlsloEVKqxWK1wuF0wmE5xOpzyGwhdqytxqtcLj8cDv98NqtcLr9cLtdsvPLpcLfr8fwWAQ4XAY9XodbrdbxFL0LO+TgD5bsNd9Oq5U60n2W+vfIzkBahmBjxuFUZ9frc/vx8nci4R3HBh3R8KAgXHEhTPSPPScTicCgQBCoRBcLtcQe7nT6QxFg+OWfrRarXA4HDCbzeh2uy9JXqrpfKfTic3NTayurko6vFwuw2q1wufz4erVq+h2u6jVaqjVagCe63tbLBb4/X643W64XC44nU688847eOutt2CxWBAIBOB0OsXZ8Xg8SCaTcDqdyGaz2N7exldffYVcLndi12Ev1ree/EcHZBwFTdS+cL/fD5/PBwBSdnA4HJidnUU0GoXZbIbD4YDFYpFyAsVoeM8sFotkDdrtNqrVKjqdDnZ3d/Hw4UNUKpWh16dDRmKX3W4fek/9fh+tVgvNZnPICdBnLEZF/nxOvl8102PAgIH94UIZaVU1y+l0IhwOIxKJiJGmweOhRKUqHkBnbah5KNpsNng8HthsNjSbzZcOP7PZjIWFBXzyyScIBAL4/e9/j93dXdRqNTQaDRSLRTgcDoRCISwuLqLX62FtbQ1bW1vo9/twu92w2+1iNNxuNzweDxYWFhCLxeT/yCQna7dareLy5cuoVqv44osvsL6+jnw+D+D0oqhRr0ODNg7p1lH1co7VDIVCSCaTMJlM8n79fj8++ugj3Lp1CzabDV6vFw6HA91uF/V6He12W35vt9vhcDjg8/lgt9tRLpexs7ODarWKr776CqlUCtVqdej90MCbzWa43W54vd4hwz0YDJDP59FqtcRIq4TLUax6gmUGm80mGSrDSBswcDBcKCPNw5CHmc1mg81mk75iAGi326jX63IAjtPggFEiH2p6VP2dzWaTARNMP6vPYTKZYLPZ4HQ60e/3YbfbxYGx2+0SQfNauVwuBINBRKNRWK1WiaT1zxmJROB0OhEMBuXvR/UUnzbG4R6+qmSiRrLAi3KF0+mE2+0WDkAwGITH40G325XhITabDT6fDw6HAw6HA36/X+5bq9USp25Uul8l1TkcDiln0ACTRPiqMsKrPps+4jZgwMDBcKGMdDgcxsLCAnw+H5aWluD1emGz2dDtdtFoNNDpdHDv3j18++23KJfL+PLLL1Gr1dBut8ciCqOh4fu1WCzo9XpyqFKi02KxoFAo4A9/+ANsNhuePXsGTdNgsVgQi8Vw48YNOJ1OdLtdtFot+WxOpxMOhwOXLl3C7OwsvF4vrl27hrm5OUlr+/1+SanqpTsZnbvdbiwuLuLOnTsIBoPI5XJYXV1FvV4/les0imU9DsZir/Q868XtdhvNZhN2u11KDFarFSsrK9ja2kIkEsEPf/hDhEIhcTb7/b44ZJRn5d8NBgNUq1UUCgXUarWRDHQacLvdjqtXr+L27dtwOp2SUWq1Wrh79y4KhcJQfZs98oygR0XIHJFpRNEGxhWjeD3jhgtjpBnlvfXWW4hGo1heXpa0YKPRQKlUQrVaxZ///Gf87//9v1Eul9FsNtFoNMYi1a2CqU4AEn3xy+12w2QyIZvNYmNjQxjW3W4XdrsdiUQCt2/fht1ux+bmJjY2NoTV7fF44Ha7ce3aNbz55pvwer24fPkypqamhhjSo1jUwHMjTeNw6dIlvP/++5iZmcGjR4+QyWROxUiPElPZ6/2OA1QjR6cJACKRCKLRKJrNJh4+fIidnR3Mz8/j0qVLuHXrltx3Omg
2024-08-01 01:23:57 +08:00
"text/plain": [
"<Figure size 500x500 with 25 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-08-01 06:04:14 +08:00
"100%|██████████| 500/500 [49:45<00:00, 5.97s/it]\n"
2024-07-29 00:18:35 +08:00
]
2024-07-26 21:07:40 +08:00
}
],
"source": [
2024-07-30 00:44:16 +08:00
"# Set your parameters\n",
2024-07-31 16:25:39 +08:00
"n_epochs = 500\n",
2024-07-30 00:44:16 +08:00
"display_step = 5000\n",
2024-07-28 06:10:19 +08:00
"cur_step = 0\n",
2024-07-30 00:44:16 +08:00
"\n",
2024-07-31 16:25:39 +08:00
"batch_size = 128 # 128\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-28 06:10:19 +08:00
"D_loss_grad = nn.value_and_grad(disc, disc_loss)\n",
"G_loss_grad = nn.value_and_grad(gen, gen_loss)\n",
"\n",
"\n",
2024-07-30 00:44:16 +08:00
"for epoch in tqdm(range(n_epochs)):\n",
2024-07-28 06:10:19 +08:00
"\n",
2024-07-30 18:21:38 +08:00
" for idx,real in enumerate(batch_iterate(batch_size, train_images)):\n",
2024-07-28 06:10:19 +08:00
" \n",
2024-07-30 00:44:16 +08:00
" # TODO Train Discriminator\n",
2024-07-30 07:06:52 +08:00
" D_loss,D_grads = D_loss_grad(gen, disc,mx.array(real), batch_size, z_dim)\n",
2024-07-28 06:10:19 +08:00
"\n",
" # Update optimizer\n",
" disc_opt.update(disc, D_grads)\n",
" \n",
" # Update gradients\n",
" mx.eval(disc.parameters(), disc_opt.state)\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-30 00:44:16 +08:00
" # TODO Train Generator\n",
" G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)\n",
2024-07-28 06:10:19 +08:00
" \n",
" # Update optimizer\n",
" gen_opt.update(gen, G_grads)\n",
" \n",
" # Update gradients\n",
2024-07-30 07:37:09 +08:00
" mx.eval(gen.parameters(), gen_opt.state) \n",
2024-07-30 07:44:41 +08:00
" \n",
2024-07-31 16:25:39 +08:00
" if epoch%100==0:\n",
2024-07-30 18:21:38 +08:00
" print(\"Epoch: {}, iteration: {}, Discriminator Loss:{}, Generator Loss: {}\".format(epoch,idx,D_loss,G_loss))\n",
2024-07-30 07:44:41 +08:00
" fake_noise = mx.array(get_noise(batch_size, z_dim))\n",
" fake = gen(fake_noise)\n",
2024-08-01 01:23:57 +08:00
" show_images(epoch,fake)"
2024-07-27 05:19:08 +08:00
]
2024-07-26 21:07:40 +08:00
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}