mlx-examples/gan/playground.ipynb

637 lines
710 KiB
Plaintext
Raw Normal View History

2024-07-26 21:07:40 +08:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Import Library"
]
},
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 1,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"import mnist"
]
},
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 2,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"import mlx.core as mx\n",
"import mlx.nn as nn\n",
"import mlx.optimizers as optim\n",
"\n",
"from tqdm import tqdm\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
2024-07-30 07:17:12 +08:00
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 3,
2024-07-30 07:17:12 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-31 16:25:39 +08:00
"# mx.set_default_device(mx.gpu)\n",
"mx.random.seed(42)"
2024-07-30 07:17:12 +08:00
]
},
2024-07-26 21:07:40 +08:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# GAN Architecture"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generator 👨🏻‍🎨"
]
},
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 4,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def GenBlock(in_dim:int,out_dim:int):\n",
" \n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
2024-07-30 07:06:52 +08:00
" nn.BatchNorm(out_dim, 0.8),\n",
" nn.LeakyReLU(0.2)\n",
2024-07-26 21:07:40 +08:00
" )"
]
},
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 5,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"class Generator(nn.Module):\n",
"\n",
2024-07-31 00:50:02 +08:00
" def __init__(self, z_dim:int = 32, im_dim:int = 784, hidden_dim: int = 256):\n",
2024-07-26 21:07:40 +08:00
" super(Generator, self).__init__()\n",
2024-07-30 18:21:38 +08:00
"\n",
2024-07-26 21:07:40 +08:00
" self.gen = nn.Sequential(\n",
" GenBlock(z_dim, hidden_dim),\n",
" GenBlock(hidden_dim, hidden_dim * 2),\n",
" GenBlock(hidden_dim * 2, hidden_dim * 4),\n",
"\n",
2024-07-30 07:56:13 +08:00
" nn.Linear(hidden_dim * 4,im_dim),\n",
2024-07-26 21:07:40 +08:00
" )\n",
" \n",
" def __call__(self, noise):\n",
2024-07-30 07:06:52 +08:00
" x = self.gen(noise)\n",
" return mx.tanh(x)"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 6,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Generator(\n",
" (gen): Sequential(\n",
" (layers.0): Sequential(\n",
2024-07-31 00:50:02 +08:00
" (layers.0): Linear(input_dims=100, output_dims=256, bias=True)\n",
" (layers.1): BatchNorm(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
2024-07-30 07:06:52 +08:00
" (layers.2): LeakyReLU()\n",
2024-07-26 21:07:40 +08:00
" )\n",
" (layers.1): Sequential(\n",
2024-07-31 00:50:02 +08:00
" (layers.0): Linear(input_dims=256, output_dims=512, bias=True)\n",
" (layers.1): BatchNorm(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
2024-07-30 07:06:52 +08:00
" (layers.2): LeakyReLU()\n",
2024-07-26 21:07:40 +08:00
" )\n",
" (layers.2): Sequential(\n",
2024-07-31 00:50:02 +08:00
" (layers.0): Linear(input_dims=512, output_dims=1024, bias=True)\n",
" (layers.1): BatchNorm(1024, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
2024-07-30 07:06:52 +08:00
" (layers.2): LeakyReLU()\n",
2024-07-26 21:07:40 +08:00
" )\n",
2024-07-31 00:50:02 +08:00
" (layers.3): Linear(input_dims=1024, output_dims=784, bias=True)\n",
2024-07-26 21:07:40 +08:00
" )\n",
")"
]
},
2024-08-01 00:59:36 +08:00
"execution_count": 6,
2024-07-26 21:07:40 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gen = Generator(100)\n",
"gen"
]
},
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 7,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-30 00:44:16 +08:00
"# make 2D noise with shape n_samples x z_dim\n",
"def get_noise(n_samples:list[int], z_dim:int)->list[int]:\n",
" return mx.random.normal(shape=(n_samples, z_dim))"
]
},
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 8,
2024-07-30 00:44:16 +08:00
"metadata": {},
"outputs": [
{
"data": {
2024-07-31 16:25:39 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAWNklEQVR4nO3ca2zW9d3H8U9LobRFQCiM80GtdMgKioJYg8VJBIwRmIsYcCw76FTCxsZmZjZM3IxzU4kYNZqFLZ5ZHJODGkCshIMoIudxGEUqZwQpbS1toVz3s29yP7quz+9OvO/ceb8eX+//tdSWz/5PvnmZTCYjAAAk5f9v/w8AAPzfwSgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgFOT6wYceesh+eLt27eymrKzMbiRp165ddjNq1Ci72bFjh90UFxfbzdatW+1Gkr773e/aTWNjo93cfPPNdvPWW2/ZjSQdPnzYbu666y67+c9//mM3Q4YMsZtXXnnFbiRpwIABdnPmzBm7qaystJv6+nq7+eijj+xGSvu36KuvvrKblP9OtbW1diNJ06dPt5uU39c//vGPWT/DmwIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIOR/E69+/v/3w8vJyu3n33XftRpIaGhrspq2tzW66d+9uNx9//PE38j2S9K1vfctutmzZYjeHDh2ymw4dOtiNlHZQMOVnfv3119tNYWGh3XTu3NluJOnkyZN2k3JobenSpXZTUlJiN8OHD7cbSZozZ47djBw50m4GDx5sN3369LEbSdq0aZPd3HrrrUnflQ1vCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACDkfBCvqqrKfvjTTz9tN6WlpXaT6uzZs3aTyWTsZvTo0XbT1NRkN5L0zDPP2M3tt99uNymH9woKcv51+29OnDhhN3v37rWblKNua9assZuysjK7kaRu3brZTcphxZSfd0VFhd2k/K5K0g033GA3AwcOtJtjx47ZTcq/KVLa8dCUI3q5/K3zpgAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACDmfrUy5GDhkyBC7OX78uN1IaZdI6+rq7Gbw4MF209LSYjepV1JTrtmmXN88fPiw3RQVFdmNJHXt2tVuLly4YDdr1661m5TLr9u2bbMbKe3SZ8rvXsr3nDt3zm5Srp1K0m9/+1u7efDBB+1m2rRpdnP99dfbjSRdeeWVdpN6ZTYb3hQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAyPmaV3V1tf3wUaNG2U27du3sRpI+++wzu5k6dardrFixwm5uuukmuzl9+rTdSNKUKVPsZt26dXbT3NxsN3l5eXYjSYMGDbKblENwmUzGbn7/+9/bzbx58+xGki5evGg3KYcBd+/ebTc/+clP7Gbu3Ll2I0nvvPOO3aQcsuzdu7fdNDY22o0kPfXUU3ZTUlKS9F3Z8KYAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAQs4H8S6//HL74Rs3brSbDh062I2UdjRt586ddnPJJZfYzYIFC+xmzpw5diNJZ86csZtNmzbZTcrhvRdeeMFuJOncuXN206tXL7spKMj5zyFceumldrNv3z67kaSOHTvaTcoxxm7dutnN888/bzePPPKI3UjSww8/bDdVVVV2s3z5cru544477EZK+/dr//79Sd+VDW8KAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIOR8ASzl8NeYMWPspqWlxW4k6eTJk3YzceJEu3niiSfs5tprr7Wb3bt3240klZaW2k1zc7PddOnSxW7Ky8vtRpJOnTplNyk/h5QDjj/84Q/tpqamxm4kafLkyXbz6KOP2s1zzz1nN3V1dXazZs0au5GkDz/80G5aW1vtJuUA4QcffGA3kjR37ly7uf/++5O+KxveFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEDI+SDenj177IfX19fbTVFRkd1I0qeffmo3e/futZsNGzbYTcr/tq1bt9qNJHXr1s1utm/fbjfV1dV2M3jwYLuRpGHDhtnNokWL7KaystJuVqxYYTedO3e2G0k6ePCg3SxYsMBu3n77bbspKSmxm+LiYruRpC1bttjN7373O7tp37693aT8/UnSn/70J7upra1N+q5seFMAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAIS8TCaTyeWDP//5z+2Hf//737ebzz//3G4kaeTIkXazc+dOuzlx4oTdpFxb/OKLL+xGklpbW+2mqanJbsaPH283K1eutBtJuvvuu+1m8+bNdnPkyBG7ue222+xm/fr1diNJHTt2tJsvv/zSblJ+3ilXaVOunUrS448/bjfvv/++3bS0tNhN6gXcrl272s0///lPu1myZEnWz/CmAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAELOB/E2btxoP/yJJ56wmwkTJtiNlHb4K+UIVYqUQ2YbNmxI+q6U43t9+vT5RpqUY4KS1K5dO7vp1KmT3XTp0sVu3njjDbtJOegmScuXL7ebmTNn2s2dd95pN++8847dVFVV2Y0kzZ49225SDiT++Mc/tpuUI5uS1NzcbDcXL160m1mzZmX9DG8KAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIBTk+sGFCxfaD29sbLSbhoYGu5GkgQMH2k1ZWZnd1NXV2c3LL79sNylH4CTpmmuusZv+/fvbzdq1a+1m0KBBdiOlHS6sr6+3m2PHjtnNpEmT7CY/P+3/i1VUVNjNX/7yF7t56qmn7Obo0aN2M2fOHLuRpKlTp9rNlVdeaTcjRoywm2XLltmNlPbv3rBhw5K+KxveFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEDI+SBeU1OT/fCnn37abv71r3/ZjSQVFhbazZIlS+ymb9++djNx4kS7ufrqq+1Gkv7whz/YzW9+8xu76dGjh9106dLFbiRp8eLFdjN79my7Wbdund2k/N59/vnndiNJw4cPt5s9e/bYzeuvv243d999t9289tprdiNJjz/+uN1MmTLFbjp37mw3o0ePthsp7Xfv3//+d9J3ZcObAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAg5HwltaioyH74Aw88YDcp1xYl6eWXX7abq666ym5aW1vtprq62m5qa2vtRpImT55sNy+99JLdlJeX203KlU9J2r59u92kXLjs16+f3ZSUlNhNp06d7EZKu6TZ3NxsNyl/gymXVW+66Sa7kaRp06bZzaZNm+xm7ty5dnPvvffajSTdcsstdpPyO54L3hQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAyPkg3tixY/2HF+T8+P+xyspKu+nWrZvd7N27126GDRtmN01NTXYjpf3M29ra7KZ9+/bfSCNJFRUVdvP888/bzYMPPmg3b7/9tt1MmTLFblK/q6yszG42b95sN9dee63drFy50m4kad68eXaT8jf4ox/9yG4OHTpkN1LaAcyjR4/azahRo7J+hjcFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEHK+nta7d2/74ePGjbObMWPG2I0kvfDCC3Zz5MgRu7niiivs5rP
2024-07-30 00:44:16 +08:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"img = get_noise(28,28)\n",
"plt.imshow(img, cmap='gray')\n",
"plt.axis('off')\n",
"plt.show()"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Discriminator 🕵🏻‍♂️"
]
},
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 9,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def DisBlock(in_dim:int,out_dim:int):\n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
2024-07-30 00:44:16 +08:00
" nn.LeakyReLU(negative_slope=0.2),\n",
2024-07-30 18:21:38 +08:00
" nn.Dropout(0.3),\n",
2024-07-26 21:07:40 +08:00
" )"
]
},
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 10,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"class Discriminator(nn.Module):\n",
"\n",
2024-07-31 00:50:02 +08:00
" def __init__(self,im_dim:int = 784, hidden_dim:int = 256):\n",
2024-07-26 21:07:40 +08:00
" super(Discriminator, self).__init__()\n",
"\n",
" self.disc = nn.Sequential(\n",
2024-07-30 07:37:09 +08:00
" DisBlock(im_dim, hidden_dim * 4),\n",
" DisBlock(hidden_dim * 4, hidden_dim * 2),\n",
2024-07-26 21:07:40 +08:00
" DisBlock(hidden_dim * 2, hidden_dim),\n",
2024-07-30 07:37:09 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" nn.Linear(hidden_dim,1),\n",
2024-07-30 18:21:38 +08:00
" nn.Sigmoid()\n",
2024-07-26 21:07:40 +08:00
" )\n",
" \n",
2024-07-30 18:21:38 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" def __call__(self, noise):\n",
2024-07-30 18:21:38 +08:00
" return self.disc(noise)"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 11,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Discriminator(\n",
" (disc): Sequential(\n",
" (layers.0): Sequential(\n",
2024-07-31 00:50:02 +08:00
" (layers.0): Linear(input_dims=784, output_dims=1024, bias=True)\n",
2024-07-26 21:07:40 +08:00
" (layers.1): LeakyReLU()\n",
2024-07-30 18:21:38 +08:00
" (layers.2): Dropout(p=0.30000000000000004)\n",
2024-07-26 21:07:40 +08:00
" )\n",
" (layers.1): Sequential(\n",
2024-07-31 00:50:02 +08:00
" (layers.0): Linear(input_dims=1024, output_dims=512, bias=True)\n",
2024-07-30 07:37:09 +08:00
" (layers.1): LeakyReLU()\n",
2024-07-30 18:21:38 +08:00
" (layers.2): Dropout(p=0.30000000000000004)\n",
2024-07-30 07:37:09 +08:00
" )\n",
" (layers.2): Sequential(\n",
2024-07-31 00:50:02 +08:00
" (layers.0): Linear(input_dims=512, output_dims=256, bias=True)\n",
2024-07-26 21:07:40 +08:00
" (layers.1): LeakyReLU()\n",
2024-07-30 18:21:38 +08:00
" (layers.2): Dropout(p=0.30000000000000004)\n",
2024-07-26 21:07:40 +08:00
" )\n",
2024-07-31 00:50:02 +08:00
" (layers.3): Linear(input_dims=256, output_dims=1, bias=True)\n",
2024-07-30 18:21:38 +08:00
" (layers.4): Sigmoid()\n",
2024-07-26 21:07:40 +08:00
" )\n",
")"
]
},
2024-08-01 00:59:36 +08:00
"execution_count": 11,
2024-07-26 21:07:40 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"disc = Discriminator()\n",
"disc"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model Training 🏋🏻‍♂️"
]
},
{
2024-07-30 00:44:16 +08:00
"cell_type": "markdown",
2024-07-26 21:07:40 +08:00
"metadata": {},
"source": [
2024-07-30 00:44:16 +08:00
"### Losses"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-07-30 00:44:16 +08:00
"#### Discriminator Loss"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 12,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-26 21:36:29 +08:00
"def disc_loss(gen, disc, real, num_images, z_dim):\n",
2024-07-30 00:44:16 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" noise = mx.array(get_noise(num_images, z_dim))\n",
" fake_images = gen(noise)\n",
2024-07-27 06:09:51 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" fake_disc = disc(fake_images)\n",
" \n",
2024-07-27 05:19:08 +08:00
" fake_labels = mx.zeros((fake_images.shape[0],1))\n",
2024-07-30 18:21:38 +08:00
" \n",
" fake_loss = mx.mean(nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True))\n",
2024-07-30 00:44:16 +08:00
" \n",
2024-07-30 18:21:38 +08:00
" real_disc = mx.array(disc(real))\n",
2024-07-27 05:19:08 +08:00
" real_labels = mx.ones((real.shape[0],1))\n",
2024-07-30 18:21:38 +08:00
" \n",
" real_loss = mx.mean(nn.losses.binary_cross_entropy(real_disc,real_labels,with_logits=True))\n",
" \n",
2024-07-31 00:50:02 +08:00
" disc_loss = (fake_loss + real_loss) / 2.0\n",
2024-07-26 21:07:40 +08:00
"\n",
" return disc_loss"
]
},
2024-07-30 00:44:16 +08:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Generator Loss"
]
},
2024-07-26 21:07:40 +08:00
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 13,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-26 21:36:29 +08:00
"def gen_loss(gen, disc, num_images, z_dim):\n",
2024-07-26 21:07:40 +08:00
"\n",
" noise = mx.array(get_noise(num_images, z_dim))\n",
2024-07-30 07:06:52 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" fake_images = gen(noise)\n",
2024-07-30 18:21:38 +08:00
" fake_disc = mx.array(disc(fake_images))\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-27 05:19:08 +08:00
" fake_labels = mx.ones((fake_images.shape[0],1))\n",
2024-07-30 00:44:16 +08:00
" \n",
2024-07-30 18:21:38 +08:00
" gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)\n",
" \n",
" return mx.mean(gen_loss)"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 14,
2024-07-30 00:44:16 +08:00
"metadata": {},
"outputs": [],
"source": [
"# Get only the training images\n",
2024-07-30 18:21:38 +08:00
"train_images,train_labels,*_ = map(np.array, mnist.mnist())"
2024-07-30 00:44:16 +08:00
]
},
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 15,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-30 00:44:16 +08:00
"# Normalize the images to fall between -1,1\n",
"train_images = train_images * 2.0 - 1.0"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 16,
2024-07-29 06:24:50 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2024-08-01 00:59:36 +08:00
"<matplotlib.image.AxesImage at 0x117af0490>"
2024-07-29 06:24:50 +08:00
]
},
2024-08-01 00:59:36 +08:00
"execution_count": 16,
2024-07-29 06:24:50 +08:00
"metadata": {},
"output_type": "execute_result"
2024-07-30 00:44:16 +08:00
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaEAAAGdCAYAAAC7EMwUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAaI0lEQVR4nO3df2jU9x3H8dfVH1d1lytBk7vUmGVF202dpWrVYP3R1cxApf4oWMtGZEPa+YOJ/cGsDNNBjdgpRdI6V0amW239Y9a6KdUMTXRkijpdRYtYjDOdCcFM72LUSMxnf4hHz1j1e975vkueD/iCufu+vY/ffuvTby75xueccwIAwMBD1gsAAHRfRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJjpab2AW3V0dOjcuXMKBALy+XzWywEAeOScU0tLi/Ly8vTQQ3e+1km7CJ07d075+fnWywAA3Kf6+noNHDjwjvuk3afjAoGA9RIAAElwL3+fpyxCH3zwgQoLC/Xwww9r5MiR2rdv3z3N8Sk4AOga7uXv85REaPPmzVq8eLGWLVumI0eO6JlnnlFJSYnOnj2bipcDAGQoXyruoj1mzBg99dRTWrduXeyx73//+5o+fbrKy8vvOBuNRhUMBpO9JADAAxaJRJSVlXXHfZJ+JXTt2jUdPnxYxcXFcY8XFxertra20/5tbW2KRqNxGwCge0h6hM6fP6/r168rNzc37vHc3Fw1NjZ22r+8vFzBYDC28ZVxANB9pOwLE259Q8o5d9s3qZYuXapIJBLb6uvrU7UkAECaSfr3CfXv3189evTodNXT1NTU6epIkvx+v/x+f7KXAQDIAEm/Eurdu7dGjhypqqqquMerqqpUVFSU7JcDAGSwlNwxYcmSJfrpT3+qUaNGady4cfr973+vs2fP6tVXX03FywEAMlRKIjR79mw1NzfrN7/5jRoaGjRs2DDt2LFDBQUFqXg5AECGSsn3Cd0Pvk8IALoGk+8TAgDgXhEhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmelovAEgnPXr08DwTDAZTsJLkWLhwYUJzffv29Tzz+OOPe55ZsGCB55nf/va3nmfmzJnjeUaSrl696nlm5cqVnmfefvttzzNdBVdCAAAzRAgAYCbpESorK5PP54vbQqFQsl8GANAFpOQ9oaFDh+rvf/977ONEPs8OAOj6UhKhnj17cvUDALirlLwndOrUKeXl5amwsFAvvfSSTp8+/a37trW1KRqNxm0AgO4h6REaM2aMNm7cqJ07d+rDDz9UY2OjioqK1NzcfNv9y8vLFQwGY1t+fn6ylwQASFNJj1BJSYlmzZql4cOH67nnntP27dslSRs2bLjt/kuXLlUkEolt9fX1yV4SACBNpfybVfv166fhw4fr1KlTt33e7/fL7/enehkAgDSU8u8Tamtr05dffqlwOJzqlwIAZJikR+j1119XTU2N6urqdODAAb344ouKRqMqLS1N9ksBADJc0j8d9/XXX2vOnDk6f/68BgwYoLFjx2r//v0qKChI9ksBADJc0iP0ySefJPu3RJoaNGiQ55nevXt7nikqKvI8M378eM8zkvTII494npk1a1ZCr9XVfP31155n1q5d63lmxowZnmdaWlo8z0jSv//9b88zNTU1Cb1Wd8W94wAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAMz7nnLNexDdFo1EFg0HrZXQrTz75ZEJzu3fv9jzDf9vM0NHR4XnmZz/7meeZS5cueZ5JRENDQ0JzFy5c8Dxz8uTJhF6rK4pEIsrKyrrjPlwJAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwExP6wXA3tmzZxOaa25u9jzDXbRvOHDggOeZixcvep6ZPHmy5xlJunbtmueZP/3pTwm9Fro3roQAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADPcwBT63//+l9DcG2+84Xnm+eef9zxz5MgRzzNr1671PJOoo0ePep6ZMmWK55nW1lbPM0OHDvU8I0m//OUvE5oDvOJKCABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAw43POOetFfFM0GlUwGLReBlIkKyvL80xLS4vnmfXr13uekaSf//znnmd+8pOfeJ75+OOPPc8AmSYSidz1/3muhAAAZogQAMCM5wjt3btX06ZNU15ennw+n7Zu3Rr3vHNOZWVlysvLU58+fTRp0iQdP348WesFAHQhniPU2tqqESNGqKKi4rbPr1q1SmvWrFFFRYUOHjyoUCikKVOmJPR5fQBA1+b5J6uWlJSopKTkts855/Tee+9p2bJlmjlzpiRpw4YNys3N1aZNm/TKK6/c32oBAF1KUt8TqqurU2Njo4qLi2OP+f1+TZw4UbW1tbedaWtrUzQajdsAAN1DUiPU2NgoScrNzY17PDc3N/bcrcrLyxUMBmNbfn5+MpcEAEhjKfnqOJ/PF/exc67TYzctXbpUkUgkttXX16diSQCANOT5PaE7CYVCkm5cEYXD4djjTU1Nna6ObvL7/fL7/clcBgAgQyT1SqiwsFChUEhVVVWxx65du6aamhoVFRUl86UAAF2A5yuhS5cu6auvvop9XFdXp6NHjyo7O1uDBg3S4sWLtWLFCg0ePFiDBw/WihUr1LdvX7388stJXTgAIPN5jtChQ4c0efLk2MdLliyRJJWWluqPf/yj3nzzTV25ckXz58/XhQsXNGbMGO3atUuBQCB5qwYAdAncwBRd0rvvvpvQ3M1/VHlRU1Pjeea5557zPNPR0eF5BrDEDUwBAGmNCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZriLNrqkfv36JTT317/+1fPMxIkTPc+UlJR4ntm1a5fnGcASd9EGAKQ1IgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMNzAFvuGxxx7zPPOvf/3L88zFixc9z+zZs8fzzKFDhzzPSNL777/veSbN/ipBGuAGpgCAtEaEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmOEGpsB9mjFjhueZyspKzzOBQMDzTKLeeustzzMbN270PNPQ0OB5BpmDG5gCANIaEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGG5gCBoYNG+Z5Zs2aNZ5nfvSjH3meSdT69es9z7zzzjueZ/773/96noENbmAKAEhrRAgAYMZzhPbu3atp06YpLy9PPp9PW7dujXt+7ty58vl8cdvYsWOTtV4AQBfiOUKtra0aMWKEKioqvnWfqVOnqqGhIbbt2LHjvhYJAOiaenodKCkpUUlJyR338fv9CoVCCS8KANA9pOQ9oerqauXk5GjIkCGaN2+empqavnXftrY2RaPRuA0A0D0kPUIlJSX66KOPtHv3bq1evVoHDx7Us88+q7a2ttvuX15ermAwGNvy8/OTvSQAQJry/Om4u5k9e3bs18OGDdOoUaNUUFCg7du3a+bMmZ32X7p0qZYsWRL7OBqNEiI
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
2024-07-29 06:24:50 +08:00
}
],
"source": [
2024-07-30 00:44:16 +08:00
"plt.imshow(train_images[0].reshape(28,28),cmap='gray')"
2024-07-29 06:24:50 +08:00
]
},
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 17,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-30 00:44:16 +08:00
"def batch_iterate(batch_size: int, ipt: list[int])-> list[int]:\n",
2024-07-29 06:24:50 +08:00
" perm = np.random.permutation(len(ipt))\n",
" for s in range(0, len(ipt), batch_size):\n",
2024-07-26 21:07:40 +08:00
" ids = perm[s : s + batch_size]\n",
2024-07-30 00:44:16 +08:00
" yield ipt[ids]"
]
},
{
"cell_type": "code",
2024-08-01 01:23:57 +08:00
"execution_count": 43,
2024-07-30 00:44:16 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-08-01 01:23:57 +08:00
"def show_images(epoch_num:int,imgs:list[int],num_imgs:int = 25):\n",
2024-07-30 00:44:16 +08:00
" if (imgs.shape[0] > 0): \n",
" fig,axes = plt.subplots(5, 5, figsize=(5, 5))\n",
" \n",
" for i, ax in enumerate(axes.flat):\n",
" img = mx.array(imgs[i]).reshape(28,28)\n",
" ax.imshow(img,cmap='gray')\n",
" ax.axis('off')\n",
2024-08-01 01:23:57 +08:00
" plt.tight_layout()\n",
" plt.savefig('gen_images/img_{}.png'.format(epoch_num))\n",
2024-07-30 00:44:16 +08:00
" plt.show()"
2024-07-26 21:07:40 +08:00
]
},
2024-07-30 07:06:52 +08:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training Cycle"
]
},
2024-07-26 21:07:40 +08:00
{
"cell_type": "code",
2024-08-01 01:23:57 +08:00
"execution_count": 39,
2024-07-27 06:09:51 +08:00
"metadata": {},
2024-07-29 06:24:50 +08:00
"outputs": [],
2024-07-27 06:09:51 +08:00
"source": [
2024-07-31 16:25:39 +08:00
"lr = 2e-5\n",
2024-07-30 18:21:38 +08:00
"z_dim = 128\n",
2024-07-30 07:06:52 +08:00
"\n",
2024-07-27 06:09:51 +08:00
"gen = Generator(z_dim)\n",
"mx.eval(gen.parameters())\n",
2024-07-30 07:44:41 +08:00
"gen_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999]) #,betas=[0.5, 0.9]\n",
2024-07-27 06:09:51 +08:00
"\n",
"disc = Discriminator()\n",
"mx.eval(disc.parameters())\n",
2024-07-30 07:44:41 +08:00
"disc_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999])"
2024-07-29 06:30:08 +08:00
]
},
2024-07-27 06:09:51 +08:00
{
"cell_type": "code",
2024-08-01 00:59:36 +08:00
"execution_count": 21,
2024-07-27 06:09:51 +08:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-31 16:25:39 +08:00
" 0%| | 0/500 [00:00<?, ?it/s]"
2024-07-30 18:24:53 +08:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-01 00:59:36 +08:00
"Epoch: 0, iteration: 468, Discriminator Loss:array(0.533901, dtype=float32), Generator Loss: array(0.672384, dtype=float32)\n"
2024-07-30 18:24:53 +08:00
]
},
{
"data": {
2024-08-01 00:59:36 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9Z2+cWXomjl+Vc44sspiDJEqt2CN3u8fjmdn1OMAGNrxYvzCw/hD7QRb7AfaFFwusvVhgvU5jeFK3OklqZVHMZJGsnHN+6veCvm6e4vSsqWIv/sAfPADRLYksVp3nnDtc13Xft240Go1wta7W1bpaV+tqfYdL///rN3C1rtbVulpX6///1pVzuVpX62pdrav1na8r53K1rtbVulpX6ztfV87lal2tq3W1rtZ3vq6cy9W6Wlfral2t73xdOZerdbWu1tW6Wt/5unIuV+tqXa2rdbW+83XlXK7W1bpaV+tqfefryrlcrat1ta7W1frOl/Gi32i1WqHX6zE1NYVgMIj19XX8yZ/8CYbDIf7yL/8SGxsbaLVaaLVaCAaD+P73vw+Xy4WtrS2cnJygXC4jlUphMBhAr9dDp9PBbrfDbrcDAEajEUajETqdDvr9PgwGA8xmM4xGI1wuF0wmE0qlEkqlEgwGA6xWK6xWK27fvo25uTns7u7i6dOnMBgMiMfjsNvtSCaTyGQyMJvNcLvd0Ov1aLfb6Pf7MJlMMJvNsNvtmJ+fh81mQyqVQjabRbfbRaPRgKZpMBgM0Ov1MBqNMJlM0Ol0yGaz773RH374oeyf3+9HKBTCysoKOp0Ofvazn+Hg4AAmkwkWiwUOhwPxeBxWqxX9fh/9fh+ZTAYbGxvQNA2BQED2zuFwYDQaod/vYzAY4OjoCIVCATabDW63G3a7HcvLy3C5XNjd3cXBwQH0er189k8++QQrKyt4+vQp/vZv/xYejwf//t//e8zMzODnP/85vvjiC+j1elgsFhiNRng8HthsNjidTng8HlitVgSDQZhMJmxsbGB3dxeNRkP2KBQKweFwyD4CwOvXr997/27dugWDwYDr169jYWEBLpcL0WgU7XYb//RP/ySfS6/Xw+VyYXV1Vc6WXq/H/v4+vvrqK3Q6Hfm+69ev49atW6jX69je3kaj0UC1WkWz2YTNZoPH44Hb7caDBw8QDAbx9ddf48mTJ9Dr9bBarXA6nfjJT36Cmzdv4uuvv8b/+T//B4FAAH/+53+OhYUF/NVf/RX+5m/+BhaLBYFAAAaDAZqmAQCCwSBisZjchdFohM3NTRwcHGA4HGIwGECn08HpdMJiscBiscBms028fwDwu7/7u9DpdHA4HLDZbIjFYrh79y56vR7+/u//Hru7u3A6nfB6vXC5XFhYWIDNZoPRaIRer8f29jZ+9atfod1uYzgcAgDm5+exsLCAdruNk5MTdDod1Go1tFotOJ1OBAIBeDwe3L9/H6FQCE+ePMGzZ88wGAzQ7/dht9vxp3/6p3j48CEePXqE//E//gc8Hg/+7M/+DLOzs/hf/+t/4W//9m9hNpvh9XphMBig0+kAANPT01hcXMRoNEK9Xke328Xu7i6Oj4/lM/NZmUwmmEwmWK1WAMDm5uZ779/du3dhNBpx7do1zM3Nod/vo9lsotPpYH9/H9VqFa1WC81mEwaDQe7K7du3sbS0hGaziWKxiGaziYODA1QqFQSDQQSDQfT7fdTrdfT7fbGFzWYT5XIZFosFa2tr8Hq9yGQySKfTsFgs8Pl8MJlMMBqNMBqNSCQSePv2LVwuF374wx8iEolgY2MD29vbck8NBgOazSYGgwHcbjd8Ph88Hg9u3rwJl8uFt2/fYnd3F+VyGYeHhxiNRojFYvB4PAgGg4hGo9DpdPgv/+W//Iv79V7ORafTod1uI5fLYTAYoF6vAwBOTk7ksHQ6HRSLRTx79gwWi0UOWqfTwWg0gslkgs/ng9VqRavVktfQ6/UwGAyIRCLwer3yd71eD4VCAbVaDcPhcMyYapqG4+NjNJtN5HI5+btGoyHvRdM0DIdDdLtdGAwG2O12GI1nH3s0GuHo6Ag6nQ4ulwsrKyvyswaDAV6vF1arFScnJ3LxJ1l0zryo+XweqVQKvV4PmUxG3h8ADIdDtFot9Pt95PN5lMtleT82mw0LCwsIBAJIp9NIJBLiLADAYrHIgZ2dnYVer0ez2US9XsdoNEI4HMZgMJBAYHt7G6VSCfl8HuFwGHa7HeVy+fRwGI2Ym5uTCwoAlUoFhUIBg8EARqMR7XYb1WpVntfq6ira7Tai0Sg0TZODX6vVUKlUMGm3ocFggNFohHw+j+FwCJ1OB51Oh16vh2QyOeY0RqOROOtWq4VerycX3u/349atWwgGgygWi3j9+jV6vR7a7TYAwGQywW63w+v1IhKJwGg04vDwEIlEAsViEU6nE5qmQdM0tFotvHnzBvl8HsfHx9A0De12Gy9evMDx8TGKxSICgQAsFouc6UKhgGazKQ5bp9NhOBxiNBrB4/Hg9u3bcl55dofDodyry3RrKhQK0Ol0GAwG0DQN+/v7OD4+xmAwQCKRQL1eR7vdRrlchs1mQ6PRkABnOByiXC5D0zQ4nU7cunULgUAAqVQKW1tb0DRNnhHPgt1uRzgchtlsxuHhIY6Pj5FOpwGc3juewydPnshZ7vV6qNfrePToEbxeL7LZLKLRKEwmkzjXcrmMVqsFs9kMi8WC0WiEdruN0WiEaDSKmZkZ2TdN0yTwajQalzqDU1NTEgxUKhUJco1Go9wdq9UKm80mn4379ubNGzlbdrsds7OziEajMJvNMJvN6PV6ElRz2e12hEIhCcyazSbcbjc8Hg90Oh1MJhNGoxFqtRqq1SrMZjMWFhbkrjabTTnvg8EAw+FQbBv3T6/Xo9Pp4M2bN9DpdCiVSmg0GhgMBrLf3Lt2uy3P7yLrws6FD7HT6YgH3t7eHjPAg8EA3W4XvV4Pm5ubYvRMJhN6vR4AyPe7XC6k02kUCgX5e7PZjGAwiLm5OWiahtFohEajgUwmg0qlIg+m1+tJhsN/a7fbcrhpmHu9HkajEYbDIXq9nmRBTqcTg8EAg8EAvV4P6XQag8EAH3zwAeLxuBx8o9GIWCwGp9MJANjd3UW/37/w5p7fP51OJ5FXqVTCzs6O7AsAcVyapsnny2azSKVSktVYrVZMT09jamoKxWIR2WwWBoMBTqdT9tpms2Fqagpra2vodrt48eIFisUibDYbfD4f2u22RC9HR0dIp9Pi9M1ms0RQBoNBLpTVasVwOESlUkGtVhNHx6hxOBxiYWEBs7Oz6Ha7CAaDch540ZgNTrp4UZvNphhBOjBeep1Oh36/j263C+DUGdbrdcnUPB4P7t27h4WFBfz85z/H48ePJUMwGAwwmUzQ6/VwOp0IhUIYDAbY39+Xz8xzzv3b29vD8fGxBE/dbhdbW1uw2WzodDrwer0wm83ilPr9Pmq1GoxGo2RWPAPxeBzRaBSDwQCdTgeDwUDOdqfTQaVSudT+VatV6HQ6cfiVSkUCQ55LZk0WiwXtdhsmkwmNRgPdbhc6nQ56vR4OhwO3b9/G4uIi/uEf/gFffPGF3Gv1OVitVvj9fgkCG43Gr533wWAg0TKz9OFwiJcvX8oeBYNBsSWDwQD5fB6NRkPQB+DUAALAysoK5ufnxSH3+33Zw263i3q9PvEehkIh+f96vQ6bzQaHwwGTySQ2MBwOY2pqCs1mE0dHR2i1WigWi0gmkwgEApibm4PZbJYMgEG5wWBAu92W/QNOnYvb7YamaSgWi2i32wgGg4hEItA0Db1eD91uF+VyGY1GA2azGfF4HEajURyux+NBNBpFs9lENpuVM9hoNATN6fV6ODk5kWfMZ0MnRVvZbrfRaDQu7Jwv7Fx4WWkgeRF5IFutFrrdLobDIUwmE5xOp/zXarWKwdLpdPIh+Ge+LgC0Wi1UKhWJ8mnkCUO43W70+30YjUZomga73S4emMaZkZlOp4PNZpPDzg0fDAZwOBzwer3odDoolUqS4ubzeXlwfEh0LgsLCxNHPdVqVSLrfr+PVqsln5EXivvC32E0GjEzM4N
2024-07-30 18:24:53 +08:00
"text/plain": [
2024-08-01 00:59:36 +08:00
"<Figure size 500x500 with 25 Axes>"
2024-07-31 01:01:14 +08:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-08-01 00:59:36 +08:00
" 20%|██ | 100/500 [09:40<40:52, 6.13s/it]"
2024-07-31 01:01:14 +08:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-01 00:59:36 +08:00
"Epoch: 100, iteration: 468, Discriminator Loss:array(0.546182, dtype=float32), Generator Loss: array(0.691693, dtype=float32)\n"
2024-07-31 01:01:14 +08:00
]
},
{
"data": {
2024-08-01 00:59:36 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9WZNc2XUdvHKex8qqypqrgAIaDTTQ6IFEs9kcJUu2ZEqW5ZBNR8gOhewHP/nBD/4FfrbDDj84QuEIOxyyKZMaTLElkRTZZJPdzR7RjRkF1DznPM+Z30N9a9fOi5tVWZWJZvv7akcgAFRl3nvuuefsYe2197F0Op0OzuRMzuRMzuRMhijWX/YAzuRMzuRMzuT/e3JmXM7kTM7kTM5k6HJmXM7kTM7kTM5k6HJmXM7kTM7kTM5k6HJmXM7kTM7kTM5k6HJmXM7kTM7kTM5k6HJmXM7kTM7kTM5k6HJmXM7kTM7kTM5k6HJmXM7kTM7kTM5k6GLv94Nf//rX0W63sby8jM3NTejCfovFAovF0vWzTqcDi8WCYDAIr9eLWq2GUqmEdrstv3c6nXA6nWg2m6hUKmi323A4HLDZbGi1Wmg0GgAAq9Uq1+f3eV+XywW73Y5ms4larSb35e8BwGazweVyAQBqtRparRZsNhvsdjvsdjt8Pp/87Xa7USwWkUwm0Wg0UKvV0Gw2YbVaYbUe2GKO60QTbT+Y6na7DbOmCMb5M/6Of466hn5mAD0/oz/HZzLOLQC43W64XC60Wi1UKhWZW4vFgna7LZ/X77+fexrv04/o5zrJ5/tpQGGxWGSNtVqtvr4ziBz1ro3CcfF7wOnWn/6+2c9O+sw2m+2JdXBasdlssFqtaLfbaLVamJubw7/5N/8Gly5dwk9+8hO89dZbSKVSePjwIarVKpxOJ2w2G5rNpszFUWuQc6j3zWnesdvthsVigc/ng8fjER3F63Y6HcRiMYyPj6NSqWBjYwPValWeLxaL4fz58+h0Otjc3EShUJDfud1ujIyMwGKxYHt7G6lUCs1mE/V6vesZ+KzUV9Rtdrsd4+PjOH/+POr1Oh4/foxCoYBWq4VWq4V2uy3XstvtsFqtaDab8qdUKqHVasHj8cDlcom+43sBAJfLBbfbDQC4efPmsfPVt3F5/PgxOp0O8vk87Hb7Ey+x0+nIRHU6HdmktVpNFo1+CfwuH0ArLSoxp9MJAF0bXitDq9UKv98Pn8+HYrGIer3+xJiAA2WmFyHHCQBOpxMTExPwer0oFAqm1wEONgDHcxrhC9Ibqdlsypj4t9miN26IozZGP8rd6XTC7/fDarWiVquh0Wig1WrJ4qM0Gg0xOkYFYjbP/NtoCPkzzvmnIXosxp+ZffY4w3haMVPeJ7mPHtdJDaxR6OBwz9HB6nQ6KJfLsh6PE6vVKk5go9F4Yt2cVLReAIBMJoM/+7M/QzgcxtbWFnZ2dlCpVGQPU5foNXnU3mi327ImaRRPI06nUxQ7lXO1WoXdbsfU1BQCgQCq1aoYBrvdDq/XK06ax+NBoVBAu91GuVxGrVZDKBRCJBJBs9lEIpFAvV5HLpdDtVoFgC7daLVaEYlEEA6HAUD2balUQq1WAwBxmPm5jY0N7O3tdenPUCgEp9PZZUA4T7w/3zH1sM1mE0e7X+nbuGxsbMjgbTZb14bUA7fb7V0vvtFoPOFpGaMPM2Vkt9vhcDi6lBsnWN/P6/UiGAyaPrQ2Lvw9Fwbv6XA4EIvFEAwGUa/XkU6nexoXTvZpxDhH9Cj6VTTDUnxcNKFQCFarFYVCoef1OcZ+xmLmLfLnOuL5ZUg/kcIgHu1R9z1JlGImpzVKZmKz2QAc7mGHwwG32y1KpR/FofcN9/qgYjTs+XweP/7xj3t+3mxN8jpH3YN777TC/a8joXq9DovFglgshng8LsocONQ1gUAAXq8X7XYb1WoVjUYD1WoV9XodDocD4XAYhUIBW1tbYnSazaY4tNq4hEIhzM7Ool6vI5vNolqtIpPJIJvNAoA43AsLC3C73dja2kIul5Pf2e12BIPBrnmgk8HxlUol0RM6YjqJYQFOYFyoINxuNzweD+r1Osrl8hOKw6gwbTab/KGxKJVKYnX5dy+oyGq1wuVyPWFkGAHUajXkcjnU63WBcJrNZtei5zXordHroNeVTqdRq9WQz+dRrVZhsVjg8XjQ6XRk82kDNYjoKO4oMfN4Twph9PKam80myuVyV+RyUiVxUkWnjcynLf2M9WlELZ+1nrCtVqsL+tN/jA6f1WoVxWY0PFxDXMu9REcKdrsdFovF1Nk8Sqg7eM9eeoKKkPrETIionFYajYbAYn6/HxaLBaVSCRaLBZlMRnSbw+HogqE8Ho/oEP4sHo+LA5ZOp1GtVkVPEX6z2+2i++jw+v1+BAIBAEAoFEK9XkehUEA2mxWITCMsGqXh3+VyGa1Wq+vnNIR+v19QKL4zjttut58IvTmRGbdYLAiFQojFYshms4LT8cXSWGhLy7yKx+NBOBxGs9nE1taWLDIuWu058t8aj7RarahUKhKGer1eAEChUEAymYTL5YLP50O73UY+n5eNxIUXCAQkZLRarSgWi6jVaiiXy1hfX4fdbke5XEalUoHP58PY2Jh8FgBSqRT29/cH9tSazeYTxrTXhtG/1+FrPxGPhqH0nHKhpNNpuVY/z2R2rZPKL0PZ9jNP/Xzuad3/JDKoYeZeM8KxwJMwosPhQCQSgc1mQyaTQbFYlM91Oh2BbY56PpvNBrfbDZvNJnnNXC4nnrSZaIgGgMBJjUZD9I3Zfbi/i8UiSqWS6bUHzQ9VKhXYbDZ4vV6MjY0hk8mgUCig0WhgbW0Na2trovwZWTSbTXg8Hvh8PnFgvV4vXnjhBYyPj+OTTz7BRx99JI67y+VCOByG1+uVZ282m0gmk6hWqxgbGxMYf2RkBPV6HXt7e9ja2oLT6UQgEIDH45HIhPqPjl2n00E6nUar1YLb7YbP54PL5UIoFILD4eiCzB0OB1qtFvb29lAsFhEMBuHxePqer76Ni1ZuR0EIZrkUPqQ2Ov3ABYxSet3HmMDj+PSmMW4go1I3XlsrG3p0hOKGkbzktY2in1MbBiMe3UsIeRyV7Nf310b9JIaqn+sfpQQ/S968jriN7934OYox+upnLj5Lzwz0DwHqvap/1uu7RsPQ676nGS+9euZPjUbGqG/MxsooaBBo27j+NWRE/aBzz0Q8NEJAx5e513q9LpECHXU9TkYxbrf7CVifUNn4+DgikQg8Hg/cbrd8lkgTo8hO5yBvzijQTH8a586YM+93/vo2Lj6fD8BBpKBhLeBwMfFvHW0wpNOTCRwkx/gyjBuaoXulUoHFYhGoSiv8ZrMpkx4MBrsS5UavvdFoIJfLwWI5YJcxOdVqteByuTAxMQGfz4dUKoVsNotOp4NkMinfB4BqtdqXYj2NGNlwVqsVHo9HojWG0pppp8XhcCAYDEpERmZXr8+fNPqw2WyyQMnq02JcbPR6eK9hMrCMzKJBhZvODAIyJoF5T25+AKjX6z3HobFsDUWZjQE4Xvl+WkaKa5HPxufjfgPwBESl1wjzCq1WS/ZurVaDxWI5Flo2ri3mHwKBAM6dOwer1Yr19XXJMXBcxvEbx9putxEOhzE+Pj4wsaRQKGB/fx+lUkmeh+vH5XIJlMXcajabRaFQEJir2WxieXkZiUQCOzs7aDabcDqdCIfDcDqdQgrg3nE6nRgdHRWDkc/nkcvlsLW1BZvNhldeeQV/7+/9PRSLRWQyGdhsNsRiMYlYCJVFo1G0Wi3cvn0bu7u7sj5JzOh0OigWixKlcs0TRqvVaigWi8M3LhwoE1FGDN1oHOgl6NwKqcL0RIDe0Ym28ITajHRgi8UCv98v9LhewkQV72O322UymXAj06NSqaBWq6FSqXRZ66dlWDhfeg40G4eG5ah705jbbDZUKhX5+bDGy9CaSsJs7Maf0bE4iZff71g0PDjotXg9Ph/fuY606U3SodHGsxfpAUCXB9rL0H9WRec4jIw/jRpQdAJYr1kq39PSpzm/FosF0WgUVqtVEubG8WrR64Rr0O12C9R3GuF
2024-07-31 01:01:14 +08:00
"text/plain": [
2024-08-01 00:59:36 +08:00
"<Figure size 500x500 with 25 Axes>"
2024-07-31 01:01:14 +08:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-08-01 00:59:36 +08:00
" 40%|████ | 200/500 [19:38<29:59, 6.00s/it]"
2024-07-31 16:50:32 +08:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-01 00:59:36 +08:00
"Epoch: 200, iteration: 468, Discriminator Loss:array(0.556852, dtype=float32), Generator Loss: array(0.686229, dtype=float32)\n"
2024-07-31 16:50:32 +08:00
]
},
{
"data": {
2024-08-01 00:59:36 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz96ZNc53Ufjn963/dt9n0ADAiQAMGdtChRIi2Jshw5SiwnFacqLsdJKpU/Im+SypukKi/8zknZrpKdOF+VHcu0I9MGRZEEQRIEQKwzmH3pnt73vfv+XuB3Dp6+c3uml9szA7k/VSgAs9y+97nPc9bPOUcjSZKEIYYYYoghhlAR2pO+gSGGGGKIIX75MFQuQwwxxBBDqI6hchliiCGGGEJ1DJXLEEMMMcQQqmOoXIYYYoghhlAdQ+UyxBBDDDGE6hgqlyGGGGKIIVTHULkMMcQQQwyhOobKZYghhhhiCNWh7/QHNRrNIO9DEVrtY93XbDbbfj8YDMJutyObzSIej0OSJOh0Omg0GjSbTTSbTYhNCDQaDTQaDSRJgrw5gcFggMFgQKPRQLVahSRJij/fS1ODXtdPq9Xy57dbh5MArQvQ/v20w3Gu32mB/P77aYzR6+8aDAYA4HOh0Wj4jCmdB4LJZILBYECtVkOlUml7LkRoNBro9XoYjUZIkoRqtYpGo3HgZ3U6HcbHx+F2u5FMJhEOh9FsNg89w4SZmRksLS2hWCzi1q1bSKfTivdFay9e4x/aHqT3AQCNRqNvWdLJ+nWsXE4CRz2AJEnIZrMol8uo1WrQaDSwWCyYmZmBzWbDzs4OwuGw4jWVrm2z2WC321Eul5FKpVoOw0l0ydFqtRgfH4fP50M6ncbOzg7q9XrLz9BBOm4MuwZ1h9OwXqIiob+VhIzcGGs2m6jX6y1KhQQtXRN4YmSQQQQA1WqVrykKfvGz0uk0yuUyyuUy/z79EX9OrizS6TSWl5dRq9X4d+lntFotK9NarXaihpnBYIBer2ejVYROp4NOp4MkSbzGakJ8F41GA0D3xmCveOqVS7FYRLFYZCvMbDZjZmYGXq8XxWIRe3t7HV1Xo9HAbDbD5XJBq9UinU53fB+DglarRSAQwPz8PLa3txWfBRgqmCE6AwlrEjJAZ2eMvBxJkloEPl1Tq9W2XIe+1mg02BgSFYvck8jn88jn8/x1Ost0XfodEop0Lfo9uZKk+zQajQCeeGonBfLgarUaarXagbWiaInccFQDGo0GOp0OzWaz5b0fB05UuZjNZpjNZjQaDRQKhSM3gOjG08/KraBqtYpIJIJ8Po9sNtvV/ZTLZWSzWZRKpWMRnEeF/ZrNJtLpNHZ3d5FMJk9VWGyIpw8kXGhv63S6A9a90r4n4S0qGlFByK8rnk25t2Gz2WCz2VCtVpHL5dBoNNqGng8L1cmvT58h3gMJ65M+N41GA7VaTTEsSPfZ6T3Kle5RkRXx3R03NJ12RR5EvDEUCmFkZASFQgE7Ozvs2so/k26RciKim66UUzEajdBqtajVaodaA/Lr04s7ytJRI16r1Wqh1+uPdIf1ej271LVa7cC1nkbv4R9avFtt9PrO5WExi8UCl8sFAMhkMpxPOer6ZOTJc4FyRSL/mkajwdTUFGZmZpBOp/Ho0SOUy2VFT8dgMBy4vjzsJlrj8rMsV4BiWA5AT1Z8P3tQKe8jfq/Td2o0GmEwGFhuiDmp45QFpz7n0m+SvN01K5VKT7973O5zJ89cr9fbKkjxINGBo78NBgPMZnOLt0ffq1aryGaznNh7GhXUEN1DboSpDfGaYm5GDHeRsaTX63k/6nS6A6E1Jblw1D3LlYn8WieJduQCOp9HEXZojcjzq9frKBQKnMcZZMjrMMV4GAamXDq5IbKWRKtchPx3SVOrpZROctOJQr3b+5D/vNFohMfjgcFggMVigdFoxOzsLK5cuQKTyQSz2Qy9Xg+LxQKr1YqVlRX86Z/+KWKxGHK53AGPUW2cVE6oE/R6cI4LopWu1j2SoGo2m8hmsxxOPky4kZIQQ7nkDeh0Ov4Z8WfNZjPsdnuL4jCZTMjn8yiXy9Dr9RwaNxqNKBaLHDqS51B0Oh3nUOr1+gHGUzuygJKCOynQM4jrbTabYbFYWpSFHAaDAT6fD2azGefOncOZM2eQSqVw79495HI5RKPRlhSAmnuZDAKdTneooauEE/VcKpVKV16GGN/9ZYCawsJms/FhNpvNmJ2dxYsvvsiWjsFggMPhgMPhgNvtxs9+9jMUCgWUSiVV7qEdTnM4S0kgnTaIwlGt+xStZaVQmJJ1Ld4LgYS36BnT3xqNBiaTCXa7veV3dDod53e0Wi10Oh1MJhOMRiPq9TorPSUDkujJSmHkduskkhDk93/cIEEtSRIb0zqdDmazGdVqFcViUfF36Hw7HA5MTExgaWkJkUgE+/v70Gg0SKVS/LP0TtTc08Ro6zaqMzDlIrJChmgPk8mEhYUFeDwe7OzsYHNz89CNQTklsiRqtRq8Xi9+5Vd+BR6Ph2sS5ubmMDU1xZ4MHWKTyYRAIIBLly7B7/fj5s2bKJVKAwuPnWbBrXZIdhAQmVpqXhN4Ej7S6XRsGWezWRZWYq0J/SzVoNjtduj1eg4l2+123m9k6JDyMBqN8Pv9MBqNSCQSSCQSzA6r1+twOBywWq1IJBKoVqtsIYtJa41G02JYykNJ8p89zrCYVquFy+WCyWRCoVBALpdT/Llms8n1c3RPtVoN+XxesfbEZDLBZrPB6/Xirbfewvj4ODNZm80mZmdnEQgEkMvlkMlkOAel5vOKuZ1uDfuBei6n+dCeFphMJjz//PNYXFzERx99hO3t7UNfItXyGI1GlEolNBoN+P1+vP3225iamuKfCwQCmJ2dhV6vP6DkQ6EQXnrpJUxOTiISiTDF+ZfJK+wUp32PqpmoFT0SMbxhNBoxNjaGQCCA9fV1puHT3hGZZGRJO51OWK1Wjj4Eg0G8/PLL8Pl8CAaDcLlcKJfLyOfzsFgseOaZZ+B2u3Hz5k3cvHmTBWStVoPP52PvOpvNolqtolKpHFBqJDhFeq34ffEZRQUz6HdMytntdjNTVekzm83mgUjNYdEbs9kMv9+PyclJfO9738OFCxfw2Wef4fPPP4fRaMSZM2dQrVaxsbHBNXCDoDP3es1TXeciB7ngRqMRDocDGo2GqcNqXV9OiRw0ms0mVyZnMpmOfodYNaFQCD6fD2fOnEEwGITb7eafsdlsHFsXIUkSTCYTQqEQJEliC5SuO8Q/HIgEkHK53MIYAw4KZWI4iixHh8OBUCiEQCAAr9cLl8sFr9cLr9eLSqUCm83GoRuysCVJ4jxCo9GA2WxmD8disfC/KVldqVQUGWjtzmi7eppBnWnKoRSLRcXccTfQaDSw2WwwmUwYHR3F4uIidzAwGAxwu90YHx+HwWCA0+lEtVqF1+uFx+NBoVA48YJREceqXI6q6zgKRqMRRqMRgUAAly9fhl6vx+eff4719fWeFILcoqfEFQnv41AwxWIRH330ET7//PO2CT0RkiShVCqhWq3iH//jf4zf+Z3fgdPpRCgUYgo2WXaUaBWh0Wjg8Xjw0ksvIZFI4OrVq7Db7SiVSgcKvIb45YLSfjcYDNBqtdje3oZGo2lRLtT1ggS10WiEzWaDVqtFtVpFrVbDM888g9deew02mw3BYBAWiwWzs7MIBoNM1InH4/g//+f/YH19HXa7nfOAExMT0Gg0TCxxOp0YHR1Fo9Hg3Es0GkUsFkOj0WhJ9h+WI6KviUWjg1QwjUYDkUiEQ9X9yCG9Xo+ZmRlMTEzghRdewD/6R/8IdrsdFosFWq0WS0tLOHPmDIDH8rRQKGBrawv1eh3b29usYE7DOe66t1ivN91t/oU2B1lLVH1vsVjg9XoxOjrKDKhe7o1qZogkQIm/4076UaHkURDviRSQ1+vF0tISTCYTx8fpGRqNBlt8tH7A4zXS6/VwOBxoNBoc7242mygUCvwzQ/zygYwN0bijfVEulxVbC9HflJsxmUxckwIATqcTY2NjMJvNcDgcMJlMsFqtsFgsrJwAIJlMYm9vDyMjI7BarTA
2024-07-31 16:50:32 +08:00
"text/plain": [
2024-08-01 00:59:36 +08:00
"<Figure size 500x500 with 25 Axes>"
2024-07-31 16:50:32 +08:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-08-01 00:59:36 +08:00
" 60%|██████ | 300/500 [29:21<19:26, 5.83s/it]"
2024-07-31 16:50:32 +08:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-01 00:59:36 +08:00
"Epoch: 300, iteration: 468, Discriminator Loss:array(0.561531, dtype=float32), Generator Loss: array(0.66357, dtype=float32)\n"
2024-07-31 16:50:32 +08:00
]
},
{
"data": {
2024-08-01 00:59:36 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy953Ok2XUe/nTOOaEbQDcyJs/ObJhd7i65YhC5FCWSJUqkS6lky1a5rM+usr/Yf4JdKllWSR9cJZZlu0TJphjFJTdxlxtmJw9mMMi5c87x92F+58ztF91Ao9EAGrP9VKEwAzS63/e+9574nHNkjUajgQEGGGCAAQboIeQnfQEDDDDAAAM8fRgolwEGGGCAAXqOgXIZYIABBhig5xgolwEGGGCAAXqOgXIZYIABBhig5xgolwEGGGCAAXqOgXIZYIABBhig5xgolwEGGGCAAXqOgXIZYIABBhig51B2+kKZTHaU13HsUCgUUKlUaDQaqFQqqNfrLV8n3jc1M+imqcHTtn7A43vqZi0Os36tnkev0c19KRQKKJVK3k/7/b1CoYBarUaj0UC5XObX02c3Gg3IZDJotVoolUqUy2WUSiUA3d+3Uvn4uNfr9V3voVKpIJfLUa1WUavVunp/gkKhgFwuR71e7+i9juOZStEvZ9jr9WJ4eBjZbBarq6soFos9/4yjQCfr17FyedpQr9dRqVT43+0gk8kgk8n2fM2nFSfROUgUwnTYe3kdGo0GSqUS1WqVhfleEK+hWq12/Dm0/0iRSO9BLpdDLpdDrVZDpVKhXq83KaFuQHu41XvUajXU6/We7HN6Np0K41bXQ3/fam2eJqTTad4LJI+eFpyYcunW6u0VGo3GvlaVeEhO+noH2I1eW7wymQwqlQoajQalUqkjYS4aHwcRzO2UEXkscrmcvSGVSoVyuXzg+2n13u3Qa+PpIMqlFeRyOWQy2aG9qH5HLpdDLpc79Pv0o3w6MeXSbwvRCqLVdBqu99OKXlm5JPBlMhmq1WpH73VU+6LRaLC3Uq/XO76eTnEYr08mk0Gj0UChUKBcLqNSqUCn08Hj8UCpVCIejyOTyaBer+8p9OTy5pSv9LwNzlxrtHp2/bhWn9qwWKcYhMP6G6KFLOYS9hKee/2uVCp1FA4jdCoEDyLMyauu1Wqs7HopPMgrAJ7sb/H6OskXWSwWaLVapFIpJJNJ2Gw2fPazn4XBYMDbb7+NeDze5PVLz5FMJoNCoWjybkiJitc1QDMoXNpJ5OWkMVAuA5xKUPjoKN63n9CvFnyr6yJvC3isPNRqNQwGAxqNBorFImq1GgwGAwwGA+r1OkqlEucbeu2ZPc3oZE8cRT7yoJB1Os/laWQ7dYt+YZqcVvRq/aQ/E99X/N1pSgx34qUcli122HAvKQ6FQsGJaL1eD6/XC6VSiWg0ilQqhampKbz44ouo1WqYm5tDKpXCl7/8ZXzpS19CMBhkD2dpaQmRSATlchmFQqGre5Ne3373dprP8H6hRplMBqVSCblczoq73ftQaPKgXtCALTbAU41Ow1H9YMX1A1qFD7tBo9HYFTosFovY3NxkwSeXy2G1WjEzM4NKpYJoNAqZTIaLFy/i9ddfx9LSEra2trC5uYlQKIR4PM6C8TDX1i8K4CjRbn1or8vlcqaWkzfZKj8jno2jIAQMlMsATyX6PdnZDkd5rb3KY5DnQnUxRKGVeo7JZBJ3796FSqWC2Wzmr2q1yl5KsVjkr71CY6IgJCUkpYGfFu/0qEBKxWaz4XOf+xw8Hg+vSTQaxYMHD5DNZhGJRJDJZADsTU8/LAbKZYBTgW69j34SNnuF8dq9tpfX30vlotFooFKpUCgUdoVd6Nqj0Shu3LgBk8mEq1evwuPxwGazcUFoLpdDPp9HoVBAoVDoyCKnkI9CoeDfUbFmPz3rkwApF4/Hg29/+9t45plnoFQqoVQq8eDBA/yv//W/EAwGcfv2beTz+Y6LXLvFQLk8JWhF65Ti0374TgqtaqVE5dGOSQXgWIXmfsqPBLzIVhItX6IekxKoVqvI5XKQyWSIRCKo1+t49OgR1Go11tbWEA6HEYvFUCqVDpQfUSgUUCgUXFs0YJY9hslkgs/ng9/vh81mg8FgYEWs1WqhUqmYoXccXt5AubTBaYjTi2ECtVrdJLxaWSWiIOvn+2qF03a9BDEGDjxJplORpMiaIsjlcmg0GgBgltVRg4SOuO+lik2pVEKr1QIAyuUyKxd6DV0ntX+hsFcikUAsFoNKpcJbb70Fo9GIYrGIaDSKSqWCfD7f0TXW63Vu26TRaHZ5PNLrP617pltMTU3hN3/zN+Hz+TA2NgaLxcLroFarudsDcDxGy0C59CHI9ZcynqR1AXSY6MCRRQk83jzUXqRarbKF92k8dL0CWczSzgDt1pRer1AooNFoIJPJUKlUmpRFK0q1NNF6XNjvM0V2kViTQpAWQNKek8lkKJVKTeEtsTC00/1IXpNKpYJarW5qn7PXdX8aIJPJYDAYMDIygqGhIej1eiiVyibPjgzOwxI6OkVPlYu0OOs0C7G9iu+OGhaLBZcuXYLJZOLN4HA4MDY2xpXR4iETC6rosBUKBWQyGWQyGXzyySeIRqNIp9NcOd3vBVj7QRR0tO+k99UqrNPq5/uBhNrU1BTGx8ebQgvZbBbFYrEpqUyHORAIwOfzwWg0wu12o9FoYGVlBbFYDKurq5ibm2tpQdZqNWZjHddzoj20V0udWq3GLXH2CkWJZ5++i0wwkQDQ6bNQq9Uwm83QarUIBAKw2WzY2NjgnA+t414U66dV+dDaOp1OnD9/Hna7HQaDgde7Xq8jnU7j/v372NzcRDKZPJbr6qlyISv6sDz6fsRheyUdBAaDAefOnYPL5eKDMzY2hmvXrkGn07GXotPpoNfrUavVOEFHa5/NZhGNRhEOh5FIJFh45HK5p+aZ0KEi4U8eWqu8xmEoyXK5HEqlEl6vF88++2xT199oNIpsNsuCV1RwV69exeXLl2G1WjE2NoZ6vY6PPvoIa2trqNVquH//fkvlIgpgEUe5//Zr1SK9rr2Ui1RpSCv1D6owyZM3mUwwGo1snZdKJSwvL+9rVNB7iN+fJtAZMJvNCAQCsFgsHP4CHq9HLpfD+vo61tfXj61BZs+VC4VmpBtsr/DBfmgXDxY/VxQoZN0To4U2ZKPRwPz8PMLh8IGv5TgEstvths/ng9frxdTUFBwOB987MW1UKhVUKhXTQSlJB4DbhdRqNWi1WpjNZpTLZT6IJpMJDoeDe1YBgNVqhU6nw87ODpaXl1lAnwaIApGUiujNAK079NJr92IZaTQariY/f/48HA4HZmdnMTs722SFZzIZzovQe9H6jY+Pw+fzwWAwQK/Xo16vw+v1QqFQYHl5GSqVapdC3AtHuQfFe2r3OQfN2ZHSNxgMuHjxIux2O9bW1ljAkUG0H+i80+upZiadTu963V7XfhS1HCcNhUIBn88Hi8UCn8/Hhifw+J7j8TgikQh2dnZQKpWOlSByZJ6LGP8HwIcPONghkclkLEBps9ZqNQ5FiIeiXq9DqVTCZrNBrVbDbrfDYrFgcnISX/7yl9FoNPCXf/mXiMfjLWPG++EoH4pMJsPMzAy+9KUvwe124+rVq7BYLLBYLDAYDDz/Q2qB0b1rtVrU63UUCgWUy2WoVCpotVpotVqcO3cONpsN1WqVv0qlEhQKBWZnZ+H1evGLX/wCf/u3f8ueTb8eQpF1JTUwqIuwWKEs0laB5kQ19RFrFeYxGo0YHR2F3+/Hn/3Zn+HixYtsrIjMvFYhIALRQOna6vU6Lly4gGKxiNXVVWi1Wn4eJw1ar72u5yD7ggxNtVoNj8eD3//938fFixfxgx/8AD/84Q+RTqexvb3dcbfncrnMpABqjplMJpsMjP2urV/39GGgUqlw4cIFnDlzBufPn+f5P8Djvbm2toY7d+7g/v37yGazx9pmp+cJfbIQSOjTfIxyuYx8Pt+0CcRkE/2MlIhGo+GklMFgaGLbEMWRhColrhqNBlQqFaxWK7RaLWw2G8xmM4aHh+F2u1GtVqHVancpvn6BSqVia1m
2024-07-31 16:50:32 +08:00
"text/plain": [
2024-08-01 00:59:36 +08:00
"<Figure size 500x500 with 25 Axes>"
2024-07-31 16:50:32 +08:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-08-01 01:23:57 +08:00
" 80%|████████ | 400/500 [48:13<14:39, 8.79s/it] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 400, iteration: 468, Discriminator Loss:array(0.566545, dtype=float32), Generator Loss: array(0.67279, dtype=float32)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9V5Bk53kejj+dcw7TMz15Zmc2JyywAISMAimRImkGyaayS7qwVOVw4Tv7xre+kdOF7LJL/pVMipLKpCTCTAABIoddpM07OXXOOXf/L/b/vvt1b8/shJ6Z07PnqZpaYLen+5yvv/O96XmfV9FqtVqQIUOGDBkyegjlYV+ADBkyZMg4epCNiwwZMmTI6Dlk4yJDhgwZMnoO2bjIkCFDhoyeQzYuMmTIkCGj55CNiwwZMmTI6Dlk4yJDhgwZMnoO2bjIkCFDhoyeQzYuMmTIkCGj51Bv94UKhWI/r6MvQGvQbDZ3/buHic2uQRRpUCqVMBgMUKlUqFQqqFQqD32/TpGHsbExPPfcc1AoFHj77bexsrKy6eft9dofRexWVENew/uQ8h5UKBRQKpVotVq7OmsOAttZP8V25V+2WlilUgmFQoFms7nrjd9P6MXGpA0E4FDXTbyuVqsFhUIBlUrF10bfa71e3/Y1Wq1WDA0NQalUIhAIIJvNtv3uQT3YCoWizQAelb0pG5e9Y7/3oEKhgEajgUKhQKPR4Gd8u5+rUCgkvV+3c23bjly2AhkXAGg0Gr14yyMPOsSBhx983Tbadv9uu9dCv9dqtVCv1wEAWq0WGo0GjUYDjUZj2++dzWaRzWb5vQ8D9LlkxJvNprw3ZRwYlEolNBoNlEolarUa6vX6jpxIKRuW7aInxqXZbB66pVUqlXxY00Eo9S9IPNC387rd/B0drgqFAgaDAQaDAbVaDdls9qEhN0UsD3soaN27vY6ioYMGfS6lFqS+F2QcHYiONjk1j+Ie7JlxOWyo1WoYDAa0Wi1UKhX+QqVwbd1wUMaPwnO1Wo3BwUEMDQ0hnU5jbm4O5XJ5y9+t1+sPjVhED61arXLUI+KwHqp+cDBkHC2Qk9tqtVCtVh/pPdgT4yIFUH5dqVTCZDIBAEqlEh+g/fQF96JWQO8hRi5iBKFWq6FWq9s8qu1GQzv5dxkyHjXQMytVx/ag0JOCvhSg0Wig1WphtVpx/vx52O12XL9+HXfv3m2LYHpxGO53QV+tVsNoNEKpVKJUKjFja6vUI70XXZ9areaCIhXoqYZChqbVaiGXy6FUKnEKbDdQqVRc+N/OAyVlps5hQPw+trM2ckF/79jPPbgZi/Io4cAK+rtBr74A0TMH7hkZv9+PgYEBBIPBbRfNDxOda6FQKKDVaqFSqVCtVh94bbdCvvgetCZqtbotgimVSsjn89Dr9XA6nVAqlSiXy6hUKm01ip3iUffQ9grRuMjof8jf4z0cmnHZ6xegVCqh1+vx/PPP4/jx4wDuedD1eh3pdBqLi4tIJBJMoZXKF95J/aU/xetrNpsolUrMctJoNBwViAVy8fc7D/hGo8GGiQ4veo9yuYx0Os0GZ6dMlk6I1yQekgqFAjqdDiqVCrVa7QFD+SiBvgOKIhuNRlukKJX9KePoQqlUQqlUwmw2Q6/Xo1gsIpfLodVqtTnhnWfJbgOBvqy5kGduNpvx7W9/G9/+9rd5UVZXV/Hnf/7nuHPnDmKxGBekDzslIB6+QPcvkdBoNFAoFLgYT3RgsaGxm4ERQcX4btdQr9dRKpW6/t5uIRoU8fMMBgN0Oh2KxSJqtdojeYjSflUqldDpdFCr1ahUKmxctssalCFjt6A9qNFo4Ha7YbPZEIvFkM/nAdxPbXdj2orO6U72aF8aF7VaDavVCofDAbPZDIPBwJQ/tVqNcrmMfD7fdphJ5cHdzMsndPackBHabupJrVazFyK+R7efXqPTgNJmrdVqR67HhNKOIyMjcDgcXXu9isUiisUi9w4pFAr2GpPJJCKRyJGmqZrNZrjdbqhUKk7RdoIi6Wq1ilKp9EDDrYzu0Gq1sFgs0Gq1sNlsMBqN/G+NRgOZTAb1ep2JO2q1GmazGVqtFgMDA7BarVhbW0O9Xke1WkWxWNz0Gd3t99GXxsViseDEiRMYHByEy+Vq629pNBqIx+MIhUKSSsN0dqh3K+h3+7tudODNjANFcwaDgf+u2Wzy4V6pVPZtTUSDItaAisVi2/0cBahUKuh0OtjtdvzJn/wJXn75ZahUKmg0Gqag1mo13L17F4uLiygUCojFYgCA4eFh2Gw2fP7553jrrbfYEdotmULKmJ6exle/+lWYzWbYbDZotdo2NQ862JaWlhCLxbCwsIBPP/30kY1wdwKHw4FLly7B7Xbj6aefxvT0NJ8f2WwWn3zyCZLJJKxWK6xWK2w2GyYmJmAwGGA2m6HT6XD16lW8/vrriMVi+Pzzz5FKpbqeLbt1RvvOuFAe3+PxwO12Q6/Xt1Fu6eGuVCqSLTR3HsCdf78dQyL+DjHBKKIjL4bqTcViEfV6nQ+wXhePuxkW+oyddPZLFXR/xLajniqHw4GJiQmcOnWK2XnUZ0UGvVarIZfL8Xu53W7Y7XbYbDbodDqOaI4S6Ps3mUwYGBiAzWaD0+lkzTqqJdZqNVQqFZRKJbRaLUQikSO3Fr2GVqvlaMXn88Hr9WJqagonTpzgsyCTySCRSMBkMsHhcPD6Hzt2DHq9HkajERqNBpFIBF6vlyOcXqOvjAvRjcfGxvDNb34TIyMjmJqa4s3aCSkeagaDAVqttk2vS9Qd2okGkU6ng06nw+DgIJ577jk4nU4MDAzAbrcjnU4jFAohn89jfn4e6XQa8Xgc8XicPcZeGF+RrafRaKBSqZhMICUixW5BfVNarRbPPfccnnvuOTYuBoMBFy9ebPPIAfA6DA0NAQDS6TQ0Gg2KxSLi8TjW1taQSCRgt9uh0WhQKBRQq9UO8zZ7AqorWSwWdvrW19eRy+Wg1+vZAdJqtUwioYiP2JGycekOIoM8//zzeOGFF+BwOHDs2DGYzWaOhul1Op0OTz75JEqlEnQ6HbRaLUfaKpWK19lut2N8fJzTlr12BPvKuKjVauj1evh8PjzzzDMYHx9/wGsmSPFQI4qxyWTionqj0dhVwyQd5kajEUNDQ3j55ZcxMjKC4eFhOJ1OBINB3L59G4lEAtVqFRqNBtVqFblcDgqFoieHmRgxUg1Cq9VKPnLcCUgl2mg04sknn8Q//+f/nB9WsbeIoFAo2At0uVzQarVIJpMoFovIZDJYXV3FxsYGyuUyLBYLHxr9DtoHKpUKJpMJVqsVKpUKsVgMtVoNw8PDvNdFQUc67OhHRnfQPjl79iy++93vwmQywel0do04DAYDG5tOiA6syWTC0NAQyuXyAw3VvUDfGBelUomTJ0/i7NmzOHXqFEwmU5sxSSQSWF5exsLCAvL5vCSNS6vVQq1WY6NCueXdXqter4fdbofdbofb7Ybb7eaQ12w2w+fzwWAwYGZmBk6nEz6fDxMTE8jn8wiFQiiXy4jH47xe27mOTs9SjLgoEtuJgrJUQYel1WrF5cuX4ff7MTMz0+b5PQx02Ipeo8fjQbVaRSQSQSgUYip4v4OMqqioXa/Xkc/n0Wg0cPv2baytrcHtdsPlcnENsFarIZPJMLFB7vl5EAaDAbOzs3C73ZienobRaIROp+vaH/cwiE54Op3GnTt32NnpNfrCuNDGfeWVV/Bv/s2/gV6vh8ViaXvN8vIy/v7v/x7BYBCxWEyyXjNJ0myXxdZtw9DvWK1WDA8PY2xsDBMTExgZGeEowuVywWg0olwuw263M3uuVqshFovh008/RSKRwCeffILl5eW2PpqtQOkfkUpNLBP6UzRUO+0+lwqItun1evH7v//7eOKJJ2C1WmEwGLrWy7qBvHTRQ8zn8zAajcjlcohEIiiVSkcmJUY0a9oj5Ly0Wi3cunUL9XodExMTmJiY4NoUcJ/hSJGNbGDaYbPZ8I1vfAOnTp3C8ePHuQGaasz0IxKCtgL93sbGBt544w0kEgmk0+men5mSNy5qtRpOpxNGoxEDAwNwuVxtoSDlbul
"text/plain": [
"<Figure size 500x500 with 25 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 500/500 [58:22<00:00, 7.00s/it]\n"
2024-07-29 00:18:35 +08:00
]
2024-07-26 21:07:40 +08:00
}
],
"source": [
2024-07-30 00:44:16 +08:00
"# Set your parameters\n",
2024-07-31 16:25:39 +08:00
"n_epochs = 500\n",
2024-07-30 00:44:16 +08:00
"display_step = 5000\n",
2024-07-28 06:10:19 +08:00
"cur_step = 0\n",
2024-07-30 00:44:16 +08:00
"\n",
2024-07-31 16:25:39 +08:00
"batch_size = 128 # 128\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-28 06:10:19 +08:00
"D_loss_grad = nn.value_and_grad(disc, disc_loss)\n",
"G_loss_grad = nn.value_and_grad(gen, gen_loss)\n",
"\n",
"\n",
2024-07-30 00:44:16 +08:00
"for epoch in tqdm(range(n_epochs)):\n",
2024-07-28 06:10:19 +08:00
"\n",
2024-07-30 18:21:38 +08:00
" for idx,real in enumerate(batch_iterate(batch_size, train_images)):\n",
2024-07-28 06:10:19 +08:00
" \n",
2024-07-30 00:44:16 +08:00
" # TODO Train Discriminator\n",
2024-07-30 07:06:52 +08:00
" D_loss,D_grads = D_loss_grad(gen, disc,mx.array(real), batch_size, z_dim)\n",
2024-07-28 06:10:19 +08:00
"\n",
" # Update optimizer\n",
" disc_opt.update(disc, D_grads)\n",
" \n",
" # Update gradients\n",
" mx.eval(disc.parameters(), disc_opt.state)\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-30 00:44:16 +08:00
" # TODO Train Generator\n",
" G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)\n",
2024-07-28 06:10:19 +08:00
" \n",
" # Update optimizer\n",
" gen_opt.update(gen, G_grads)\n",
" \n",
" # Update gradients\n",
2024-07-30 07:37:09 +08:00
" mx.eval(gen.parameters(), gen_opt.state) \n",
2024-07-30 07:44:41 +08:00
" \n",
2024-07-31 16:25:39 +08:00
" if epoch%100==0:\n",
2024-07-30 18:21:38 +08:00
" print(\"Epoch: {}, iteration: {}, Discriminator Loss:{}, Generator Loss: {}\".format(epoch,idx,D_loss,G_loss))\n",
2024-07-30 07:44:41 +08:00
" fake_noise = mx.array(get_noise(batch_size, z_dim))\n",
" fake = gen(fake_noise)\n",
2024-08-01 01:23:57 +08:00
" show_images(epoch,fake)"
2024-07-27 05:19:08 +08:00
]
2024-07-26 21:07:40 +08:00
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}