mlx-examples/gan/playground.ipynb

631 lines
147 KiB
Plaintext
Raw Normal View History

2024-07-26 21:07:40 +08:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Import Library"
]
},
{
"cell_type": "code",
2024-07-31 16:25:39 +08:00
"execution_count": 63,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"import mnist"
]
},
{
"cell_type": "code",
2024-07-31 16:25:39 +08:00
"execution_count": 64,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"import mlx.core as mx\n",
"import mlx.nn as nn\n",
"import mlx.optimizers as optim\n",
"\n",
"from tqdm import tqdm\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
2024-07-30 07:17:12 +08:00
{
"cell_type": "code",
2024-07-31 16:25:39 +08:00
"execution_count": 65,
2024-07-30 07:17:12 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-31 16:25:39 +08:00
"# mx.set_default_device(mx.gpu)\n",
"mx.random.seed(42)"
2024-07-30 07:17:12 +08:00
]
},
2024-07-26 21:07:40 +08:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# GAN Architecture"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generator 👨🏻‍🎨"
]
},
{
"cell_type": "code",
2024-07-31 16:25:39 +08:00
"execution_count": 66,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def GenBlock(in_dim:int,out_dim:int):\n",
" \n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
2024-07-30 07:06:52 +08:00
" nn.BatchNorm(out_dim, 0.8),\n",
" nn.LeakyReLU(0.2)\n",
2024-07-26 21:07:40 +08:00
" )"
]
},
{
"cell_type": "code",
2024-07-31 16:25:39 +08:00
"execution_count": 67,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"class Generator(nn.Module):\n",
"\n",
2024-07-31 00:50:02 +08:00
" def __init__(self, z_dim:int = 32, im_dim:int = 784, hidden_dim: int = 256):\n",
2024-07-26 21:07:40 +08:00
" super(Generator, self).__init__()\n",
2024-07-30 18:21:38 +08:00
"\n",
2024-07-26 21:07:40 +08:00
" self.gen = nn.Sequential(\n",
" GenBlock(z_dim, hidden_dim),\n",
" GenBlock(hidden_dim, hidden_dim * 2),\n",
" GenBlock(hidden_dim * 2, hidden_dim * 4),\n",
"\n",
2024-07-30 07:56:13 +08:00
" nn.Linear(hidden_dim * 4,im_dim),\n",
2024-07-26 21:07:40 +08:00
" )\n",
" \n",
" def __call__(self, noise):\n",
2024-07-30 07:06:52 +08:00
" x = self.gen(noise)\n",
" return mx.tanh(x)"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-31 16:25:39 +08:00
"execution_count": 68,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Generator(\n",
" (gen): Sequential(\n",
" (layers.0): Sequential(\n",
2024-07-31 00:50:02 +08:00
" (layers.0): Linear(input_dims=100, output_dims=256, bias=True)\n",
" (layers.1): BatchNorm(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
2024-07-30 07:06:52 +08:00
" (layers.2): LeakyReLU()\n",
2024-07-26 21:07:40 +08:00
" )\n",
" (layers.1): Sequential(\n",
2024-07-31 00:50:02 +08:00
" (layers.0): Linear(input_dims=256, output_dims=512, bias=True)\n",
" (layers.1): BatchNorm(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
2024-07-30 07:06:52 +08:00
" (layers.2): LeakyReLU()\n",
2024-07-26 21:07:40 +08:00
" )\n",
" (layers.2): Sequential(\n",
2024-07-31 00:50:02 +08:00
" (layers.0): Linear(input_dims=512, output_dims=1024, bias=True)\n",
" (layers.1): BatchNorm(1024, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
2024-07-30 07:06:52 +08:00
" (layers.2): LeakyReLU()\n",
2024-07-26 21:07:40 +08:00
" )\n",
2024-07-31 00:50:02 +08:00
" (layers.3): Linear(input_dims=1024, output_dims=784, bias=True)\n",
2024-07-26 21:07:40 +08:00
" )\n",
")"
]
},
2024-07-31 16:25:39 +08:00
"execution_count": 68,
2024-07-26 21:07:40 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gen = Generator(100)\n",
"gen"
]
},
{
"cell_type": "code",
2024-07-31 16:25:39 +08:00
"execution_count": 69,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-30 00:44:16 +08:00
"# make 2D noise with shape n_samples x z_dim\n",
"def get_noise(n_samples:list[int], z_dim:int)->list[int]:\n",
" return mx.random.normal(shape=(n_samples, z_dim))"
]
},
{
"cell_type": "code",
2024-07-31 16:25:39 +08:00
"execution_count": 70,
2024-07-30 00:44:16 +08:00
"metadata": {},
"outputs": [
{
"data": {
2024-07-31 16:25:39 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAWNklEQVR4nO3ca2zW9d3H8U9LobRFQCiM80GtdMgKioJYg8VJBIwRmIsYcCw76FTCxsZmZjZM3IxzU4kYNZqFLZ5ZHJODGkCshIMoIudxGEUqZwQpbS1toVz3s29yP7quz+9OvO/ceb8eX+//tdSWz/5PvnmZTCYjAAAk5f9v/w8AAPzfwSgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgFOT6wYceesh+eLt27eymrKzMbiRp165ddjNq1Ci72bFjh90UFxfbzdatW+1Gkr773e/aTWNjo93cfPPNdvPWW2/ZjSQdPnzYbu666y67+c9//mM3Q4YMsZtXXnnFbiRpwIABdnPmzBm7qaystJv6+nq7+eijj+xGSvu36KuvvrKblP9OtbW1diNJ06dPt5uU39c//vGPWT/DmwIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIOR/E69+/v/3w8vJyu3n33XftRpIaGhrspq2tzW66d+9uNx9//PE38j2S9K1vfctutmzZYjeHDh2ymw4dOtiNlHZQMOVnfv3119tNYWGh3XTu3NluJOnkyZN2k3JobenSpXZTUlJiN8OHD7cbSZozZ47djBw50m4GDx5sN3369LEbSdq0aZPd3HrrrUnflQ1vCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACDkfBCvqqrKfvjTTz9tN6WlpXaT6uzZs3aTyWTsZvTo0XbT1NRkN5L0zDPP2M3tt99uNymH9woKcv51+29OnDhhN3v37rWblKNua9assZuysjK7kaRu3brZTcphxZSfd0VFhd2k/K5K0g033GA3AwcOtJtjx47ZTcq/KVLa8dCUI3q5/K3zpgAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACDmfrUy5GDhkyBC7OX78uN1IaZdI6+rq7Gbw4MF209LSYjepV1JTrtmmXN88fPiw3RQVFdmNJHXt2tVuLly4YDdr1661m5TLr9u2bbMbKe3SZ8rvXsr3nDt3zm5Srp1K0m9/+1u7efDBB+1m2rRpdnP99dfbjSRdeeWVdpN6ZTYb3hQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAyPmaV3V1tf3wUaNG2U27du3sRpI+++wzu5k6dardrFixwm5uuukmuzl9+rTdSNKUKVPsZt26dXbT3NxsN3l5eXYjSYMGDbKblENwmUzGbn7/+9/bzbx58+xGki5evGg3KYcBd+/ebTc/+clP7Gbu3Ll2I0nvvPOO3aQcsuzdu7fdNDY22o0kPfXUU3ZTUlKS9F3Z8KYAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAQs4H8S6//HL74Rs3brSbDh062I2UdjRt586ddnPJJZfYzYIFC+xmzpw5diNJZ86csZtNmzbZTcrhvRdeeMFuJOncuXN206tXL7spKMj5zyFceumldrNv3z67kaSOHTvaTcoxxm7dutnN888/bzePPPKI3UjSww8/bDdVVVV2s3z5cru544477EZK+/dr//79Sd+VDW8KAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIOR8ASzl8NeYMWPspqWlxW4k6eTJk3YzceJEu3niiSfs5tprr7Wb3bt3240klZaW2k1zc7PddOnSxW7Ky8vtRpJOnTplNyk/h5QDjj/84Q/tpqamxm4kafLkyXbz6KOP2s1zzz1nN3V1dXazZs0au5GkDz/80G5aW1vtJuUA4QcffGA3kjR37ly7uf/++5O+KxveFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEDI+SDenj177IfX19fbTVFRkd1I0qeffmo3e/futZsNGzbYTcr/tq1bt9qNJHXr1s1utm/fbjfV1dV2M3jwYLuRpGHDhtnNokWL7KaystJuVqxYYTedO3e2G0k6ePCg3SxYsMBu3n77bbspKSmxm+LiYruRpC1bttjN7373O7tp37693aT8/UnSn/70J7upra1N+q5seFMAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAIS8TCaTyeWDP//5z+2Hf//737ebzz//3G4kaeTIkXazc+dOuzlx4oTdpFxb/OKLL+xGklpbW+2mqanJbsaPH283K1eutBtJuvvuu+1m8+bNdnPkyBG7ue222+xm/fr1diNJHTt2tJsvv/zSblJ+3ilXaVOunUrS448/bjfvv/++3bS0tNhN6gXcrl272s0///lPu1myZEnWz/CmAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAELOB/E2btxoP/yJJ56wmwkTJtiNlHb4K+UIVYqUQ2YbNmxI+q6U43t9+vT5RpqUY4KS1K5dO7vp1KmT3XTp0sVu3njjDbtJOegmScuXL7ebmTNn2s2dd95pN++8847dVFVV2Y0kzZ49225SDiT++Mc/tpuUI5uS1NzcbDcXL160m1mzZmX9DG8KAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIBTk+sGFCxfaD29sbLSbhoYGu5GkgQMH2k1ZWZnd1NXV2c3LL79sNylH4CTpmmuusZv+/fvbzdq1a+1m0KBBdiOlHS6sr6+3m2PHjtnNpEmT7CY/P+3/i1VUVNjNX/7yF7t56qmn7Obo0aN2M2fOHLuRpKlTp9rNlVdeaTcjRoywm2XLltmNlPbv3rBhw5K+KxveFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEDI+SBeU1OT/fCnn37abv71r3/ZjSQVFhbazZIlS+ymb9++djNx4kS7ufrqq+1Gkv7whz/YzW9+8xu76dGjh9106dLFbiRp8eLFdjN79my7Wbdund2k/N59/vnndiNJw4cPt5s9e/bYzeuvv243d999t9289tprdiNJjz/+uN1MmTLFbjp37mw3o0ePthsp7Xfv3//+d9J3ZcObAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAg5HwltaioyH74Aw88YDcp1xYl6eWXX7abq666ym5aW1vtprq62m5qa2vtRpImT55sNy+99JLdlJeX203KlU9J2r59u92kXLjs16+f3ZSUlNhNp06d7EZKu6TZ3NxsNyl/gymXVW+66Sa7kaRp06bZzaZNm+xm7ty5dnPvvffajSTdcsstdpPyO54L3hQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAyPkg3tixY/2HF+T8+P+xyspKu+nWrZvd7N27126GDRtmN01NTXYjpf3M29ra7KZ9+/bfSCNJFRUVdvP888/bzYMPPmg3b7/9tt1MmTLFblK/q6yszG42b95sN9dee63drFy50m4kad68eXaT8jf4ox/9yG4OHTpkN1LaAcyjR4/azahRo7J+hjcFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEHK+nta7d2/74ePGjbObMWPG2I0kvfDCC3Zz5MgRu7niiivs5rP
2024-07-30 00:44:16 +08:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"img = get_noise(28,28)\n",
"plt.imshow(img, cmap='gray')\n",
"plt.axis('off')\n",
"plt.show()"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Discriminator 🕵🏻‍♂️"
]
},
{
"cell_type": "code",
2024-07-31 16:25:39 +08:00
"execution_count": 71,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def DisBlock(in_dim:int,out_dim:int):\n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
2024-07-30 00:44:16 +08:00
" nn.LeakyReLU(negative_slope=0.2),\n",
2024-07-30 18:21:38 +08:00
" nn.Dropout(0.3),\n",
2024-07-26 21:07:40 +08:00
" )"
]
},
{
"cell_type": "code",
2024-07-31 16:25:39 +08:00
"execution_count": 72,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"class Discriminator(nn.Module):\n",
"\n",
2024-07-31 00:50:02 +08:00
" def __init__(self,im_dim:int = 784, hidden_dim:int = 256):\n",
2024-07-26 21:07:40 +08:00
" super(Discriminator, self).__init__()\n",
"\n",
" self.disc = nn.Sequential(\n",
2024-07-30 07:37:09 +08:00
" DisBlock(im_dim, hidden_dim * 4),\n",
" DisBlock(hidden_dim * 4, hidden_dim * 2),\n",
2024-07-26 21:07:40 +08:00
" DisBlock(hidden_dim * 2, hidden_dim),\n",
2024-07-30 07:37:09 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" nn.Linear(hidden_dim,1),\n",
2024-07-30 18:21:38 +08:00
" nn.Sigmoid()\n",
2024-07-26 21:07:40 +08:00
" )\n",
" \n",
2024-07-30 18:21:38 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" def __call__(self, noise):\n",
2024-07-30 18:21:38 +08:00
" return self.disc(noise)"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-31 16:25:39 +08:00
"execution_count": 73,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Discriminator(\n",
" (disc): Sequential(\n",
" (layers.0): Sequential(\n",
2024-07-31 00:50:02 +08:00
" (layers.0): Linear(input_dims=784, output_dims=1024, bias=True)\n",
2024-07-26 21:07:40 +08:00
" (layers.1): LeakyReLU()\n",
2024-07-30 18:21:38 +08:00
" (layers.2): Dropout(p=0.30000000000000004)\n",
2024-07-26 21:07:40 +08:00
" )\n",
" (layers.1): Sequential(\n",
2024-07-31 00:50:02 +08:00
" (layers.0): Linear(input_dims=1024, output_dims=512, bias=True)\n",
2024-07-30 07:37:09 +08:00
" (layers.1): LeakyReLU()\n",
2024-07-30 18:21:38 +08:00
" (layers.2): Dropout(p=0.30000000000000004)\n",
2024-07-30 07:37:09 +08:00
" )\n",
" (layers.2): Sequential(\n",
2024-07-31 00:50:02 +08:00
" (layers.0): Linear(input_dims=512, output_dims=256, bias=True)\n",
2024-07-26 21:07:40 +08:00
" (layers.1): LeakyReLU()\n",
2024-07-30 18:21:38 +08:00
" (layers.2): Dropout(p=0.30000000000000004)\n",
2024-07-26 21:07:40 +08:00
" )\n",
2024-07-31 00:50:02 +08:00
" (layers.3): Linear(input_dims=256, output_dims=1, bias=True)\n",
2024-07-30 18:21:38 +08:00
" (layers.4): Sigmoid()\n",
2024-07-26 21:07:40 +08:00
" )\n",
")"
]
},
2024-07-31 16:25:39 +08:00
"execution_count": 73,
2024-07-26 21:07:40 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"disc = Discriminator()\n",
"disc"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model Training 🏋🏻‍♂️"
]
},
{
2024-07-30 00:44:16 +08:00
"cell_type": "markdown",
2024-07-26 21:07:40 +08:00
"metadata": {},
"source": [
2024-07-30 00:44:16 +08:00
"### Losses"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-07-30 00:44:16 +08:00
"#### Discriminator Loss"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-31 16:25:39 +08:00
"execution_count": 74,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-26 21:36:29 +08:00
"def disc_loss(gen, disc, real, num_images, z_dim):\n",
2024-07-30 00:44:16 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" noise = mx.array(get_noise(num_images, z_dim))\n",
" fake_images = gen(noise)\n",
2024-07-27 06:09:51 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" fake_disc = disc(fake_images)\n",
" \n",
2024-07-27 05:19:08 +08:00
" fake_labels = mx.zeros((fake_images.shape[0],1))\n",
2024-07-30 18:21:38 +08:00
" \n",
" fake_loss = mx.mean(nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True))\n",
2024-07-30 00:44:16 +08:00
" \n",
2024-07-30 18:21:38 +08:00
" real_disc = mx.array(disc(real))\n",
2024-07-27 05:19:08 +08:00
" real_labels = mx.ones((real.shape[0],1))\n",
2024-07-30 18:21:38 +08:00
" \n",
" real_loss = mx.mean(nn.losses.binary_cross_entropy(real_disc,real_labels,with_logits=True))\n",
" \n",
2024-07-31 00:50:02 +08:00
" disc_loss = (fake_loss + real_loss) / 2.0\n",
2024-07-26 21:07:40 +08:00
"\n",
" return disc_loss"
]
},
2024-07-30 00:44:16 +08:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Generator Loss"
]
},
2024-07-26 21:07:40 +08:00
{
"cell_type": "code",
2024-07-31 16:25:39 +08:00
"execution_count": 75,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-26 21:36:29 +08:00
"def gen_loss(gen, disc, num_images, z_dim):\n",
2024-07-26 21:07:40 +08:00
"\n",
" noise = mx.array(get_noise(num_images, z_dim))\n",
2024-07-30 07:06:52 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" fake_images = gen(noise)\n",
2024-07-30 18:21:38 +08:00
" fake_disc = mx.array(disc(fake_images))\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-27 05:19:08 +08:00
" fake_labels = mx.ones((fake_images.shape[0],1))\n",
2024-07-30 00:44:16 +08:00
" \n",
2024-07-30 18:21:38 +08:00
" gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)\n",
" \n",
" return mx.mean(gen_loss)"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-31 16:25:39 +08:00
"execution_count": 76,
2024-07-30 00:44:16 +08:00
"metadata": {},
"outputs": [],
"source": [
"# Get only the training images\n",
2024-07-30 18:21:38 +08:00
"train_images,train_labels,*_ = map(np.array, mnist.mnist())"
2024-07-30 00:44:16 +08:00
]
},
{
"cell_type": "code",
2024-07-31 16:25:39 +08:00
"execution_count": 77,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-30 00:44:16 +08:00
"# Normalize the images to fall between -1,1\n",
"train_images = train_images * 2.0 - 1.0"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-31 16:25:39 +08:00
"execution_count": 78,
2024-07-29 06:24:50 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2024-07-31 16:25:39 +08:00
"<matplotlib.image.AxesImage at 0x10ed09990>"
2024-07-29 06:24:50 +08:00
]
},
2024-07-31 16:25:39 +08:00
"execution_count": 78,
2024-07-29 06:24:50 +08:00
"metadata": {},
"output_type": "execute_result"
2024-07-30 00:44:16 +08:00
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaEAAAGdCAYAAAC7EMwUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAaI0lEQVR4nO3df2jU9x3H8dfVH1d1lytBk7vUmGVF202dpWrVYP3R1cxApf4oWMtGZEPa+YOJ/cGsDNNBjdgpRdI6V0amW239Y9a6KdUMTXRkijpdRYtYjDOdCcFM72LUSMxnf4hHz1j1e975vkueD/iCufu+vY/ffuvTby75xueccwIAwMBD1gsAAHRfRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJjpab2AW3V0dOjcuXMKBALy+XzWywEAeOScU0tLi/Ly8vTQQ3e+1km7CJ07d075+fnWywAA3Kf6+noNHDjwjvuk3afjAoGA9RIAAElwL3+fpyxCH3zwgQoLC/Xwww9r5MiR2rdv3z3N8Sk4AOga7uXv85REaPPmzVq8eLGWLVumI0eO6JlnnlFJSYnOnj2bipcDAGQoXyruoj1mzBg99dRTWrduXeyx73//+5o+fbrKy8vvOBuNRhUMBpO9JADAAxaJRJSVlXXHfZJ+JXTt2jUdPnxYxcXFcY8XFxertra20/5tbW2KRqNxGwCge0h6hM6fP6/r168rNzc37vHc3Fw1NjZ22r+8vFzBYDC28ZVxANB9pOwLE259Q8o5d9s3qZYuXapIJBLb6uvrU7UkAECaSfr3CfXv3189evTodNXT1NTU6epIkvx+v/x+f7KXAQDIAEm/Eurdu7dGjhypqqqquMerqqpUVFSU7JcDAGSwlNwxYcmSJfrpT3+qUaNGady4cfr973+vs2fP6tVXX03FywEAMlRKIjR79mw1NzfrN7/5jRoaGjRs2DDt2LFDBQUFqXg5AECGSsn3Cd0Pvk8IALoGk+8TAgDgXhEhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmelovAEgnPXr08DwTDAZTsJLkWLhwYUJzffv29Tzz+OOPe55ZsGCB55nf/va3nmfmzJnjeUaSrl696nlm5cqVnmfefvttzzNdBVdCAAAzRAgAYCbpESorK5PP54vbQqFQsl8GANAFpOQ9oaFDh+rvf/977ONEPs8OAOj6UhKhnj17cvUDALirlLwndOrUKeXl5amwsFAvvfSSTp8+/a37trW1KRqNxm0AgO4h6REaM2aMNm7cqJ07d+rDDz9UY2OjioqK1NzcfNv9y8vLFQwGY1t+fn6ylwQASFNJj1BJSYlmzZql4cOH67nnntP27dslSRs2bLjt/kuXLlUkEolt9fX1yV4SACBNpfybVfv166fhw4fr1KlTt33e7/fL7/enehkAgDSU8u8Tamtr05dffqlwOJzqlwIAZJikR+j1119XTU2N6urqdODAAb344ouKRqMqLS1N9ksBADJc0j8d9/XXX2vOnDk6f/68BgwYoLFjx2r//v0qKChI9ksBADJc0iP0ySefJPu3RJoaNGiQ55nevXt7nikqKvI8M378eM8zkvTII494npk1a1ZCr9XVfP31155n1q5d63lmxowZnmdaWlo8z0jSv//9b88zNTU1Cb1Wd8W94wAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAMz7nnLNexDdFo1EFg0HrZXQrTz75ZEJzu3fv9jzDf9vM0NHR4XnmZz/7meeZS5cueZ5JRENDQ0JzFy5c8Dxz8uTJhF6rK4pEIsrKyrrjPlwJAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwExP6wXA3tmzZxOaa25u9jzDXbRvOHDggOeZixcvep6ZPHmy5xlJunbtmueZP/3pTwm9Fro3roQAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADPcwBT63//+l9DcG2+84Xnm+eef9zxz5MgRzzNr1671PJOoo0ePep6ZMmWK55nW1lbPM0OHDvU8I0m//OUvE5oDvOJKCABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAw43POOetFfFM0GlUwGLReBlIkKyvL80xLS4vnmfXr13uekaSf//znnmd+8pOfeJ75+OOPPc8AmSYSidz1/3muhAAAZogQAMCM5wjt3btX06ZNU15ennw+n7Zu3Rr3vHNOZWVlysvLU58+fTRp0iQdP348WesFAHQhniPU2tqqESNGqKKi4rbPr1q1SmvWrFFFRYUOHjyoUCikKVOmJPR5fQBA1+b5J6uWlJSopKTkts855/Tee+9p2bJlmjlzpiRpw4YNys3N1aZNm/TKK6/c32oBAF1KUt8TqqurU2Njo4qLi2OP+f1+TZw4UbW1tbedaWtrUzQajdsAAN1DUiPU2NgoScrNzY17PDc3N/bcrcrLyxUMBmNbfn5+MpcEAEhjKfnqOJ/PF/exc67TYzctXbpUkUgkttXX16diSQCANOT5PaE7CYVCkm5cEYXD4djjTU1Nna6ObvL7/fL7/clcBgAgQyT1SqiwsFChUEhVVVWxx65du6aamhoVFRUl86UAAF2A5yuhS5cu6auvvop9XFdXp6NHjyo7O1uDBg3S4sWLtWLFCg0ePFiDBw/WihUr1LdvX7388stJXTgAIPN5jtChQ4c0efLk2MdLliyRJJWWluqPf/yj3nzzTV25ckXz58/XhQsXNGbMGO3atUuBQCB5qwYAdAncwBRd0rvvvpvQ3M1/VHlRU1Pjeea5557zPNPR0eF5BrDEDUwBAGmNCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZriLNrqkfv36JTT317/+1fPMxIkTPc+UlJR4ntm1a5fnGcASd9EGAKQ1IgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMNzAFvuGxxx7zPPOvf/3L88zFixc9z+zZs8fzzKFDhzzPSNL777/veSbN/ipBGuAGpgCAtEaEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmOEGpsB9mjFjhueZyspKzzOBQMDzTKLeeustzzMbN270PNPQ0OB5BpmDG5gCANIaEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGG5gCBoYNG+Z5Zs2aNZ5nfvSjH3meSdT69es9z7zzzjueZ/773/96noENbmAKAEhrRAgAYMZzhPbu3atp06YpLy9PPp9PW7dujXt+7ty58vl8cdvYsWOTtV4AQBfiOUKtra0aMWKEKioqvnWfqVOnqqGhIbbt2LHjvhYJAOiaenodKCkpUUlJyR338fv9CoVCCS8KANA9pOQ9oerqauXk5GjIkCGaN2+empqavnXftrY2RaPRuA0A0D0kPUIlJSX66KOPtHv3bq1evVoHDx7Us88+q7a2ttvuX15ermAwGNvy8/OTvSQAQJry/Om4u5k9e3bs18OGDdOoUaNUUFCg7du3a+bMmZ32X7p0qZYsWRL7OBqNEiI
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
2024-07-29 06:24:50 +08:00
}
],
"source": [
2024-07-30 00:44:16 +08:00
"plt.imshow(train_images[0].reshape(28,28),cmap='gray')"
2024-07-29 06:24:50 +08:00
]
},
{
"cell_type": "code",
2024-07-31 16:25:39 +08:00
"execution_count": 79,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-30 00:44:16 +08:00
"def batch_iterate(batch_size: int, ipt: list[int])-> list[int]:\n",
2024-07-29 06:24:50 +08:00
" perm = np.random.permutation(len(ipt))\n",
" for s in range(0, len(ipt), batch_size):\n",
2024-07-26 21:07:40 +08:00
" ids = perm[s : s + batch_size]\n",
2024-07-30 00:44:16 +08:00
" yield ipt[ids]"
]
},
{
"cell_type": "code",
2024-07-31 16:25:39 +08:00
"execution_count": 80,
2024-07-30 00:44:16 +08:00
"metadata": {},
"outputs": [],
"source": [
"def show_images(imgs:list[int],num_imgs:int = 25):\n",
" if (imgs.shape[0] > 0): \n",
" fig,axes = plt.subplots(5, 5, figsize=(5, 5))\n",
" \n",
" for i, ax in enumerate(axes.flat):\n",
" img = mx.array(imgs[i]).reshape(28,28)\n",
" ax.imshow(img,cmap='gray')\n",
" ax.axis('off')\n",
" plt.show()"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-07-30 00:44:16 +08:00
"### show first batch of images"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-31 16:25:39 +08:00
"execution_count": 81,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
2024-07-31 16:25:39 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAADjS0lEQVR4nOyd53Od13ben9N77w29ESBYJJIiKUpULze6N+7O2J547Em+ZSZ/SvI140kmju1xVXyv46ubq3LVSEpiJ0Cit9N7e0/v+aCszQMSJEESBHCA/ZvBUCIKz9l433ftvdazniXqdDodcDgcDoezg4j3+gVwOBwO5+DBgwuHw+FwdhweXDgcDoez4/DgwuFwOJwdhwcXDofD4ew4PLhwOBwOZ8fhwYXD4XA4Ow4PLhwOh8PZcXhw4XA4HM6OI93uF4pEohf5OnqKZzE14Ot3H75+z8ezmmrwNbwPvwafj+2sHz+5cDgcDmfH4cGFw+FwODvOttNiHA6Hs9OIRCJoNBooFAo4nU4MDQ2h2WwiHA6jVCohk8kgm83u9cvkPAM8uHA4nD1DIpHA4XDAarXirbfewp/92Z+hVCrhX//1X+H3+3H9+nUeXHoUHlw4HM6uIxKJIJFIoFAoYDKZ4HQ64Xa74XQ6USwWYTQakclkIJfLIRKJnlnEwNk7eHDhcDi7jkKhgM1mg8FgwIcffoiXXnoJ/f39UCqVqFar0Ov1MJlMUCqVe/1SOc8IDy4cDmfXkUql0Ol0MBqNGBkZwfHjx6HT6SCVStmJRqlUQiKR7PVL3bc8rzT6RZ8GeXDhcDi7jkKhgN1uh8PhgNlshsFggFwuZ58Xi8Us0HAexmKx4NixYzAYDLDb7dDr9Q99zVbpxFarhWQyiWKxiFQqhVgshkqlglQqhVqttqOvkQcXDoez6yiVSrhcLng8HtjtdpjNZvY5kUgEqVQKqVQKsZh3S2yFw+HARx99BK/XixMnTqCvr2/T5+lU82BwqdVqmJ2dRTQaxfz8PK5du4ZMJoNCocCDy7NAxUOxWAydTgeFQgGRSASxWIxms4lsNot6vY5Op3PgCod0o4rFYojFYrYTFIlEEIlEUKvVkMvl7Cbuvpm7b3CpVIp2u41IJIJ8Po92u41Wq7Un7+lZkMlk0Gg0kMvlsNvt0Gg0SKfTyGazaDQaKBaLaLfbT/1zxWIxRCIRZDIZlEolZDIZTCYT5HI5JBIJJBIJisUiotEo6vU6Go3GM/07Bw2JRAKNRgONRgOZTMb+vtPpoNlsIpfLIZVKoVwu7+Gr3B9IJBIYjUbI5XJIpVLIZDL09fXBbDbDaDRCrVZDoVBs+p7ulFn3M00sFsNsNqPVaqFUKkEQBKRSKSaeKJVKqFQqO/K6D0VwkUql0Gg0UKlUOH78ODweD+RyOZRKJTKZDL7++mskEgnU63U0m829frk7ilgshl6vh0KhgFqthlarZZ+Ty+UYHx+H0+mEVCplQRf48eI0Go3Q6/VQq9WwWCwolUr4i7/4C1y5cgXVahWlUmmv3tZTYzKZMDY2BpfLhT/8wz/E6OgofvOb3+DTTz9FMpnE3NzcUz/IRCIRFAoFZDIZrFYrvF4vLBYL3nzzTTidTmi1WqjVaszOzuJ//s//iXg8jkwmwx+Y+PHac7lccLvdUKvVAMA2LIVCAbOzs5ifn0ckEjlwG76nRaPR4PTp03A6nTCZTDCZTLDb7Thy5Aj0ev2me5p41JpJJBL09fXB4/FgbGwMr776KiKRCPR6Pfx+P+bm5rCysrIjr3vfBhc6bQBgf7ZaLbRarae62GjnrtFooNVq4XA44PV6oVAooFKpNhUOD8IRnHbStH4ymQw6nQ5qtRo6nQ46nY4FELlcDrfbDa/Xy3ZE3dhsNpjNZmi1WtjtdgiCAKPRyHLhvSARpfeqUqlgtVrhdDoxNjaGqakpLC0tQafToVQqbft3T2tLJ1+1Wg2VSsVueIfDgcHBQfh8Pmi1Wmg0GuTzeahUqk0nxMMKrZtCoYBGo4FarWbXXafTQavVQr1ehyAIyGQyqFare/yKdx+ZTAapVMquM71eD4fDAY/HA6vVCovFsmnjJ5FI2H3YnX3pdDrsZwCbsxXAj0HLZDJBLBbDbrejWCzuqDpv3wYXk8kEn88HlUoFj8cDlUqF+fl5LC4uotlsolarPfHBRqcVr9eL1157DRaLBdPT03C73ewiX19fx6VLl5DJZNiF3WtQWkYikbBgYDKZ4PV6odPpcOTIERiNRqhUKiiVyk2pMp1OB6VSCUEQkEgk0Gw2UalU0Gq1IJfLYbPZWEqs2Wyi3W73TPqQbiSlUomJiQn89Kc/hd1uh9FoRLVaRTqdxtraGnK5HBqNxhN/nlKpZB8WiwVarRYXLlzA2NgYtFotjEYjNBoN+vv7odFoUK1WUSgUUCqVeErs/2M2m2GxWDA2NoZjx47B4/HAZDIBACqVCnK5HBKJBAKBADY2NnrqdLwTKBQKXLx4EVNTU1Cr1TCZTFCr1RgYGGCbREqDGQwGlnEAgGq1imKxuOkeVqlUbGOjVqs3bW7oOaBSqeB2u9FqtbC0tLRj72XfBhedToeBgQEYjUYcPXoUBoMB9Xodfr8ftVqN1Ugeh0KhgF6vR19fHy5cuACXy4XR0VHYbDZ2CiJJpEKh2LFc424jFoshk8kgk8lgsVhgsVjQ19eHY8eOwWKx4MyZM7Db7exr6EQDgD34wuEw8vk8Op0OGo0Gms0mZDIZU6G0Wi202+1NwaUXAoxSqYROp4PP58OZM2dgsVigVCpRr9eRz+cRj8dRLpe3lQ6Vy+XQaDTQ6XTo6+uDxWLB+++/j3PnzrG1JTqdDhKJBCqVCqrVKprNJprNZk+s2YtCJBJBp9PB7XbD5/NhaGgILpcLKpUKANBoNFAoFJDNZpFMJpFIJPb4Fe8+UqkUx48fxwcffACz2QyPxwOZTAaFQsFqpo86/dI1XavVIAgCGo0GjEYjDAYDCzLd0ElGoVDAYrGgXq+zU82OvJcd+0k7gEgkgsFggEajwdDQEJPaDQ8Psx14Op1GOp3G4uLilsFALBazwurk5CQmJibQ398Pr9cLk8nEony9XkehUEA+n0e5XEa5XN7W7nWvoAebQqGAw+GARqNhF4dMJoPRaIRCoYDL5WKFPpfLBYVCgUQigWw2i0qlgkqlAolEwi6iVCrFZImhUAgikQhmsxlqtZrlwHO5HNbW1pBMJhGLxdjDcr8jFoths9ng8Xjgcrmg1Wohk8lQLBbZBoV2c2KxGK1Wi/23TqeDzWZjN7JIJILVamUnQ5fLBZ1OB4fDAZlMxr6OCtKNRgMrKyu4c+cOVlZWkMvl2G7ysCISieB0OnHs2DGMjo5uKua3223EYjFcvXoVfr8fhUJhj1/t7mC1WmG1WiGRSNgpwuPxQK/XMwEKBRSxWIxqtfpQdoVS4X6/H7OzsyiXy0ilUqjX6xgaGmLFf71ev6vS7n0VXCQSCbxeL7xeL15++WV89NFH0Ov1sNlsUCgUkMvlsFgsWFxcRCgU2jK4SKVS9gB499138du//dvQarWw2Wxs1w78uGOPRCKIRqNIpVLIZrP7elep0Wjg8/lgNptx8eJF+Hw+pkbSaDTwer0s96/ValGv11GpVJDP53H16lXEYjGsra1hY2MDcrmcpSJWVlaQSCRQq9VQqVRgNpvx05/+FH19fWi1Wmg2mwgGg/jHf/xHRKNRLCwsoFgs7uu1IiQSCQYHB3Hq1ClMTU3BbDZDLBYjEAggnU6jXC7DYDBAq9Wy90M388jICM6ePct6L0QiEfr6+tDX1we5XA6dTgeJRMI2MkS73WZB/Ntvv8Xf/u3fMrXYYU+LiUQiTExM4Gc/+xk7YSuVSnadLS4u4m/+5m8OzalFJBJhaGgI586dg1wuZ6eLY8eOwel0QqVSsc0P1TcLhQKSyeSmn0Gbm+vXr+Pv//7vkc1mEQ6HUavVcPHiRZw9exZDQ0Po7+9/SFX2ItkXwYWKzxQ8XC4Xy41rtVqoVKpNgQF4WA1BP0OpVDLdvMPhYBcwKaGq1SparRbS6TRCoRBisRhqtdq+u+npVKL
2024-07-26 21:07:40 +08:00
"text/plain": [
2024-07-29 06:24:50 +08:00
"<Figure size 500x500 with 25 Axes>"
2024-07-26 21:07:40 +08:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
2024-07-30 00:44:16 +08:00
"X = batch_iterate(25, train_images)\n",
"for x in X: \n",
" show_images(x)\n",
2024-07-26 21:07:40 +08:00
" break"
]
},
2024-07-30 07:06:52 +08:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training Cycle"
]
},
2024-07-26 21:07:40 +08:00
{
"cell_type": "code",
2024-07-31 16:25:39 +08:00
"execution_count": 82,
2024-07-27 06:09:51 +08:00
"metadata": {},
2024-07-29 06:24:50 +08:00
"outputs": [],
2024-07-27 06:09:51 +08:00
"source": [
2024-07-31 16:25:39 +08:00
"lr = 2e-5\n",
2024-07-30 18:21:38 +08:00
"z_dim = 128\n",
2024-07-30 07:06:52 +08:00
"\n",
2024-07-27 06:09:51 +08:00
"gen = Generator(z_dim)\n",
"mx.eval(gen.parameters())\n",
2024-07-30 07:44:41 +08:00
"gen_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999]) #,betas=[0.5, 0.9]\n",
2024-07-27 06:09:51 +08:00
"\n",
"disc = Discriminator()\n",
"mx.eval(disc.parameters())\n",
2024-07-30 07:44:41 +08:00
"disc_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999])"
2024-07-29 06:30:08 +08:00
]
},
2024-07-27 06:09:51 +08:00
{
"cell_type": "code",
2024-07-31 16:25:39 +08:00
"execution_count": 83,
2024-07-27 06:09:51 +08:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-31 16:25:39 +08:00
" 0%| | 0/500 [00:00<?, ?it/s]"
2024-07-30 18:24:53 +08:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-07-31 16:25:39 +08:00
"Epoch: 0, iteration: 468, Discriminator Loss:array(0.527415, dtype=float32), Generator Loss: array(0.672361, dtype=float32)\n"
2024-07-30 18:24:53 +08:00
]
},
{
"data": {
2024-07-31 16:25:39 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaEAAAGdCAYAAAC7EMwUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAApSklEQVR4nO3dfXiU9Z3v8c/kaQgQBgIkmZAQIgRBQVSgPCwIUs2arZwq2qKe7YHd1tUK7HKo6xbptbLdLnFtoWyXilc9FnErlXar1l2oGIuEWoqLiEIRMUiAIIRIhEyemJDMff7gkLORAPMdE355eL+ua66LTO4P9y/33Mknk5n5js/zPE8AADgQ53oBAIDuixICADhDCQEAnKGEAADOUEIAAGcoIQCAM5QQAMAZSggA4EyC6wV8ViQS0bFjx5SSkiKfz+d6OQAAI8/zVF1drczMTMXFXfq+TocroWPHjik7O9v1MgAAn1NZWZmysrIuuU2HK6GUlBRJ0hR9SQm+xKhz8Sm9zfvy9e1jzkiSV1NnztSNzTVneu0/Yc409bd/TXH1Z80ZSVJFpT0Tw5QoX89kcyZyusqckSRveI45U5ZvP+Y5v6owZ84OtJ/jivGvCXHhxiuTqbZ/L4Vz+pszdelJ5owk9Xn5PXMmbsilf+i2xiu3nw/eVbH9sh5/KmQPNdhu28ZIg7ZUPNv88/xS2q2EnnzySX3/+9/X8ePHde2112rlypWaOnXqZXPn/wSX4Eu0lZDPfpL54vzmjCR5cfZvtoTEHvZMDOvzxdszcfExPjQYF8M3diSGEorhOERiOB8kyYvh+MX7Y7htY9iPl2DfT8wl1BRDoTTaf5mJi2syZ5piOA7xibGdD5afQefFxXLbxnC+xnKuSlJ8LD/3LvMntYuJ5iGVdnliwvr167Vw4UItWbJEu3bt0tSpU1VQUKAjR460x+4AAJ1Uu5TQihUr9PWvf13f+MY3NHLkSK1cuVLZ2dlavXp1e+wOANBJtXkJNTQ0aOfOncrPz29xfX5+vrZt23bB9uFwWKFQqMUFANA9tHkJnTx5Uk1NTUpPT29xfXp6usrLyy/YvrCwUIFAoPnCM+MAoPtotxerfvYBKc/zWn2QavHixaqqqmq+lJWVtdeSAAAdTJs/O27AgAGKj4+/4F5PRUXFBfeOJMnv98vvj+1ZHgCAzq3N7wklJSVp7NixKioqanF9UVGRJk+e3Na7AwB0Yu3yOqFFixbpa1/7msaNG6dJkybpJz/5iY4cOaIHH3ywPXYHAOik2qWEZs+ercrKSn33u9/V8ePHNWrUKG3cuFE5OfZXowMAui6f58UwR6UdhUIhBQIB3ZIzzzQxIFJ5yrwvX6yTAgbax4Y0HSg1ZxIyg+aMIhF75NRp+34kRa4fbg9t322ONM4Ya8743y4xZySpYewwcyZpz2H7jmL4touEaswZ39VXmTOS5O0/aM7Epw+07yiGiQ6fTrGPxam8LrbJEbm/rjVnfLv22zN59rFex74X29eU/MuAOdNvr+2lM41NYW3e/c+qqqpSnz6XHmvFWzkAAJyhhAAAzlBCAABnKCEAgDOUEADAGUoIAOAMJQQAcIYSAgA4QwkBAJyhhAAAzlBCAABnKCEAgDPtMkW7LTQMSlUkoUfU28cfPWbeh+/qoeaMJDXtOxBTzsrrm2IPNZw1R2onDLbvR1Lv39uHsirPPlAzrrLOnPH1TDZnJMl/oMKcaaqqNmdOzrEPZR243T6kt35wDOeQJA0eY470eG2XOeNLtt9OXpx9gGnuo9vNGUmK759qztRPH23OVGcnmjO1B2KbPT3onUpzxjv8sTHQEPWm3BMCADhDCQEAnKGEAADOUEIAAGcoIQCAM5QQAMAZSggA4AwlBABwhhICADhDCQEAnKGEAADOUEIAAGcoIQCAMx12inbSR+VKiEuKPpCeZt5H49795owkJWSkmzNNwQHmjO9klTnjVYXMmabrB5ozkqQ0+4ThSLJ9WvDZgN+cSSo5bM5IkvJyzBFfv97mzIB/e8ecOXnfjeZMXYbPnJGk5JP2Cc3H/2G8OdMnhoH0ldfb15a6sa99R5IiORnmTM8/2if6xzUGzZn0zfZp2JJ0cop9X2fybT8jmsJnpNUvRLUt94QAAM5QQgAAZyghAIAzlBAAwBlKCADgDCUEAHCGEgIAOEMJAQCcoYQAAM5QQgAAZyghAIAzlBAAwJkOO8DU69dHXrxhcOXxT8z78I0fbc5IUuRskz0UQ91HTtsHmMYNsA8V7ftf9oGLUmzDUgdstA/UrHwo05y5/e0yc0aSNtxsH0Za9rVh5kzkdvvtlL3Jfj7kfSO22/ZP+tkni/7se1+y78g+i1RZI07YQ5n2AceSFHe43Jypv36IOePfts+ciQwbbM5IUv+NH5ozvt69TNs3RsKK9ivinhAAwBlKCADgDCUEAHCGEgIAOEMJAQCcoYQAAM5QQgAAZyghAIAzlBAAwBlKCADgDCUEAHCGEgIAONNhB5jq2AnJlxT15o2jcs27OJuSaM5IUvmk6Nd13pCXTtl3FImYI42lh+27mXqDOSNJR7+Ybc7UFjaaM5/m20/TtU/cbs5I0sQNb5szH26x305fmLTfnPnSnN3mzPry8eaMJC3f/GfmzIh3PzVnrlprP19f/82N5kz9BHNEktT3o57mTCzDSH29bANCJSkSH2/OSFLtnww1Z5JCtu/bxsYzUpQ3LfeEAADOUEIAAGfavISWLl0qn8/X4pKRkdHWuwEAdAHt8pjQtddeq9dff7354/gY/3YJAOja2qWEEhISuPcDALisdnlMqKSkRJmZmcrNzdU999yjgwcPXnTbcDisUCjU4gIA6B7avIQmTJig5557Tps2bdLTTz+t8vJyTZ48WZWVla1uX1hYqEAg0HzJzrY/7RcA0Dm1eQkVFBTorrvu0ujRo3XLLbdow4YNkqS1a9e2uv3ixYtVVVXVfCkrK2vrJQEAOqh2f7Fqr169NHr0aJWUlLT6eb/fL7/f397LAAB0QO3+OqFwOKx9+/YpGAy2964AAJ1Mm5fQww8/rOLiYpWWluqtt97S3XffrVAopDlz5rT1rgAAnVyb/znu6NGjuvfee3Xy5EkNHDhQEydO1Pbt25WTk9PWuwIAdHJtXkIvvPBCm/w/vtS+8sVF/1hRXL19MGaP9w6YM5J01Xu97SG/feipNyTLvp999q8p8dM6+34kDfn+h+ZM9W2jzJkb7vyjOTOtr31tkvTCMfvAz9999QfmTFHdEHOmuqmHORNujO1b/F9va/2JRJeS8SX7yyv6xjWYM3sn2V+D2HN12JyRpEhNrTnzwb9eY84M/7p9cK7vVAxDkSWVf9V+jg9b/pFp+7hI9Lcrs+MAAM5QQgAAZyghAIAzlBAAwBlKCADgDCUEAHCGEgIAOEMJAQCcoYQAAM5QQgAAZyghAIAzlBAAwJl2f1O7WB36Sqbi/dEPbBzy0knzPk7PHG3OSFJgw15zxldXb99RDMMTI1OvM2cSDpwwZySp6lfp5kxgwSfmzM7/sA89ffPqYeaMJP3ypqfMmX88cYs584PgVnNm9C/+2pwZeeNhc0aS/mH/THOm4dWB5syf/eWb5kwg6Yw5E07vb85Iki8SMWeuHfqxOfPhC/bv29yV5sg5vhhz7YR7QgAAZyghAIAzlBAAwBlKCADgDCUEAHCGEgIAOEMJAQCcoYQAAM5QQgAAZyghAIAzlBAAwBlKCADgDCUEAHCmw07RTn+nQQkJ0Xekl2T/Uvq+Hdv06I++bZ/q7MUwuXbo+tPmTOIH9gm+BxYMNWckaej/tk8urxve15yJJJkjenrKWntI0tWJ9qnJv3nXPo399d/eYM6kv+uZMzW/zTJnJOnkn9p/P43Psq8v1Jhszux78ypzJrtv2JyRJP8n9p8rJ9bmmDPD3qo0Z8ofj20c9uAfxvBjv28f2/ZNYSnKL4l7QgAAZyghAIAzlBAAwBlKCADgDCUEAHCGEgIAOEMJAQCcoYQAAM5QQgAAZyghAIAzlBAAwBlKCADgTIcdYHr49jjFJUffkSNXnGnH1bSU++8hcyac3tOc8X1
2024-07-30 18:24:53 +08:00
"text/plain": [
2024-07-31 16:25:39 +08:00
"<Figure size 640x480 with 1 Axes>"
2024-07-30 18:24:53 +08:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-31 16:25:39 +08:00
" 20%|██ | 100/500 [09:47<38:39, 5.80s/it]"
2024-07-31 01:01:14 +08:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-07-31 16:25:39 +08:00
"Epoch: 100, iteration: 468, Discriminator Loss:array(0.535585, dtype=float32), Generator Loss: array(0.68229, dtype=float32)\n"
2024-07-31 01:01:14 +08:00
]
},
{
"data": {
2024-07-31 16:25:39 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaEAAAGdCAYAAAC7EMwUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAoSElEQVR4nO3df3TU9Z3v8dd3JsnkB2EkQn5BiFHAKuHaKgpy/AG2Zs1uXRV7irbXhW3XqxXcy8Ee71Juj2x7r1h7ZN29VHrW21JYpdI9VeteXDVdJKhIixQLRcUoQUIhRCIkgYRJMvO5f7CkjVCY99fET0Kej3PmHDP5vvx+8s0388qXzLwncM45AQDgQcT3AgAAQxclBADwhhICAHhDCQEAvKGEAADeUEIAAG8oIQCAN5QQAMCbDN8L+LhUKqV9+/YpPz9fQRD4Xg4AwMg5p7a2NpWWlioSOf21zoAroX379qmsrMz3MgAAn1BDQ4PGjBlz2m0GXAnl5+dLkq7N/7Iygqy0cy6ZNO+rc/IEc0aS2oszzZkR/1FvzqQOHzZngpxs+36OHDVnJCk1tdKcif7mXXMmEh9uziQPfmTOSFL7X3zOnMn9f1vMmSAr/XP7Dxn7eRf2e+um2L+3kTfeNmeCjKg5o6g9E+Tk2vcjKfnRIXMmUjnevqN3Qzw+TLzAvh9JkY5u+77eed+0fbfr0qvu33oez0+n30roscce0/e//33t379fEydO1KOPPqqrr776jLkT/wSXEWTZSiiwl1Aqw/6ALUnREA8GGRH7g04qsO8nMByzP+yn05yRwh2/aIivKRLi2AUh9iNJGZn2rykj1PdpYH9vXYjvbSTU1xTiISgIUUIhziEp3PcpEo2F2FGI723Ix69ItMu+rzA/T05p/UmlX56YsGbNGs2fP1+LFi3S1q1bdfXVV6u6ulp79uzpj90BAAapfimhpUuX6utf/7r+5m/+RhdddJEeffRRlZWVafny5f2xOwDAINXnJdTZ2aktW7aoqqqq1/1VVVXauHHjSdsnEgm1trb2ugEAhoY+L6GDBw8qmUyqqKio1/1FRUVqbGw8afslS5YoHo/33HhmHAAMHf32YtWP/0HKOXfKP1ItXLhQLS0tPbeGhob+WhIAYIDp82fHjRw5UtFo9KSrnqamppOujiQpFospFgvxbBIAwKDX51dCWVlZuuyyy1RTU9Pr/pqaGk2bNq2vdwcAGMT65XVCCxYs0B133KHJkyfryiuv1D//8z9rz549uvvuu/tjdwCAQapfSmjWrFlqbm7Wd77zHe3fv1+VlZV6/vnnVV5e3h+7AwAMUoFzzvlexB9rbW1VPB7X9OBm0yvRM8aMNu8r1RxutEuYEUFhREeNNGeSjQfMmTBjcSRJRfb1af+H5kj7tHHmTPZLvzVnpHCjcXSGAY2n0nWZfbRL1vbd5kwy5DkeycuzZwpGmDOuxf6SjM5L7edDtHarOSNJGcUn/x37TNyxhH1HEfuw5jAjhSQpGuLn3XXapix0u06ta39KLS0tGj789PvjrRwAAN5QQgAAbyghAIA3lBAAwBtKCADgDSUEAPCGEgIAeEMJAQC8oYQAAN5QQgAAbyghAIA3lBAAwJt+maLdF6LnxBUNstLePtQw0lTKnpHkurrNmca/nWLOFP/j6+ZMGO9+a0Ko3PjFO8yZMANC90+1n6bn/Xu4IbOuw56LDh9mzmRuqTNnNLbUHIlmhPsRT3102Jxpu8y+vpxn95ozGa9uM2cihaPMGSnckNBITrY5447ZH1MiId8MNBliaGzw2YtN27tkQkpzhjBXQgAAbyghAIA3lBAAwBtKCADgDSUEAPCGEgIAeEMJAQC8oYQAAN5QQgAAbyghAIA3lBAAwBtKCADgDSUEAPBmwE7RTrV3KBWkP1k2CDMtOBq1ZyQpdcwcGfOvu80ZN+ECcyZot69t9MvhpomHkTzUYs6M++EH5kwqO+SE4c+Ot+9ryzvmTCQ315xxu+0Tp4OQ57jr7jJn8mrsU9XDnHnNd1xuzozcaj/vpHC/pYea6B98etcDYc691JtvmbZ3Lv3zhyshAIA3lBAAwBtKCADgDSUEAPCGEgIAeEMJAQC8oYQAAN5QQgAAbyghAIA3lBAAwBtKCADgDSUEAPBmwA4wVTJpGurnUs68i0hejjlzPJdnzrhEwpxJXFhizmS+1mDOZB0eac5IOv49Mork2YcnhhHkhvveRrfvMmdSnZ3mTDJEJsyQy0hWpn0/IaXa2+2hIDBHCn6yyZxxYYcVZ2WZI5HzysyZ5Pu7zZlowQhzRpJSLa3mTCTH9vMUcVEpzdOBKyEAgDeUEADAG0oIAOANJQQA8IYSAgB4QwkBALyhhAAA3lBCAABvKCEAgDeUEADAG0oIAOANJQQA8GbADjCNjDtPkWgs/UBXt3kfbt8Bc0aSghDDEIPh+eZM/U324ZOfaSw3Z7qywg13DMpKzZmd37Yfhwv/vsWcaZlmHyIpSTkHu8yZjMP24bT6XZ05Eh1ZYM6kDh02ZySFGpYaRpgBq87ZhxVHR5xjzkiSMuwPkfbVSdHCUeZM50VjQuxJyjhqP8ddKmXbPpmQtqa3LVdCAABvKCEAgDd9XkKLFy9WEAS9bsXFxX29GwDAWaBf/iY0ceJE/fKXv+z5OBr2DaUAAGe1fimhjIwMrn4AAGfUL38TqqurU2lpqSoqKnTbbbdp164//ZbJiURCra2tvW4AgKGhz0toypQpWrVqlV588UU9/vjjamxs1LRp09Tc3HzK7ZcsWaJ4PN5zKysL99RaAMDg0+clVF1drVtvvVWTJk3SF77wBa1du1aStHLlylNuv3DhQrW0tPTcGhoa+npJAIABqt9frJqXl6dJkyapru7UL86LxWKKxQwvSgUAnDX6/XVCiURCb7/9tkpKSvp7VwCAQabPS+ib3/ymamtrVV9fr1/96lf60pe+pNbWVs2ePbuvdwUAGOT6/J/j9u7dq9tvv10HDx7UqFGjNHXqVG3atEnl5faZZgCAs1ufl9BTTz3VJ/+f5Lv1CgL7cEOLprunhMoV//g35sx7/83+z5EXLvytOdP655PMmWgizMhFqfUR+4X09JH2wZ0jf3bEnPn2qNXmjCS9kcg1Z+7e8lVzJj/3PHPmprHbzJkfv3KtOSNJw3bZX2Be9Of2JxVlzLGfe4kLCs2ZD++zn0OS1NpmPx/GlzSZM8nrPjRnslLhfm67x9kHD2c0t5u2jyTTH+rL7DgAgDeUEADAG0oIAOANJQQA8IYSAgB4QwkBALyhhAAA3lBCAABvKCEAgDeUEADAG0oIAOANJQQA8Kbf39QutM9dKGVkp715tL7RvIv8fd3mjCSlEukP5zth3ENvmTNHrq80Z+K/+r058/ubx5ozktTSNNycuW78f5gzr7WON2e6XMqckaSkAnOmsmS/OfNXxRvNmauzD5ozX/3iG+aMJGXaD4OyA3vomn+4y5wJAvvPX26Ir0eSPjfWPpR187YLzJnhf3ueOZPTHO4cL9jSbM4E7cds26cYYAoAGAQoIQCAN5QQAMAbSggA4A0lBADwhhICAHhDCQEAvKGEAADeUEIAAG8oIQCAN5QQAMAbSggA4A0lBADwZuBO0f7N21KQmfbmqUsvNu8i59lfmzOSFMnNtYcCe98P27TbnEmNKjBnRtR1mTOSFGtJ//tzwvK1XzJnfrx0qTmTNCeOK422mTP5mfapzn///dnmTPU9r5ozXz5nszkjSaMC+4T5Xd1Z5sz5I+0TnXcdPNecOfr6SHNGkuZ9/WfmzMJ/sk99z2myTamWpMzNO80ZKdy7AAQXj7PtI5klpTnQnyshAIA3lBAAwBtKCADgDSUEAPCGEgIAeEMJAQC8oYQAAN5QQgAAbyghAIA3lBAAwBtKCADgDSUEAPBmwA4wjeRkKxKkPxAx9Zu3zPsIMu0DFyUpdSzEAMAu+0DIIDtmzzR+aM7k7N1vzkhSdoivqeWLk8yZ3MCZM5lB1JyRpAPJPHPmzVX2r2nU746aM2tevMqcya3uNGckqbnLfhwuHfaBObP36QpzJhhmjuiWL79iD0n62ut/bc7
2024-07-31 01:01:14 +08:00
"text/plain": [
2024-07-31 16:25:39 +08:00
"<Figure size 640x480 with 1 Axes>"
2024-07-31 01:01:14 +08:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-31 16:25:39 +08:00
" 40%|████ | 200/500 [19:30<29:44, 5.95s/it]"
2024-07-31 01:01:14 +08:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-07-31 16:25:39 +08:00
"Epoch: 200, iteration: 468, Discriminator Loss:array(0.550168, dtype=float32), Generator Loss: array(0.68665, dtype=float32)\n"
2024-07-31 01:01:14 +08:00
]
},
{
"data": {
2024-07-31 16:25:39 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaEAAAGdCAYAAAC7EMwUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAg1UlEQVR4nO3df3TU9b3n8dd38mMIOBlFSGYiMU0t1BYoPUUKchADW3PNPeWo2C7qbi/sth6twC4HPZ5S9tZszy7x2iOH3qXSraeXwqlU7t1V6h64YrqYoBfpIsVK0dK4BIklMSXFTAgwYTKf/YMyZwOI+XyZmc/M5Pk4Z84xM9+X3w/ffJMXX2bmPZ4xxggAAAcCrhcAABi5KCEAgDOUEADAGUoIAOAMJQQAcIYSAgA4QwkBAJyhhAAAzhS7XsDFksmkjh8/rlAoJM/zXC8HAGDJGKO+vj5VVVUpELjytU7OldDx48dVXV3tehkAgKvU0dGhCRMmXHGbnCuhUCgkSZqjv1axSjK6r8CooK9cMj5gH/J8/MtnctDHfnxcPfqd3JSlfRVVjLfODHb/yTqT6z672/7H9fDcRAZWgispurbcOjP4Ucx+R4Ei+4zk7/eKpYTO6XXtSP0+v5KMldAzzzyjH/zgB+rs7NTkyZO1bt063XbbbZ+Yu/BPcMUqUbGX4RLySn3lkp6PX9p+SshXxs8/YWaxhHzsqyhg/33yMnzuuFB6jf2PazH/pJ11RT5+r/g6Xz2fJeTn94qtv/yYD+cplYysZuvWrVqxYoVWr16tAwcO6LbbblNDQ4OOHTuWid0BAPJURkpo7dq1+uY3v6lvfetb+tznPqd169apurpaGzZsyMTuAAB5Ku0lNDAwoP3796u+vn7I/fX19dqzZ88l28fjccVisSE3AMDIkPYSOnHihAYHB1VZWTnk/srKSnV1dV2yfVNTk8LhcOrGK+MAYOTI2DNUFz8hZYy57JNUq1atUm9vb+rW0dGRqSUBAHJM2l8dN27cOBUVFV1y1dPd3X3J1ZEkBYNBBYP+XioNAMhvab8SKi0t1fTp09Xc3Dzk/ubmZs2ePTvduwMA5LGMvE9o5cqV+sY3vqFbbrlFt956q37yk5/o2LFjevjhhzOxOwBAnspICS1atEg9PT36/ve/r87OTk2ZMkU7duxQTU1NJnYHAMhTnjF+Z7ZkRiwWUzgcVp3uyvjEBL+Kq688C+lyEh0fZGAl+ScwjDEeF0v29WVgJWmUzVFJtvxOTMjl9WXxV1Zg9GjrTPL06Qys5DKyOQ3D8pgnzDm16Jfq7e1VefmVxxjxUQ4AAGcoIQCAM5QQAMAZSggA4AwlBABwhhICADhDCQEAnKGEAADOUEIAAGcoIQCAM5QQAMAZSggA4ExGpmgXulweRprrA0JzfhipH34GamZrcGduzSe+VI6vL2vDSP3I8WM3XFwJAQCcoYQAAM5QQgAAZyghAIAzlBAAwBlKCADgDCUEAHCGEgIAOEMJAQCcoYQAAM5QQgAAZyghAIAzlBAAwBmmaPtQfEOVdSbxx+MZWMmlCnJKdTYFiuwzyUH7TIFMQAauFldCAABnKCEAgDOUEADAGUoIAOAMJQQAcIYSAgA4QwkBAJyhhAAAzlBCAABnKCEAgDOUEADAGUoIAOAMA0x9yNYwUnmedaS4ssI6k+j60DqTTd4tU6wzH84q97Wvqf/md9aZ196dZJ25+e9PW2cS4aB1pvQPndYZSUr2xuwzZ87Y74hBrpIkL2j/vfWKfAzblZQ8bX/uZRJXQgAAZyghAIAzlBAAwBlKCADgDCUEAHCGEgIAOEMJAQCcoYQAAM5QQgAAZyghAIAzlBAAwBlKCADgDANMffCK7Q+bSSSsM0XXXmudyfVhpH6Gsv5xddI68+jn/od1RpLmjz5inQlX77TfUb19xI+fx+yHq0rShs0LrDNVu/utM31/e8o6U97wf60zuc7E4/aZgL8BprmGKyEAgDOUEADAmbSXUGNjozzPG3KLRCLp3g0AoABk5DmhyZMn61e/+lXq6yKfH74EAChsGSmh4uJirn4AAJ8oI88JtbW1qaqqSrW1tbrvvvt05MjHv+IoHo8rFosNuQEARoa0l9DMmTO1efNm7dy5U88++6y6uro0e/Zs9fT0XHb7pqYmhcPh1K26ujrdSwIA5Ki0l1BDQ4PuvfdeTZ06VV/5yle0fft2SdKmTZsuu/2qVavU29ubunV0dKR7SQCAHJXxN6uOGTNGU6dOVVtb22UfDwaDCgaDmV4GACAHZfx9QvF4XO+++66i0WimdwUAyDNpL6HHHntMra2tam9v169//Wt97WtfUywW0+LFi9O9KwBAnkv7P8d98MEHuv/++3XixAmNHz9es2bN0t69e1VTU5PuXQEA8lzaS+j5559Pz//I8+yGXRqTnv0OZ1c+hpHKx7DBwZMn7fdTgCY02n9vI/+z19e+2s6FrTPXBs5YZ77+2sPWGe9EqXXm5vVd1hlJunHUCetM25LrrTPTy+1fiNRXY/8K2sT7BfiCp+Sgv5yPIcKZ/P3K7DgAgDOUEADAGUoIAOAMJQQAcIYSAgA4QwkBAJyhhAAAzlBCAABnKCEAgDOUEADAGUoIAOAMJQQAcCbjH2rnmzGSsjeUNOP8DhssMIHRo60zf1hsP1T0b39/l3VGkk50XGudmfTIPuvMRPMb64wfCT/DKiUl53zROhP8s/2+ziZKrDMDnxpnnSnpP22dkaTBEz2+crYCY8bYhwb9/U5Jnj3rK5cpXAkBAJyhhAAAzlBCAABnKCEAgDOUEADAGUoIAOAMJQQAcIYSAgA4QwkBAJyhhAAAzlBCAABnKCEAgDOUEADAmdydom3Lz7RgU0BTugvYzT/8o3XG9MZ87ev6U0fs95XD51HRzZ/xlfuodpR1ZsbdB60zP6lusc5MeeTfWWduahxrnZEkM2mCdSb26TLrTPdM+3No4vJfW2dyEVdCAABnKCEAgDOUEADAGUoIAOAMJQQAcIYSAgA4QwkBAJyhhAAAzlBCAABnKCEAgDOUEADAGUoIAOBMzg4w9UpK5Xklw97enBvI4GqQLl5pqX0oYP93Je+aa+z3I8nETvnKZcOpr8+0zox6qNPXvv5p0g+sM+FAkY892Z8Pv5uz0Toz+Xv/3jrj1z/O+qF1prn/89aZXylknclFXAkBAJyhhAAAzlBCAABnKCEAgDOUEADAGUoIAOAMJQQAcIYSAgA4QwkBAJyhhAAAzlBCAABnKCEAgDM5O8DUnBuQ8YzrZSDNvHL7waKmLGidee+hKuuMJP39135lnfn2rr+xzlz72+EP570gcOcJ68yc649ZZyRp0MeP3unkoHUmXGw/9PRYwn7IbOKcn+Gq0meXtllnvv6fVlhnPv1iv3VmsM7fr+/if/mddcYbZfczGDADUt8wt7VeDQAAaUIJAQCcsS6h3bt3a8GCBaqqqpLnedq2bduQx40xamxsVFVVlcrKylRXV6dDhw6la70AgAJiXUL9/f2aNm2a1q9ff9nHn3rqKa1du1br16/Xvn37FIlEdMcdd6ivb5j/QAgAGDGsn9lqaGhQQ0PDZR8zxmjdunVavXq1Fi5cKEnatGmTKisrtWXLFj300ENXt1oAQEFJ63NC7e3t6urqUn19feq+YDCo22+/XXv27LlsJh6PKxaLDbkBAEaGtJZQV1eXJKmysnLI/ZWVlanHLtbU1KRwOJy6VVdXp3NJAIAclpFXx3meN+RrY8wl912watUq9fb2pm4dHR2ZWBIAIAel9c2qkUhE0vkromg0mrq/u7v7kqujC4LBoIJB+zcjAgDyX1qvhGpraxWJRNTc3Jy6b2BgQK2trZo9e3Y6dwUAKADWV0KnTp3Se++9l/q6vb1db731lsaOHasbb7xRK1as0Jo1azRx4kRNnDhRa9as0ejRo/XAAw+kdeEAgPxnXUJvvvmm5s2bl/p65cqVkqTFixfrZz/7mR5//HGdOXNGjzzyiE6ePKmZM2fqlVdeUSgUSt+qAQAFwTPG5NSU0FgspnA4rDrdpWLPfsgjclvR9WOtM/Na37fO/Ifrfm+dkaR
2024-07-31 01:01:14 +08:00
"text/plain": [
2024-07-31 16:25:39 +08:00
"<Figure size 640x480 with 1 Axes>"
2024-07-31 01:01:14 +08:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-31 16:25:39 +08:00
" 49%|████▉ | 246/500 [24:09<26:17, 6.21s/it]"
2024-07-29 00:18:35 +08:00
]
2024-07-26 21:07:40 +08:00
}
],
"source": [
2024-07-30 00:44:16 +08:00
"# Set your parameters\n",
2024-07-31 16:25:39 +08:00
"n_epochs = 500\n",
2024-07-30 00:44:16 +08:00
"display_step = 5000\n",
2024-07-28 06:10:19 +08:00
"cur_step = 0\n",
2024-07-30 00:44:16 +08:00
"\n",
2024-07-31 16:25:39 +08:00
"batch_size = 128 # 128\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-28 06:10:19 +08:00
"D_loss_grad = nn.value_and_grad(disc, disc_loss)\n",
"G_loss_grad = nn.value_and_grad(gen, gen_loss)\n",
"\n",
"\n",
2024-07-30 00:44:16 +08:00
"for epoch in tqdm(range(n_epochs)):\n",
2024-07-28 06:10:19 +08:00
"\n",
2024-07-30 18:21:38 +08:00
" for idx,real in enumerate(batch_iterate(batch_size, train_images)):\n",
2024-07-28 06:10:19 +08:00
" \n",
2024-07-30 00:44:16 +08:00
" # TODO Train Discriminator\n",
2024-07-30 07:06:52 +08:00
" D_loss,D_grads = D_loss_grad(gen, disc,mx.array(real), batch_size, z_dim)\n",
2024-07-28 06:10:19 +08:00
"\n",
" # Update optimizer\n",
" disc_opt.update(disc, D_grads)\n",
" \n",
" # Update gradients\n",
" mx.eval(disc.parameters(), disc_opt.state)\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-30 00:44:16 +08:00
" # TODO Train Generator\n",
" G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)\n",
2024-07-28 06:10:19 +08:00
" \n",
" # Update optimizer\n",
" gen_opt.update(gen, G_grads)\n",
" \n",
" # Update gradients\n",
2024-07-30 07:37:09 +08:00
" mx.eval(gen.parameters(), gen_opt.state) \n",
2024-07-30 07:44:41 +08:00
" \n",
2024-07-30 18:21:38 +08:00
" # if (cur_step + 1) % display_step == 0:\n",
" # print(f\"Step {epoch}: Generator loss: {G_loss}, discriminator loss: {D_loss}\")\n",
" # fake_noise = mx.array(get_noise(batch_size, z_dim))\n",
" # fake = gen(fake_noise)\n",
" # show_images(fake)\n",
" # show_images(real)\n",
" # cur_step += 1\n",
" \n",
2024-07-31 16:25:39 +08:00
" if epoch%100==0:\n",
2024-07-30 18:21:38 +08:00
" print(\"Epoch: {}, iteration: {}, Discriminator Loss:{}, Generator Loss: {}\".format(epoch,idx,D_loss,G_loss))\n",
2024-07-30 07:44:41 +08:00
" fake_noise = mx.array(get_noise(batch_size, z_dim))\n",
" fake = gen(fake_noise)\n",
2024-07-31 16:25:39 +08:00
" # show_images(fake)\n",
" # show_images(real)\n",
" plt.imshow(fake[0].reshape(28,28))\n",
" plt.show()\n",
" # show_images(fake[:25])\n",
2024-07-30 18:21:38 +08:00
" \n",
" # print('Losses D={0} G={1}'.format(D_loss,G_loss))"
2024-07-27 05:19:08 +08:00
]
2024-07-26 21:07:40 +08:00
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}