mlx-examples/gan/playground.ipynb

651 lines
938 KiB
Plaintext
Raw Normal View History

2024-07-26 21:07:40 +08:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Import Library"
]
},
{
"cell_type": "code",
2024-07-30 18:45:09 +08:00
"execution_count": 538,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"import mnist"
]
},
{
"cell_type": "code",
2024-07-30 18:45:09 +08:00
"execution_count": 539,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"import mlx.core as mx\n",
"import mlx.nn as nn\n",
"import mlx.optimizers as optim\n",
"\n",
"from tqdm import tqdm\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
2024-07-30 07:17:12 +08:00
{
"cell_type": "code",
2024-07-30 18:45:09 +08:00
"execution_count": 540,
2024-07-30 07:17:12 +08:00
"metadata": {},
"outputs": [],
"source": [
"# mx.set_default_device(mx.gpu)"
]
},
2024-07-26 21:07:40 +08:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# GAN Architecture"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generator 👨🏻‍🎨"
]
},
{
"cell_type": "code",
2024-07-30 18:45:09 +08:00
"execution_count": 541,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def GenBlock(in_dim:int,out_dim:int):\n",
" \n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
2024-07-30 07:06:52 +08:00
" nn.BatchNorm(out_dim, 0.8),\n",
" nn.LeakyReLU(0.2)\n",
2024-07-26 21:07:40 +08:00
" )"
]
},
{
"cell_type": "code",
2024-07-30 18:45:09 +08:00
"execution_count": 542,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"class Generator(nn.Module):\n",
"\n",
2024-07-30 18:45:09 +08:00
" def __init__(self, z_dim:int = 32, im_dim:int = 784, hidden_dim: int = 128):\n",
2024-07-26 21:07:40 +08:00
" super(Generator, self).__init__()\n",
2024-07-30 18:21:38 +08:00
"\n",
2024-07-26 21:07:40 +08:00
" self.gen = nn.Sequential(\n",
" GenBlock(z_dim, hidden_dim),\n",
" GenBlock(hidden_dim, hidden_dim * 2),\n",
" GenBlock(hidden_dim * 2, hidden_dim * 4),\n",
"\n",
2024-07-30 07:56:13 +08:00
" nn.Linear(hidden_dim * 4,im_dim),\n",
2024-07-26 21:07:40 +08:00
" )\n",
" \n",
" def __call__(self, noise):\n",
2024-07-30 07:06:52 +08:00
" x = self.gen(noise)\n",
" return mx.tanh(x)"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 18:45:09 +08:00
"execution_count": 543,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Generator(\n",
" (gen): Sequential(\n",
" (layers.0): Sequential(\n",
2024-07-30 18:45:09 +08:00
" (layers.0): Linear(input_dims=100, output_dims=128, bias=True)\n",
" (layers.1): BatchNorm(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
2024-07-30 07:06:52 +08:00
" (layers.2): LeakyReLU()\n",
2024-07-26 21:07:40 +08:00
" )\n",
" (layers.1): Sequential(\n",
2024-07-30 18:45:09 +08:00
" (layers.0): Linear(input_dims=128, output_dims=256, bias=True)\n",
" (layers.1): BatchNorm(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
2024-07-30 07:06:52 +08:00
" (layers.2): LeakyReLU()\n",
2024-07-26 21:07:40 +08:00
" )\n",
" (layers.2): Sequential(\n",
2024-07-30 18:45:09 +08:00
" (layers.0): Linear(input_dims=256, output_dims=512, bias=True)\n",
" (layers.1): BatchNorm(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
2024-07-30 07:06:52 +08:00
" (layers.2): LeakyReLU()\n",
2024-07-26 21:07:40 +08:00
" )\n",
2024-07-30 18:45:09 +08:00
" (layers.3): Linear(input_dims=512, output_dims=784, bias=True)\n",
2024-07-26 21:07:40 +08:00
" )\n",
")"
]
},
2024-07-30 18:45:09 +08:00
"execution_count": 543,
2024-07-26 21:07:40 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gen = Generator(100)\n",
"gen"
]
},
{
"cell_type": "code",
2024-07-30 18:45:09 +08:00
"execution_count": 544,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-30 00:44:16 +08:00
"# make 2D noise with shape n_samples x z_dim\n",
"def get_noise(n_samples:list[int], z_dim:int)->list[int]:\n",
" return mx.random.normal(shape=(n_samples, z_dim))"
]
},
{
"cell_type": "code",
2024-07-30 18:45:09 +08:00
"execution_count": 545,
2024-07-30 00:44:16 +08:00
"metadata": {},
"outputs": [
{
"data": {
2024-07-30 18:45:09 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAWA0lEQVR4nO3cXWzW9d3H8U+h0Ada2kJbhBZqC4MioOAm2mQgA0Q63XDLJGZZNCNLXLajbUcSZ5aYLLotWZZtzmxjHJixGocxU3E+IHNKpwPkQQQqbYFCsaWuLX2gD1Cu++ybmPug1+eX3N537rxfx9f7umof+Pg/+eZkMpmMAACQNOV/+wsAAPzfwSgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAg5Gb7wt/85jf2mw8MDNhNaWmp3UhSf3+/3eTn59tNd3e33cyePdtuPvroI7uRpI6ODrv59re/bTetra1209DQYDeStH//fru566677OYf//iH3QwODtrNsmXL7EaSRkdH7WZ4eNhu5s2bZzeffPKJ3RQVFdmNJI2MjNhNV1eX3dx0001209fXZzeSVFVVZTc5OTl209jYOOlreFIAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAIeuDeCmH6s6ePWs3ublZf0mfknKQa/369XazePFiu5k1a5bdHDlyxG4kafXq1XZz6tQpu7l+/brdNDc3240ktbS02E1lZaXd1NfX283ChQvt5oUXXrAbKe1vMOXo3K9+9Su7uf/+++0m5WckSRcuXLCblO9dU1OT3Xz3u9+1G0n6+OOP7SblbzAbPCkAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAkJPJZDLZvPCJJ56w3/zSpUt2M3XqVLuRpCz/Mz5lyZIldjM2NmY3HR0ddlNRUWE3krRs2TK7aW9vt5uUg30zZsywG0n66le/ajfvvfee3Vy7ds1u9u7dazep34eVK1fazQcffGA3hYWFdpPytfX29tqNJNXU1NjN+++/n/RZrrq6uqQu5d+VTZs22c2aNWsmfQ1PCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAkJvtC8vLy+03P3HihN184xvfsBsp7SJrW1ub3Zw7d85u5syZYzeVlZV2I0kDAwN2Mzw8bDdLly61m66uLrtJdeONN9pNys/23nvvtZv77rvPbiTp9ddft5tVq1bZTWdnp92kmDdvXlJXVFRkN5s3b7abf/7zn3ZTXV1tN5K0b98+u3n55ZfthiupAAALowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgJCTyWQy2bzwkUcesd+8trbWboaGhuxGkpqbm+3m6aeftputW7fazQ9/+EO7ycvLsxtJ2rt3r93cfvvtdtPU1GQ3t9xyi91IaUf++vv77eZzn/uc3ZSVldlNb2+v3UjSxMSE3YyNjdnN6Oio3eTn59vN2bNn7UaS7rzzzqTO1dLSYjd33HFH0melHPnbuXOn3TzzzDOTvoYnBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABBys33hzTffbL95yiGzAwcO2I0kLVu2zG5OnjxpNynHuA4ePGg3ixYtshtJyvK+4afs3r3bbhoaGuzm2rVrdiNJra2tdrN69Wq76evrs5urV6/azcjIiN1I0vj4eFLnOnr0qN3U19fbTervw9KlS+2moqLCbrq7u+3myJEjdiOlHbLctm1b0mdNhicFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAELI+iFdcXGy/+Te/+U27+fOf/2w3kvTGG298Js3MmTPtZtq0aXZz8eJFu5HSDn+dO3fObi5dumQ35eXldiNJRUVFdjN//ny7STnQlvJzKiwstBtJGh0dtZvBwUG7WbFihd0cOnTIbhYvXmw3knT48GG7OX/+vN0sWLDAblIOA0pSXl6e3aQcY8wGTwoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgZH0Q76233rLf/NFHH7WbsbExu5GkBx980G7OnDljN6+88ordrFy50m4uXLhgN5K0bt06uzl69KjdpPw3pR75GxgYsJuUo25/+tOf7KasrMxuqqqq7EZKO4CWn59vN52dnXbzox/9yG62b99uN5J0/fp1uxkaGrKblL+ln//853YjSdu2bbObN998M+mzJsOTAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgZH0ldcGCBfabP/vss3aTk5NjN5JUUFBgN3/5y1/sprGx0W5SLlU2NzfbjSRNnz7dbhYtWmQ3g4ODdlNeXm43krR161a7+elPf2o3d9xxh9309PTYTXt7u91IUm5u1n+uIeUScGFhod10dXXZzVe+8hW7kaTKykq7+eIXv2g3v/zlL+1m1apVdiOlXc7Ny8tL+qzJ8KQAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAQtYXtjKZjP3m3//+9+0m9SDea6+9ZjePPfaY3Tz66KN285Of/MRu5s+fbzeS9NRTT9nN5z//ebt57rnn7KakpMRupLQjY21tbXazePFiu5kzZ47dzJ49226ktON2a9as+Uw+56OPPrKb1tZWu5GktWvX2k1HR4fdLFy40G5GR0ftRpL+8Ic/2E3q39NkeFIAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAIeuDeIWFhfab7969224aGxvtRpK2b99uN7t27bKbdevW2c2TTz5pN1/+8pftRkr7/lVUVNhNyu9Dfn6+3UjSvn377Oaee+6xm66uLrspLi62m2vXrtmNJJ0+fdpuUg7Vbdy40W5SpBw6lNKOc+7fv99uNmzYYDfvvvuu3UjSxMSE3ZSXlyd91mR4UgAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAAAh64N477zzjv3mKcfjCgoK7EaSnnrqKbupqamxm56eHrspLS21m/fee89uJOnKlSt2U19fbzd1dXV2M2VK2v+DVFZW2k3KQcEDBw7YzYkTJ+zm1KlTdiOlfR8+q6NuKX+3Q0NDdiNJt956q92kHGN8/PHH7eaxxx6zGyntd2LWrFlJnzUZnhQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAACEnk8lksnnhX//6V/vNe3t77WZ8fNxuJGn9+vV289prr9lNyjXDuXPn2s2KFSvsRpL27NljN4sWLbKbkZERuyksLLQbSZo6dardlJSU2E2WfwqfcvHiRbtpamqyG0l6+OGH7WZ4eNhuUq6xnj9/3m5mz55tN5J055132s0zzzxjN1/4whfs5s0337QbKe3q8LJly+xmy5Ytk76GJwUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQcrN9YcpxuylT/M05dOiQ3UjSPffcYzfd3d12U1NTYzcpx8KuXr1qN5K0bt06uzlw4IDdrFmzxm76+/vtRpLa29vtJi8vz24uX75sN5cuXbKbN954w24k6ZVXXrGb0tJSuzl9+rTdLF++3G46OzvtRpJ6enrsJuX3IeUIaMq/eZL08ccf2838+fOTPmsyPCkAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAkPVBvNtuu81+8+eee85uUo6SSdKOHTvspqWlxW5uueUWu9m8ebPdNDU
2024-07-30 00:44:16 +08:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"img = get_noise(28,28)\n",
"plt.imshow(img, cmap='gray')\n",
"plt.axis('off')\n",
"plt.show()"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Discriminator 🕵🏻‍♂️"
]
},
{
"cell_type": "code",
2024-07-30 18:45:09 +08:00
"execution_count": 546,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def DisBlock(in_dim:int,out_dim:int):\n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
2024-07-30 00:44:16 +08:00
" nn.LeakyReLU(negative_slope=0.2),\n",
2024-07-30 18:21:38 +08:00
" nn.Dropout(0.3),\n",
2024-07-26 21:07:40 +08:00
" )"
]
},
{
"cell_type": "code",
2024-07-30 18:45:09 +08:00
"execution_count": 547,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"class Discriminator(nn.Module):\n",
"\n",
2024-07-30 18:45:09 +08:00
" def __init__(self,im_dim:int = 784, hidden_dim:int = 128):\n",
2024-07-26 21:07:40 +08:00
" super(Discriminator, self).__init__()\n",
"\n",
" self.disc = nn.Sequential(\n",
2024-07-30 07:37:09 +08:00
" DisBlock(im_dim, hidden_dim * 4),\n",
" DisBlock(hidden_dim * 4, hidden_dim * 2),\n",
2024-07-26 21:07:40 +08:00
" DisBlock(hidden_dim * 2, hidden_dim),\n",
2024-07-30 07:37:09 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" nn.Linear(hidden_dim,1),\n",
2024-07-30 18:21:38 +08:00
" nn.Sigmoid()\n",
2024-07-26 21:07:40 +08:00
" )\n",
" \n",
2024-07-30 18:21:38 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" def __call__(self, noise):\n",
2024-07-30 18:21:38 +08:00
" return self.disc(noise)"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 18:45:09 +08:00
"execution_count": 548,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Discriminator(\n",
" (disc): Sequential(\n",
" (layers.0): Sequential(\n",
2024-07-30 18:45:09 +08:00
" (layers.0): Linear(input_dims=784, output_dims=512, bias=True)\n",
2024-07-26 21:07:40 +08:00
" (layers.1): LeakyReLU()\n",
2024-07-30 18:21:38 +08:00
" (layers.2): Dropout(p=0.30000000000000004)\n",
2024-07-26 21:07:40 +08:00
" )\n",
" (layers.1): Sequential(\n",
2024-07-30 18:45:09 +08:00
" (layers.0): Linear(input_dims=512, output_dims=256, bias=True)\n",
2024-07-30 07:37:09 +08:00
" (layers.1): LeakyReLU()\n",
2024-07-30 18:21:38 +08:00
" (layers.2): Dropout(p=0.30000000000000004)\n",
2024-07-30 07:37:09 +08:00
" )\n",
" (layers.2): Sequential(\n",
2024-07-30 18:45:09 +08:00
" (layers.0): Linear(input_dims=256, output_dims=128, bias=True)\n",
2024-07-26 21:07:40 +08:00
" (layers.1): LeakyReLU()\n",
2024-07-30 18:21:38 +08:00
" (layers.2): Dropout(p=0.30000000000000004)\n",
2024-07-26 21:07:40 +08:00
" )\n",
2024-07-30 18:45:09 +08:00
" (layers.3): Linear(input_dims=128, output_dims=1, bias=True)\n",
2024-07-30 18:21:38 +08:00
" (layers.4): Sigmoid()\n",
2024-07-26 21:07:40 +08:00
" )\n",
")"
]
},
2024-07-30 18:45:09 +08:00
"execution_count": 548,
2024-07-26 21:07:40 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"disc = Discriminator()\n",
"disc"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model Training 🏋🏻‍♂️"
]
},
{
2024-07-30 00:44:16 +08:00
"cell_type": "markdown",
2024-07-26 21:07:40 +08:00
"metadata": {},
"source": [
2024-07-30 00:44:16 +08:00
"### Losses"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-07-30 00:44:16 +08:00
"#### Discriminator Loss"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 18:45:09 +08:00
"execution_count": 549,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-26 21:36:29 +08:00
"def disc_loss(gen, disc, real, num_images, z_dim):\n",
2024-07-30 00:44:16 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" noise = mx.array(get_noise(num_images, z_dim))\n",
" fake_images = gen(noise)\n",
2024-07-27 06:09:51 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" fake_disc = disc(fake_images)\n",
" \n",
2024-07-27 05:19:08 +08:00
" fake_labels = mx.zeros((fake_images.shape[0],1))\n",
2024-07-30 18:21:38 +08:00
" \n",
" fake_loss = mx.mean(nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True))\n",
2024-07-30 00:44:16 +08:00
" \n",
2024-07-30 18:21:38 +08:00
" real_disc = mx.array(disc(real))\n",
2024-07-27 05:19:08 +08:00
" real_labels = mx.ones((real.shape[0],1))\n",
2024-07-30 18:21:38 +08:00
" \n",
" real_loss = mx.mean(nn.losses.binary_cross_entropy(real_disc,real_labels,with_logits=True))\n",
" \n",
" disc_loss = (fake_loss + real_loss)\n",
2024-07-26 21:07:40 +08:00
"\n",
" return disc_loss"
]
},
2024-07-30 00:44:16 +08:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Generator Loss"
]
},
2024-07-26 21:07:40 +08:00
{
"cell_type": "code",
2024-07-30 18:45:09 +08:00
"execution_count": 550,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-26 21:36:29 +08:00
"def gen_loss(gen, disc, num_images, z_dim):\n",
2024-07-26 21:07:40 +08:00
"\n",
" noise = mx.array(get_noise(num_images, z_dim))\n",
2024-07-30 07:06:52 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" fake_images = gen(noise)\n",
2024-07-30 18:21:38 +08:00
" fake_disc = mx.array(disc(fake_images))\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-27 05:19:08 +08:00
" fake_labels = mx.ones((fake_images.shape[0],1))\n",
2024-07-30 00:44:16 +08:00
" \n",
2024-07-30 18:21:38 +08:00
" gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)\n",
" \n",
" return mx.mean(gen_loss)"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 18:45:09 +08:00
"execution_count": 551,
2024-07-30 00:44:16 +08:00
"metadata": {},
"outputs": [],
"source": [
"# Get only the training images\n",
2024-07-30 18:21:38 +08:00
"train_images,train_labels,*_ = map(np.array, mnist.mnist())"
2024-07-30 00:44:16 +08:00
]
},
{
"cell_type": "code",
2024-07-30 18:45:09 +08:00
"execution_count": 552,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-30 00:44:16 +08:00
"# Normalize the images to fall between -1,1\n",
"train_images = train_images * 2.0 - 1.0"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 18:45:09 +08:00
"execution_count": 553,
2024-07-29 06:24:50 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2024-07-30 18:45:09 +08:00
"<matplotlib.image.AxesImage at 0x119f89a80>"
2024-07-29 06:24:50 +08:00
]
},
2024-07-30 18:45:09 +08:00
"execution_count": 553,
2024-07-29 06:24:50 +08:00
"metadata": {},
"output_type": "execute_result"
2024-07-30 00:44:16 +08:00
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaEAAAGdCAYAAAC7EMwUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAaI0lEQVR4nO3df2jU9x3H8dfVH1d1lytBk7vUmGVF202dpWrVYP3R1cxApf4oWMtGZEPa+YOJ/cGsDNNBjdgpRdI6V0amW239Y9a6KdUMTXRkijpdRYtYjDOdCcFM72LUSMxnf4hHz1j1e975vkueD/iCufu+vY/ffuvTby75xueccwIAwMBD1gsAAHRfRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJjpab2AW3V0dOjcuXMKBALy+XzWywEAeOScU0tLi/Ly8vTQQ3e+1km7CJ07d075+fnWywAA3Kf6+noNHDjwjvuk3afjAoGA9RIAAElwL3+fpyxCH3zwgQoLC/Xwww9r5MiR2rdv3z3N8Sk4AOga7uXv85REaPPmzVq8eLGWLVumI0eO6JlnnlFJSYnOnj2bipcDAGQoXyruoj1mzBg99dRTWrduXeyx73//+5o+fbrKy8vvOBuNRhUMBpO9JADAAxaJRJSVlXXHfZJ+JXTt2jUdPnxYxcXFcY8XFxertra20/5tbW2KRqNxGwCge0h6hM6fP6/r168rNzc37vHc3Fw1NjZ22r+8vFzBYDC28ZVxANB9pOwLE259Q8o5d9s3qZYuXapIJBLb6uvrU7UkAECaSfr3CfXv3189evTodNXT1NTU6epIkvx+v/x+f7KXAQDIAEm/Eurdu7dGjhypqqqquMerqqpUVFSU7JcDAGSwlNwxYcmSJfrpT3+qUaNGady4cfr973+vs2fP6tVXX03FywEAMlRKIjR79mw1NzfrN7/5jRoaGjRs2DDt2LFDBQUFqXg5AECGSsn3Cd0Pvk8IALoGk+8TAgDgXhEhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmelovAEgnPXr08DwTDAZTsJLkWLhwYUJzffv29Tzz+OOPe55ZsGCB55nf/va3nmfmzJnjeUaSrl696nlm5cqVnmfefvttzzNdBVdCAAAzRAgAYCbpESorK5PP54vbQqFQsl8GANAFpOQ9oaFDh+rvf/977ONEPs8OAOj6UhKhnj17cvUDALirlLwndOrUKeXl5amwsFAvvfSSTp8+/a37trW1KRqNxm0AgO4h6REaM2aMNm7cqJ07d+rDDz9UY2OjioqK1NzcfNv9y8vLFQwGY1t+fn6ylwQASFNJj1BJSYlmzZql4cOH67nnntP27dslSRs2bLjt/kuXLlUkEolt9fX1yV4SACBNpfybVfv166fhw4fr1KlTt33e7/fL7/enehkAgDSU8u8Tamtr05dffqlwOJzqlwIAZJikR+j1119XTU2N6urqdODAAb344ouKRqMqLS1N9ksBADJc0j8d9/XXX2vOnDk6f/68BgwYoLFjx2r//v0qKChI9ksBADJc0iP0ySefJPu3RJoaNGiQ55nevXt7nikqKvI8M378eM8zkvTII494npk1a1ZCr9XVfP31155n1q5d63lmxowZnmdaWlo8z0jSv//9b88zNTU1Cb1Wd8W94wAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAMz7nnLNexDdFo1EFg0HrZXQrTz75ZEJzu3fv9jzDf9vM0NHR4XnmZz/7meeZS5cueZ5JRENDQ0JzFy5c8Dxz8uTJhF6rK4pEIsrKyrrjPlwJAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwExP6wXA3tmzZxOaa25u9jzDXbRvOHDggOeZixcvep6ZPHmy5xlJunbtmueZP/3pTwm9Fro3roQAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADPcwBT63//+l9DcG2+84Xnm+eef9zxz5MgRzzNr1671PJOoo0ePep6ZMmWK55nW1lbPM0OHDvU8I0m//OUvE5oDvOJKCABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAw43POOetFfFM0GlUwGLReBlIkKyvL80xLS4vnmfXr13uekaSf//znnmd+8pOfeJ75+OOPPc8AmSYSidz1/3muhAAAZogQAMCM5wjt3btX06ZNU15ennw+n7Zu3Rr3vHNOZWVlysvLU58+fTRp0iQdP348WesFAHQhniPU2tqqESNGqKKi4rbPr1q1SmvWrFFFRYUOHjyoUCikKVOmJPR5fQBA1+b5J6uWlJSopKTkts855/Tee+9p2bJlmjlzpiRpw4YNys3N1aZNm/TKK6/c32oBAF1KUt8TqqurU2Njo4qLi2OP+f1+TZw4UbW1tbedaWtrUzQajdsAAN1DUiPU2NgoScrNzY17PDc3N/bcrcrLyxUMBmNbfn5+MpcEAEhjKfnqOJ/PF/exc67TYzctXbpUkUgkttXX16diSQCANOT5PaE7CYVCkm5cEYXD4djjTU1Nna6ObvL7/fL7/clcBgAgQyT1SqiwsFChUEhVVVWxx65du6aamhoVFRUl86UAAF2A5yuhS5cu6auvvop9XFdXp6NHjyo7O1uDBg3S4sWLtWLFCg0ePFiDBw/WihUr1LdvX7388stJXTgAIPN5jtChQ4c0efLk2MdLliyRJJWWluqPf/yj3nzzTV25ckXz58/XhQsXNGbMGO3atUuBQCB5qwYAdAncwBRd0rvvvpvQ3M1/VHlRU1Pjeea5557zPNPR0eF5BrDEDUwBAGmNCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZriLNrqkfv36JTT317/+1fPMxIkTPc+UlJR4ntm1a5fnGcASd9EGAKQ1IgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMNzAFvuGxxx7zPPOvf/3L88zFixc9z+zZs8fzzKFDhzzPSNL777/veSbN/ipBGuAGpgCAtEaEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmOEGpsB9mjFjhueZyspKzzOBQMDzTKLeeustzzMbN270PNPQ0OB5BpmDG5gCANIaEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGG5gCBoYNG+Z5Zs2aNZ5nfvSjH3meSdT69es9z7zzzjueZ/773/96noENbmAKAEhrRAgAYMZzhPbu3atp06YpLy9PPp9PW7dujXt+7ty58vl8cdvYsWOTtV4AQBfiOUKtra0aMWKEKioqvnWfqVOnqqGhIbbt2LHjvhYJAOiaenodKCkpUUlJyR338fv9CoVCCS8KANA9pOQ9oerqauXk5GjIkCGaN2+empqavnXftrY2RaPRuA0A0D0kPUIlJSX66KOPtHv3bq1evVoHDx7Us88+q7a2ttvuX15ermAwGNvy8/OTvSQAQJry/Om4u5k9e3bs18OGDdOoUaNUUFCg7du3a+bMmZ32X7p0qZYsWRL7OBqNEiI
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
2024-07-29 06:24:50 +08:00
}
],
"source": [
2024-07-30 00:44:16 +08:00
"plt.imshow(train_images[0].reshape(28,28),cmap='gray')"
2024-07-29 06:24:50 +08:00
]
},
{
"cell_type": "code",
2024-07-30 18:45:09 +08:00
"execution_count": 554,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-30 00:44:16 +08:00
"def batch_iterate(batch_size: int, ipt: list[int])-> list[int]:\n",
2024-07-29 06:24:50 +08:00
" perm = np.random.permutation(len(ipt))\n",
" for s in range(0, len(ipt), batch_size):\n",
2024-07-26 21:07:40 +08:00
" ids = perm[s : s + batch_size]\n",
2024-07-30 00:44:16 +08:00
" yield ipt[ids]"
]
},
{
"cell_type": "code",
2024-07-30 18:45:09 +08:00
"execution_count": 555,
2024-07-30 00:44:16 +08:00
"metadata": {},
"outputs": [],
"source": [
"def show_images(imgs:list[int],num_imgs:int = 25):\n",
" if (imgs.shape[0] > 0): \n",
" fig,axes = plt.subplots(5, 5, figsize=(5, 5))\n",
" \n",
" for i, ax in enumerate(axes.flat):\n",
" img = mx.array(imgs[i]).reshape(28,28)\n",
" ax.imshow(img,cmap='gray')\n",
" ax.axis('off')\n",
" plt.show()"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-07-30 00:44:16 +08:00
"### show first batch of images"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 18:45:09 +08:00
"execution_count": 556,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
2024-07-30 18:45:09 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAADeTUlEQVR4nOy953Oc55XmfXXOOSegkUECIEVCFGlJFmVLJef12LszrpqasF/2P9oPU1tTWzM1oXZqdj2e8ThKtigqUBQjCBJERqNzzjm+H/ieowYJZoBoAM+vikXZbADdN+7nPvdJ1xH1er0eBAQEBAQE9hDxQb8BAQEBAYGjh2BcBAQEBAT2HMG4CAgICAjsOYJxERAQEBDYcwTjIiAgICCw5wjGRUBAQEBgzxGMi4CAgIDAniMYFwEBAQGBPUcwLgICAgICe470WV8oEon2830cKl5E1EBYv68R1u/leFFRDWENv0bYgy/Hs6yf4LkICAgICOw5gnEREBAQENhznjksJiAgIPAqEIvFUKvVkMlkGBoagtfrRbfbRbPZRKPRwMbGBtLpNLrdLjqdzkG/XYHHIBgXAQGBgUImk8Fut0Ov1+NnP/sZfvzjH6PZbCKXyyGTyeBv//Zv8cUXX6DZbArGZYARjIuAgMBAIBKJIJFIoFQqYbPZYDab4XK54HK50Gw2oVQqIZfLYbFYYDQaUSqV0Gg0XrjA4SggEomgUqkgl8shk8mgUCjQ6XRQqVTQbrfRbDbRbrcP5L0JxkVAQOCJiMViiMVidLtddLvdffs5Go0GJpMJbrcbf/EXf4Hx8XGMjo5CrVZDqVRCpVJBp9Phrbfegl6vx+LiIr788ssDOzwHAblcjjfeeANjY2Pw+/2Ym5tDoVDA73//e0QiEWxubiIYDB7IexOMi4CAwGMRiURsXADsq3GRyWTQ6XSwWq2Ym5vD7Ows514AQKFQQCQSwefzoVarIR6P8/s6rkilUng8Hpw8eRKzs7O4ePEiUqkUAoEAut0uUqnUwb23A/vJAgICAw0ZFq1WC41Gg2q1ikKhsG8GRqlUwmq1wmq1QqPRQKFQQCp9cETV63VkMhnk83ncvHkTt2/fxtbW1rHLuZChNxqNGB8fh8lkwoULF3Dq1CmYzWaUSiWk02msr6/j/v37SKfTB/ZeBeMiICCwK2RcDAYDLBYLstksSqXSvhkXtVoNu90Om80GvV4PjUbD/1ar1RAOh5FIJPDll1/i888/P5YJfYlEAplMBofDgYsXL8LtduPdd9/FiRMnUCqVkM1mkUgkcP/+fdy5c+dA81GCcREQENiVXq+HXq+HZrOJarWKZrO5L4eVUqmEVCqF1WqF3++H1+uFQqEAADQaDTQaDSSTSaytrSEejyObzR47wyIWiyESiWCz2WC32zE8PIyhoSHY7XYAQKFQQDKZRCAQwNbWFsrl8r6GMJ8FwbgICAjsSq/XQ7vdRjqdRj6fR7vd3vMDXSKRwG63w2Qy4fz58/hv/+2/wWQywWKxoNfrIZlMIhqN4u7du/iHf/gHJJNJxGIxNBqNPX0fg4xIJIJcLodcLsfbb7+NDz74AHa7HadOnYJCoUAkEsGNGzdw9+5dXLp0CdlsFvF4/KDftmBc+hGJRKwfRInCTqdzqEod+5Ov9Hno1kN/A1/fStvtNrrdLnq93rG6CQo8O61WC61Wa0+/J+1NqVQKnU4Hs9kMm80Gt9sNrVYLhUKBXq+HSqWCdDqNeDyOYDCIVCp17MqPxWIxV8vZ7XaMjo7CbDbDYrFAIpFga2uLDUogEECxWES9Xj/oty0Yl350Oh00Gg1UKhVMJhM6nQ4CgQDy+fxBv7UnQv0BEokETqcTRqMRKpWKH1K73c7xbIfDgWaziUKhgHq9jnv37iEcDiOfzyMWi6HdbqPVah24Sy1wNKHLj0ql4tzKBx98gNnZWYyMjMBoNEImk6Hb7aLVamFhYQG/+93vEIlEkM/n0Wg0jt0lSKfT4f3334fP58Obb76J8fFxiEQi5PN51Ot1bGxsYGNjA1tbW4jH46jVant+GXgRBOPy/0PNSCaTCXq9HkNDQ2i1WhwSGHTEYjFkMhlsNhu8Xi+XdOp0OoyPj8NoNGJqagoTExOo1WpIJBIoFApsTEOhEHK5HBqNxrHuGxDYX8RiMSQSCd/CbTYbzp8/j7fffhsqlYqT+M1mE81mE+vr67h8+TIqlQpKpdKxMyzAg0KHs2fPYm5uDpOTk/B4PKhUKohEIigUCojFYgiFQojFYsjlcgNhWIBjYlzI/RaLxdBoNJDL5dDr9TCbzZBIJJDL5ZBKpTCZTDAYDGi1WpzAHOQbvEKhgFKphE6nw+joKHQ6HSYnJ+F2u6FUKmEwGKBQKGCz2aBSqSCRSFAsFtHtdqFSqSASiXDixAloNBoMDw/D6XSiUChgZWWFb4mHMbYtk8lgNBqhUCjgdDphMpkgl8uh0Wh2lU3vdDpsVHO5HMrl8jP9nHa7jWQyiWKxuOu/U+im1+vxPur/+ccttCMSiWA0GmE0GuF0OvGNb3yDL0NKpRISiQSdTge1Wg2rq6vIZDIIBAIol8vHLhQGAGazGT6fD263Gz6fDw6HAwqFAvV6HfF4HJ9++inS6TSWlpYQiUSQSqUG6rw6FsaFJCUUCgW8Xi9MJhMmJydx+vRpqNVqmEwmPqgVCgVWVlbw61//GtVqdaBv8VqtFjabDSMjI/jpT38Kp9OJ8fFxOJ1Ofph7vR6HuajxjEISFosFVqsVnU4HmUwG4XAY4XAY//iP/4jNzU2k0+lDaVzUajUmJiZgNpvx7rvv4vTp0zAajfB4PJBKpTvyTsDXPRS1Wg2Li4sIBAJP/RmdTgf1eh2XL1/G2trarq+hvFan00Gr1UKv19vxe6GD4KgfmnS5k0gkcLvdmJiYwIkTJ/BXf/VXsNvtUCqVkMlkHJJNpVL49a9/jfX1dSwsLCCTyXBe8Djh9/vxox/9CC6XC/Pz83C5XGi32ygWi7h//z7+1//6XwiFQqjVaizzMkie3ZEzLhQeEovFkEqlkEqlUCqVfHv1er0wGo0wm83sxYhEIr4xNRoNlEolVCoVVKvVgfplPYxGo4HVaoXdbuc/lGdpNBool8v8udrtNkqlEsrlMlQqFdrtNuRyOdRqNf9ttVrRaDRgNpuRyWRQKpUO+iO+EHK5HA6HA06nE263Gw6HAwaDATabDRKJ5BHvodFoQCKRoFarwel0PvVCQTIo9Xodw8PDj309FUs0m01UKhV0u13IZDJIJBJW9CUj039wtlotlMvlI3GYkmHR6/VQqVRwOp3wer1wuVwcgiY6nQ7K5TIKhQJSqRTi8TiKxeJAP4P7ATWPGgwGOJ1OWK1WSCQStNtt5PN5DoVls1nk8/mBzZEeGeNC1VBKpRI+nw9arRZOpxM2mw0Oh4O9FHq4c7kckskkKpUKgsEgqtUqMpkMh0Xi8TiazeYzh0heNSKRCNPT0/je974Hp9OJubk56HQ6FAoF7s799NNPUalUUC6XuVehWq1Co9HA5XJBp9Ph9ddfx8jICMxmMzweD+RyOc6ePQuz2Yxut4tYLHboDjmHw4E///M/x8TEBCwWC3Q6HV84gEc9BZlMxgedWq3G7OzsU38GGYWLFy+iUqns+u/VahW1Wg3ZbBarq6totVowGo1Qq9VoNBqo1Wockus/QAOBAOcZDjN0wTMajXj33Xfh8Xhw5swZnD59GjqdDjqdbsfrk8kk7t27h2AwiOvXr2NjY+PQr8HzIpVKMTIyApfLhQsXLuAb3/gGFAoF97F89dVXuHbtGuLxONLp9MAaFuAIGRdKFCoUCpjNZphMJvh8Pvh8Pni9Xpw/fx4ajYbzKKurqwgGgygUCggEAsjlcgiFQojH4ztCFoOKWCyGxWLB+Pg4bDYbrFYrFAoFUqkUx6q//PJLlMtlFItFNJtN1Go11Ot1qNVquN1uGAwG9uDUajU0Gg3a7TbsdjsajQa0Wu1Bf8wXQq1WY3JyErOzs2xIHg6F9SMSibhpT61WP/X7P5y3eZzxLZVKKJVKSCaTkMlkaDabsFqt0Gq1qNVqqFQq6HQ6j3j
2024-07-26 21:07:40 +08:00
"text/plain": [
2024-07-29 06:24:50 +08:00
"<Figure size 500x500 with 25 Axes>"
2024-07-26 21:07:40 +08:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
2024-07-30 00:44:16 +08:00
"X = batch_iterate(25, train_images)\n",
"for x in X: \n",
" show_images(x)\n",
2024-07-26 21:07:40 +08:00
" break"
]
},
2024-07-30 07:06:52 +08:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training Cycle"
]
},
2024-07-26 21:07:40 +08:00
{
"cell_type": "code",
2024-07-30 18:45:09 +08:00
"execution_count": 557,
2024-07-27 06:09:51 +08:00
"metadata": {},
2024-07-29 06:24:50 +08:00
"outputs": [],
2024-07-27 06:09:51 +08:00
"source": [
2024-07-30 18:21:38 +08:00
"lr = 2e-6\n",
"z_dim = 128\n",
2024-07-30 07:06:52 +08:00
"\n",
2024-07-27 06:09:51 +08:00
"gen = Generator(z_dim)\n",
"mx.eval(gen.parameters())\n",
2024-07-30 07:44:41 +08:00
"gen_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999]) #,betas=[0.5, 0.9]\n",
2024-07-27 06:09:51 +08:00
"\n",
"disc = Discriminator()\n",
"mx.eval(disc.parameters())\n",
2024-07-30 07:44:41 +08:00
"disc_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999])"
2024-07-29 06:30:08 +08:00
]
},
2024-07-27 06:09:51 +08:00
{
"cell_type": "code",
2024-07-30 18:45:09 +08:00
"execution_count": 558,
2024-07-27 06:09:51 +08:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-30 18:45:09 +08:00
" 0%| | 0/200 [00:00<?, ?it/s]"
2024-07-30 18:24:53 +08:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-07-30 18:45:09 +08:00
"Epoch: 0, iteration: 468, Discriminator Loss:array(1.31356, dtype=float32), Generator Loss: array(0.465569, dtype=float32)\n"
2024-07-30 18:24:53 +08:00
]
},
{
"data": {
2024-07-30 18:45:09 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9x4/kWXYdjp/w3nuXGZlZaSrLdVeb4YxmRFIURVFcC5QWMhsBWmgrB+hvELTVRhBACgK0ESEBIjgYkc1xzTbVXV0mK72NDO+9+cQnvovkufkip0lmRbXwA37IBzS6ujpNfN7nvXvPPffcew2z2WyGu3W37tbdult36ztcxv9ff4C7dbfu1t26W///t+6cy926W3frbt2t73zdOZe7dbfu1t26W9/5unMud+tu3a27dbe+83XnXO7W3bpbd+tufefrzrncrbt1t+7W3frO151zuVt3627drbv1na8753K37tbdult36ztfd87lbt2tu3W37tZ3vsy3/cLf+Z3fgcFggNPphMvlQiQSwfb2NsbjMT755BOcnJzA5/PB6/XCbrcjFovBZrPBbrfDarWi2Wwil8vBbDZjbW0Nfr8fFosFVqsVnU4HJycnGA6HsNlssFgs6PV6aDQaAACn0wmz+eqj6roOv9+P9fV12Gw2TKdTzGYzzGYz6LqO8XiMarWK0Wgkn13TNAyHQ/lZFosFo9EI/X4fRqMRdrsdZrNZPutoNEKr1YKmaZhMJtA0DZ1OB/V6HQDwp3/6p2+90f/hP/wH6LqOZrOJTqcDl8uFcDgMs9ksn79SqaBYLMLj8WB7exsul0t+v9lshtVqhdVqRTAYhN1uR7fbRbvdRqfTwdnZGTRNQyaTQTgcRqPRwOXlJex2O548eQK/34+f/exn+PnPfw6/348HDx7Abrej3W5jOBzC4/EgEAjIfnEvR6MRNE1Dv9/HbDaDyWSCwWBAr9dDu93GaDRCo9GA0WjEv/yX/xK///u/j3w+j88//xzD4RBerxc2mw0vX77EF198AU3T8D//5/986/375//8nwMAxuMxptMpbDYb3G43vF4vfuM3fgPpdBp/9Ed/hP/23/4bACAYDMLpdOLDDz/ExsYGyuUyDg4OYDAYkEwm4XK55PlMJhPsdjtMJhMcDgesViuq1SrOzs6g6zosFguMRiMMBgMAwOFwIBKJwGQyodVqod/vy1mfTCYoFArodru4uLhAoVCAx+NBNpuF2+3GxsYGotEoOp0Oms0m+v0+crkcRqMRstkskskkJpMJer0eRqMRarUaer2efL2u6/jiiy/eev8A4N/9u38Ho9GIbDaLRCKBs7MzfPHFF5hMJshkMvB4PKhUKigUCjAYDLDZbDCbzchkMohEIjAajTCZTJjNZhgMBvIe7HY7er0ezs/PMRwO5TxbrVa5u7zvpVIJxWIRkUgET58+hdPpxHA4lPvYbDYxHA5RLpcxGo2QTCYRi8VQr9exu7uLyWSCaDQKj8cDq9UKu92O2WwGTdMAAOl0GvF4HOPxGP1+H9PpFEbjFYYul8vI5XLQdR3/8T/+x7fev3/9r/81TCYTUqkUYrEY9vb28Cd/8icYDoeIRCJwuVzymex2OyKRCCwWCwBgNpuh0Wggn89jNpvB6/XCYrHAbrfDZrOhXq/j+fPnGI/H+Oijj3Dv3j10Oh1Uq1W43W786Ec/QiwWw49//GP85Cc/gclkgsfjgcViQTAYhMvlgsFggMlkgqZpaLVaGI/H6Ha76Pf7YssAwGw2w2AwzH3WeDwOp9OJR48eYX19HcfHx/jxj3+M8XiMe/fuIRgMIp/P4/z8HNPpFP/9v//3v3G/bu1ceMHi8Tji8TgMBgPK5TI0TUMymYTf75fLNxwOsbe3B13XsbKygkQiAbvdjkwmg8lkgnw+j7OzM3g8Hvh8PsxmM7jdbrjdblgsFphMJlgsFkynU3FoJpNJDonD4YCmaTAYDGD3GtXJuN1uOBwO+ez9fh/9fh+6rsvmcqNNJhMCgQBsNht0XZdL4fV65wxLv99HNBq97Xb9ylJ/tsvlQiAQQDqdhtVqhclkks/V6/XgcDig6zqm0ymi0Sj8fr8cFE3TUCgUMJlMYDAYYDAYoOs6gsEgdF3HYDDAxcUFLBYL0uk0LBaLHDaHw4FsNgufz4dUKiVOttPpYDabwWazAQAMBgOm0ymGwyEmkwmm06l8frPZDJPJhFAohFQqhclkglKpBE3TUCqV8JOf/AT1eh2np6eYTCawWCwCDD788MOF9y+RSGA2m8llmU6nqNfr6PV6eP78OXK5HC4vL2G1WmVfptMparWafBa/3w9d1zEajcSxDIdDOW8mkwk+nw9OpxO6rsPj8cBkMiEcDsNut8tnMRgMMBqN0HVdziWdMcGP2+2G0WiEy+USxzSdTuF2uxEKhdDtdlEqlTCZTGCz2eT8tVotTCYTcdyFQgHtdhterxfZbFbu2CIrEokAADqdjrzz+/fvYzabzd0lGmq+P4/Hg+XlZXS7XVSrVei6DrPZDIvFgsFggHq9jslkIn9PkMb3YLVa4ff7xYE7nU44nU5xooPBQPbO6XTCZrPBaDRiOp3CZDKh0+kIELRYLLJf/X4fxWIRFosF4XAYNpsNvV4PhUJBAAMANBoNDIdDWCwW3Lt3b+E91DQN0+kUvV4P9Xod0+kUiUQCk8kEbrcbVqsVPp8PwWAQ4/EY5XIZ0+kUsVgMwWAQk8kETqcT4/EYxWIRo9EIgUAAwWAQg8FAgG+r1cLJyQnMZjNcLhfsdjtqtRrG4zHsdrsAw0gkAoPBgFwuh1KpBAByRwOBgABSo9EITdMwGAxgNBqRSCTg9XrR7/fR7XZhNBrlvF1cXGA2m6HT6SCTyUDTNBiNRnS7XVgsFrmHt1m3di5msxlmsxmhUAiZTAatVktQQDgchsPhQK/XQ6/XQ6VSwfn5OXq9Hnw+HyKRiGz8YDDA6ekpyuUywuEwotGoIEEaWl5GRh/cdB5WGszZbCYOR0UvDodj7rLoug5d18UhqSiUCMDpdKLf78shNJvNYhwYzRCVLbJU4+xwOOD1ehGNRmG328WB9Xo95PN5WK1W+czhcBjLy8toNpvI5/Po9XqoVqtotVpwOp3wer0wGo3ipIvFIprNJmKxGJaWlmAymcSQ2mw2JJNJeDweRKNRGI1GXF5eiqMlyqITBq4v1HQ6lb8zGo0IBAJysXjR6/U6Pv/8cwwGAzQaDXkfALC+vo7t7W15X2+7wuHw3Ptm1AZcvcN8Po9SqQSr1QoA8v6bzSaAqzPk8/kAXBlX1bkQ8dFxcg+cTiesVqtcRoKXyWSCbreLyWQiZ4nOxWg0imNhVExgMJvN4HA44Pf7cX5+jlqtBgDw+/2wWq2YzWZyBhlRVioVNJtNeDweJBKJhfcPAAKBAGazGQqFAqrVKsLhMNbX12E0GlEqldDr9QBc3aXpdCpRs8vlQiKRQLFYRKFQECdpNpvRbDZRqVQAQCIbl8sFl8slaNlqtcLj8chdcrlcACD/n86FjgmARCTtdhvdblf2lg7GarWiXq+jWCzCbrcjGAzCZDJhMBig3+/D6XTC4XCIYWw2m0gmk0in0/KO33YR7NLR8X5Op1M5O8FgEMlkEo1GA4eHh+j1eggGg3A4HBgOh7Db7ZhOp6hWq2g2m9A0DVarVe6R2WxGp9PBcDhEMBiEz+eD1WpFq9XCYDCA1WrF2toa3G43UqkUdF1HqVRCvV4XG+dyueSe67qOyWSC8XgMTdNgsVgkQCiXy/L3g8EAw+EQxWIR0+kUFosF0WgUs9lMfrfZbEY4HL71ft3auTAy6fV6ODs7w2w2g91uFyNNw22xWOB0OhEIBASpjMfjuVDfZDJJpOJyuQTNmc1mlEolCel6vZ4gHZPJhEqlgmq1Cr/fj62tLfn5RqNRNknXdQyHQwnZLRaLGCV+rdFohMfjkdDUZrPBYDBgOByi1WrBZrPB5/PNoVK73Q632/32J/IvFyMpRgIMRy0WC7rdrtBPpKZGo5EcQpvNhmaziUKhIKjbYDCIw5/NZuj1etB1HQaDQagdOlL+7mazidFoBKf
2024-07-30 18:24:53 +08:00
"text/plain": [
"<Figure size 500x500 with 25 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-30 18:45:09 +08:00
" 25%|██▌ | 50/200 [03:59<11:45, 4.70s/it]"
2024-07-30 18:24:53 +08:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-07-30 18:45:09 +08:00
"Epoch: 50, iteration: 468, Discriminator Loss:array(1.02203, dtype=float32), Generator Loss: array(0.689288, dtype=float32)\n"
2024-07-30 18:24:53 +08:00
]
},
{
"data": {
2024-07-30 18:45:09 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz92XPjaXbeiT/Y940ASHBfc6t9U6m75e5oyW7LDkd4fOvbufaV/5D5F+Zi5mIiRqGZibA1lq2RWm6ru6urq7oqK7NyYZLJfQOIHSB2zAXnc/iCXZIymeX4RfyCb0RGViVB4Iv3Pe9ZnvOcczzj8Xis23W7btftul2363tc3v9fP8Dtul2363bdrv//W7fG5Xbdrtt1u27X975ujcvtul2363bdru993RqX23W7btftul3f+7o1Lrfrdt2u23W7vvd1a1xu1+26Xbfrdn3v69a43K7bdbtu1+363tetcbldt+t23a7b9b2vW+Nyu27X7bpdt+t7X/5XfeEf/uEfajQa6eLiQt1uV9PT07p3756Gw6F+85vf6OTkRH/4h3+on/zkJxoOh6pUKup0OioWi6pWq0qlUpqdnVW/39fz589VqVT06aef6oc//KFOTk70V3/1Vzo/P5fP55PH49Hs7Kzu3r2rdrutzz77TKenpwqFQgoGg5Kk8XisQCCgd999V0tLS3rx4oW++OIL9Xo99ft9SdLS0pLm5+cVDAYVi8UUCASUSCQUDod1cXGhWq2maDSqlZUVRaNR9ft9DQYDlUolPXv2TK1WSwcHB6rVavL7/QoEApKkzc3N197ocDgsSRoOhxqNRpqbm9MHH3wgj8ejly9fqlarKRKJKBqNKhgMKp1Oy+v16uDgQKVSSYFAQKFQSH6/X7FYTH6/X9VqVZVKRcFgUKlUSj6fT/V6XZ1OR/1+X91uV5Lk9Xrl8Xi0uLio+fl51Wo1vXjxQsPhUMvLy8rn8zo+PtaLFy/k8/k0MzOjcDisarWqer2uQCCgeDwuSapWq7q4uNB4PNZ4PJbH41EoFJLP51MkElEoFFK/39fFxYUk2TNfXFyo1WppPB6r1+u99v798R//sTwej5LJpO1RPB6Xz+ezz9/f39fOzo5isZjW1tYUiUTUbrfV6XQUDAYViUTk8/kUjUbl9/u1s7Ojra0tpVIp3b17197P6/Vqc3NTv/rVrxSLxfSzn/3M9q1Wq6lYLOrhw4dqtVrq9XoaDodaXV3Ve++9p3g8rvn5eYXDYT169EhPnjxRt9tVs9nUYDBQs9m0/Wk2mwoGg5qenlY4HNZoNNJ4PFYmk9HGxobtncfjUaVSUalU0mg00l//9V+/9v5J0r/9t/9Ww+FQW1tb2t/fVyqV0vz8vMbjsQ4ODtRoNLS+vq779+8rGAwqFAppNBrpiy++0IsXLzQzM6MHDx7I4/Ho7OxMrVZLqVRKyWRSnU5HR0dHGgwGymazSiQS9rndbld7e3tqNptKJpNKJpPyeDwmOx988IGWlpb05Zdf6j/9p/+kbrer0WgkSZqenlYul1MoFFIikVAsFtMPf/hDrays6Ouvv9YvfvELhUIh3b17V8lkUoFAQD6fT6enp3r48KE6nY683ksfejweazAYSJJ++9vfvvb+ZbNZeTwe5XI5pdNpDYdDDQYDjcdju9eBQMD0hCT5fD6trKyoUCjo5OREz54902AwMD03PT2tfD6v0Whkuuvo6EjValXBYFDRaFSS1Ol0NBqN5PV65fP5NB6PbY8Gg4FGo5FyuZwWFhYkSRcXF/ZeJycnCoVCymQy8ng8KpfLarfbCgQCCgaD8nq9ikaj8nq96vf76vf78nq98vv9Go1GajQa6na79npJ2tvb+0f365WNSyaTsQ3odDpqtVqqVqsKhUL68MMP5fV61e129ctf/lLBYFD5fF4+n08+n0+BQECtVkvPnj2Tx+NROBzWwsKCTk9P9R/+w39Qr9fTYDBQOBxWs9lUp9PR3NycFhYW1Gw25fF4dHFxoV6vZ8oEIzEYDFSv11Wv19VoNOT1ejU3N6dIJKJIJGIH1u12NR6PtbW1pV6vp0QioXQ6rVarpb29PY1GI01PTyubzarX6ymdTiscDqtWq9l3bjQaumm3HJ/PN/H3aDRSqVSSz+czoWy326pWq5qamtKdO3cUjUZVKpXU7Xbl8XgUiUQkSY1GQ6PRyBRnr9dTt9uV1+u15xuNRhoOhxqPxybIZ2dnajabGo1G8vv98vv9KpfLqlarkqRCoWDvy14mEgl5PB4TZp/PZ4onHA6bEHo8HqXTacXjcbXbbZVKJfX7ffV6PfV6PfuON10ohYuLC7tU5+fnCoVC2tjYUCaTUbPZVK1Wk8fjUbFYnNgPv9+vRCKh4XCo7e1tU/bxeFzD4VBPnz6Vz+dTLpdTIpFQvV43Ze/z+cxAnp2dqV6vy+PxmCIZj8f288FgoEajIY/Ho+FwqJWVFTvX0WikSCSiQCCgWq2mcrmscDislZUVRSIR7e7u6ujoSMlkUsvLywqHw2q32xoMBgqFQkqn0zeWP+lSyQ+HQ3W7XfV6PTWbTZ2entr+BgIBVatVPX36VH6/X9FoVB6PR4PBQFNTUwoEAjo/P7e9TKfTqtVqevnypXw+nxmUXq+nYrGo6elpra6uqtPpqFwum4LE2HOHW62WTk5O1Gg0TEYCgYDJ1mAwkNfrNeP8t3/7t/pv/+2/qVqtqlQqyev1qlarKRgMamlpSbOzs2q1WqbwUciNRkPVavXGe4giTqfTmp6eVrfbVb1el9fr1dTUlOmjbrdrMuPxeDQajXRycqJqtarBYGB/JNmd9Hg8kmTOl8/nU7fbNZ3DXcZR5B6y79FoVMPhUDs7OwoGg1pYWFAkEjFHtdfrqd1uazwea2lpSeFw2Bzm4XCoTqej4XCoer2uwWCgdDqt1dVVjUYjPXr0SKenpxqNRhqNRvas/9h6ZePCxSwWi+r1euZ5SdIHH3yg6elp/epXv9Jnn32mZDKpWCymRCIhr9drl+n4+FiBQEB3795VOp3W5uamNjc3FYlEtLCwYN4ylxOPTrq03Hg7eHfBYFD9fl+tVsv+hMNh5XI5ZbNZtdttU0bD4VC9Xk/7+/uqVCpaWlrS/fv31W639e2336rZbOqdd94xIUomkwqHw4pGo2q1Wmq326rVajcWTLwnr9drSg/B5AJ0u11TmOl0WplMxiKBQCBg373dbqvf76vT6ZjR5NARqOFwqOFwaJ/ver+hUEjZbFZer9cikWw2q/n5efX7fe3t7anVaimdTiuRSGgwGKjf75twBwIBxWIxi644l3w+r0wmo1qtpk6nM/HHjfxusriMeLW9Xk+tVkuRSER3795VIpGwPyiz4XBo0S4XsNPp6Pj4WAcHB5qentbMzIza7bZ2dnY0GAy0tram6enp31NOfr9f3W7XlKTH4zEDzZl2u111Oh1TIkSKrVbLnIpCoaBkMqlyuayTkxNFIhE9ePBAkUhE3W5XZ2dnisViKhQKCofD5mXG43GToZuufr+v4XBoEXq73Z74fijgWq02EeFxHySZEs9kMkokEqpUKjo6OlI8HtfU1JRCoZBOTk5Ur9c1MzOj+fl5XVxc6MmTJybrvV5Pfr9f4XBYoVBIFxcXGg6Harfb5ghFo1GLIl3j0uv19OjRIxWLRUWjUcXjcY1GIzWbTYukkQHkHxnl/G56h9mjeDyubDarZrOpbrcrv9+v2dlZZTIZNRoN1et1i+RxdCqViil35Hk8HqvT6ahSqZixYI88Ho96vZ7K5fKEQudvUBhQmampKZ2dnen4+FixWEwbGxvKZrMKhUJKpVKq1Wra3d3VeDzW7Oys8vm8yS/OAKhUs9lUIpHQxsaGRqORXr58aUgFf15pv151Y4vFom2GJFPiqVTKPFg3UuFSc9GlyxDX5/MZXNbpdCzUQtAWFhY0MzOjeDyunZ0dNRoNu8yxWMwECugmnU4rEoloampKhULB3guBkq4Uezgc1vz8vPL5vHlMeGXxeFyJREKBQEC9Xk/1et0MV7/f19TUlObm5l51u35vIegYAjwJQlzpMqqJxWLyeDza39/X+fm
2024-07-30 18:24:53 +08:00
"text/plain": [
"<Figure size 500x500 with 25 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-30 18:45:09 +08:00
" 50%|█████ | 100/200 [07:56<07:49, 4.70s/it]"
2024-07-30 18:24:53 +08:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-07-30 18:45:09 +08:00
"Epoch: 100, iteration: 468, Discriminator Loss:array(1.01599, dtype=float32), Generator Loss: array(0.691249, dtype=float32)\n"
2024-07-30 18:24:53 +08:00
]
},
{
"data": {
2024-07-30 18:45:09 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9aY9l53UdvO48z0PdmoeunpvdZJMUKcoxTSkUHScBAgOJDRjINyM/IX8kCPIpQYB8sWNBiWXJckSZpERRHJo9sMfqmodbt+48z9P7obJ27Xu6mqy+1cYLvG9toMFm9a1zz3nO8+xh7bX3Ng2HwyHO5EzO5EzO5Exeopj/376BMzmTMzmTM/n/npwZlzM5kzM5kzN56XJmXM7kTM7kTM7kpcuZcTmTMzmTMzmTly5nxuVMzuRMzuRMXrqcGZczOZMzOZMzeelyZlzO5EzO5EzO5KXLmXE5kzM5kzM5k5cuZ8blTM7kTM7kTF66WE/6QafTCZPJBK/XC7fbDZ/Ph3g8DofDAa/XC5vNhlQqhb29PZjNZrjdbgBAtVpFo9GAxWKB1WqFxWKB0+mE1WpFrVZDrVZDv99Hu93GcDiEz+eDy+VCr9dDp9PBcDhEr9fDYDBAp9NBt9tFOBzGtWvX4HQ6kcvlUKvVYDabYTKZ0O12kcvl0Gq10O120ev1YDabYbVaYbVaEQqF4Ha7Ua/XUalU4HQ6MT8/D5fLhVQqhWw2CzYtGAwGch3+vslkQrVafeGFfu211zAYDJBOp1EoFGCxWOBwOOByuXD16lXEYjHkcjkcHByg1+uh1Wqh3+/DZDIBALrdLprNJkwmE3w+H+x2O3q9nqxTuVxGv9+H0+mEzWZDJBLBwsICOp0Onjx5glKpBK/XK+8vkUjA4/HgwoULmJqawq1bt/D3f//36Ha7cDqdMJvNaLVaaLVaMJlMMJlMMJvN8Pv9cLlccLvdCAaDaLVa2NvbQ7PZhM1mg9Vqhc1mg8PhkPffbrfl5wCQyWReeP3+/b//9zCbzVhaWsL09LQ892AwAAAMh0NYrVb5XgAwmUwIh8Pw+/2yB/r9PsrlMtrtNrLZLDKZDCwWCzweDywWCwaDAQaDAXK5HLa2tuB2u/H9738fsVgMv//97/H73/8ekUgE169fh91ux8HBAUqlEpaWlvDaa6/BarViMBhgOBzK/rPb7QgEAjCbzej3+xgMBlhZWcHt27dht9uxtLQEr9cLs9kMs9mMTCaDhw8fAgAuXbqEaDQqzwsA//E//scXXj8AsjY8T3qdbDYbLBaL7Cktz2viYTKZ4HA44HQ60ev10Gg0MBgMYLVaYTabMRwO5XfNZrNci+vDP9zj+nt4nvnZ4777uJ9zD/b7fXS7Xbm+3sMAZC1fRP7Lf/kv6HQ6+NnPfobf/va3cv8WiwXBYBBOpxNerxd+vx+NRgN7e3tyjvnc/X4fNpsNExMTcLlcsNvtsNlsqNfr2NnZQafTQTweh9/vRzQalTP88OFDFItFuN1ueDwetNttlMtlmEwmLC4uIhqNYnt7G48ePYLFYsHU1BRcLhey2SwKhYI8g9VqxdTUlOxHq9WKZrOJ7e1tNJtNLC8vY25uDtlsFisrKxgOh1heXkY0GkUul0M2m8VgMMD+/v53rteJjQuNxXA4lAVrt9tyIACg1Wqh0WgAOFQqFosF586dw+TkJNLpNFZWVtDr9dDv92G320VRdjoddDod9Pt9ubbP58Pk5CS63S52d3flularFa1WC2tra7BarbLBut0uOp0O7HY7lpeX4XA4sLOzg729PTidTkQiEdnwrVZLjA6vbzabUa/XRYFzc7vdbpjNZgSDQUSj0ZMu1zNSKBQwHA7RbrdhMplgt9sRDAZhsViwu7uLZDKJZrOJer0um9ZsNmNubg6xWAzpdBpra2sADg+Q0+lErVaTdaFyonE+d+4c3nzzTdTrdRQKBbTbbfj9fgSDQfR6Pezt7QEASqUS/H4/UqkUhsMhbDYbAoEA7Ha7/B43odlsxmAwkJ9VKhV0Oh202210u105wBaLBRaLBSaTCR6PB06nE8FgEKFQCBaLZaz149rzPXU6HdRqNVitVpw7dw6hUAiNRgP1eh0A5HsKhQL29/fh8/kQiUTQ6/WwurqKcrmMRCKBV199FfV6HVtbW2g2m3A4HLBYLLDb7ZicnITVakW1WkWv18NwOEQikYDb7Zb9Ew6HEQqFMBgM8NVXX4mz4vV6YTKZxOA5nU5YLBYxtoPBQNbVZrPBZrOhUqmgXq+j3+9jcXFRlH6z2US73Uaz2Xyuoj+JaKPB/U2hUaUyN5vN8g4p/X4f/X5/5JraMeT75zW0UucfGjauDa9hNHbcb91ud+Q79bWosI33o42aPsvRaBSJRGLs9fvpT3+Kfr+P3d1dOJ1OOJ1OBAIBDAYDVCoVFItFNBoNVCoV2O12XLhwAVarFYVCQc5qsVjEcDhEv9+X98F19/v96Pf7cLlc4gifP38eABAOh1Gr1ZBOp5FOp+FwODA1NQUAqNfrqFarKBaLoltCoRB8Ph9qtZrsNRoym80Gu92OWq02coYHgwGy2Sza7TZ6vZ44Zd1uF4VCAV6vFxMTEyN74tvkxMaFXk+320W73Uar1UK5XBYPbTAYSHTCaMNutyMej+PmzZt48OAB7t+/j3a7LR6i0+mEy+WSlz8cDtHpdNDr9RAKhTA5OYl2u41kMolOpyNeSafTwe7uLkwmE/x+P5xOJ1qtFqrVKoLBIGZmZhCLxdBoNJBMJmWxrVarbABu8F6vh1KpJC+am5GenMPhgMPhQCKRwMLCghjSF5VKpQLgyGOy2Wzwer0YDodIJpPy7zw4VqtVPN7FxUX0+32srKzIOjudTlSrVbRaLdhsNvj9fnlOr9eLxcVFvPHGGyiVSvj888+RzWbh8Xjg9/tRqVSQyWTEe2cUNBwOYbfb4fV64XQ60Wg0RtaDxpmRFd89PXQaFSoWi8UCl8sFAIjFYpidnR17/QKBgNxLs9lEs9lEqVSC2+1GJBLB3NwcMpmMGDm73Y7BYIBUKoWDgwNEo1G43W60223s7u4inU4jkUjgwoULyGQyePz4sRhaRtaRSAQA0Gw2xYhzH1E5MBLe2dnBo0eP4PF4EA6HxSmxWCyw2WxwuVwwmUyoVCpyXhjJc6+1Wi3k83l4vV5MTU2JIW2322g0GqjVamOtHYWKmA4h9xrwbDQDYCQCoWhFTyWp/53nmAZKf4+OULg/aIy08N8Y6fE7+fu8rtHQ8Rl5PX6e3x0MBrG4uChG7UXlo48+AnDoaDscDvh8PkxNTaHX66FcLqNSqcgZiMVimJubQyAQwM7ODjKZDMxmM0qlkqzbYDCQdR8MBvB4PAAgDo7f78fs7CxsNhvi8Tg6nQ6+/vprpFIpcZhNJhM2NzeRy+XQ7XblvQUCAdnLNC40Kna7XfRoJpMZMcjFYhGlUgkejwfxeBwWi0V0fjgcxvLy8okdxBOvMjcQb5AKhj/jZ3hoCN2USiU8fvwYyWRSHt7n88Hr9cqD2mw2NBoN8YB7vR6azSYymcyI0uNGpefOTcPw1+FwwGQy4eDgAPV6HaVSSeC0crkMi8UiFtrj8cDr9aLb7eLg4ECUpX5WKk8AKBaL8pzjCDe6x+MRb5VKhi+fn6NHa7FYUCwWsb29Ld9PA1yr1WCz2RCLxWSzdjodtFotWCwWpFIp3Lp1SzwpRhyNRgO9Xg8ejwcOh2MEfjCZTOj3+6jVagKJ8Z60cdGwoclkgsvlknuj0mV0w88XCoURWOJFxe/3AwDa7bZEf4FAAE6nE/V6HdlsViBWq9Uq+8Pv9wtkVy6XMRgMkEgkEAwG4fF4UCgUUC6XxbNj9EzIktfiIcvn87BarfJvvV4P9XpdnCl6tA6HQ95ts9lEKpUSD5fReTAYFKiVkDAdGqfTCeBQmVNh2e32sfefFp5hvqvnRUMavgKOjIKGsowOpf6sjix0tAFA1ltHOvoPla/R8FC+K4LThoXXrFarSKVSY+9BnhWXyyUQM3UCYXO/349AIACfzyfQVaVSQblclmhmOByi2Wyi1+vB5/MhEAig3W4
2024-07-30 18:24:53 +08:00
"text/plain": [
"<Figure size 500x500 with 25 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-30 18:45:09 +08:00
" 75%|███████▌ | 150/200 [11:57<03:58, 4.78s/it]"
2024-07-30 18:24:53 +08:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-07-30 18:45:09 +08:00
"Epoch: 150, iteration: 468, Discriminator Loss:array(1.02523, dtype=float32), Generator Loss: array(0.686512, dtype=float32)\n"
2024-07-30 18:24:53 +08:00
]
},
{
"data": {
2024-07-30 18:45:09 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9Z5Ok53Uejl+dc87Tk8POJuwCC4AgSEIMMiXZClXWC1WpymW/8Gv7vausT+Ev4FSWq1S2LEu2fpRAkQRJECDSYnOYPNM9nXOO/xfj68zpxoKc7VmSgP5zqqZmd2a6+3nu575PuM51zjGMx+MxLuRCLuRCLuRCXqAYf9MXcCEXciEXciH/+OTCuFzIhVzIhVzIC5cL43IhF3IhF3IhL1wujMuFXMiFXMiFvHC5MC4XciEXciEX8sLlwrhcyIVcyIVcyAuXC+NyIRdyIRdyIS9cLozLhVzIhVzIhbxwuTAuF3IhF3IhF/LCxXzWP/wP/+E/wGQy4fr161hdXZ343Wg0wng8xuHhIQ4ODgAANpsNFosFq6urSCaT6HQ6qNVq6PV6KBQKaLfbKBQKyOVycLvdWF5ehsViQSqVQqlUwtzcHDY3NzEajXB0dIRGo4Fer4d+v49IJIJXX30VdrsdmUwGtVoNBoMBRuOJrRyPx/J/k8mEarWK3d1dGAwG3Lp1CwsLC9je3sbdu3dhsViwtLQEu92Ora0t7O/vw+/3Y3l5GVarFWxgwPcCgG9+85vPvdB2ux0AMBgMMBwOYTAYYDAYZP0AwGQywWQywe/349q1a7Db7Xjw4AH29/cnPp/XZLPZYLPZMBwO0Wq1MBqNYLVaYTab0e/30e12YTAYYLfbYTKZ0O120ev15JoMBgNsNhvMZjMGgwF6vR5MJhN8Ph8sFgtarRba7TYMBoN8drfbxXA4hM1mg8PhwGg0Qrvdxmg0gtPphN1ux2AwkJ9xb0x/Pa/8xV/8BbrdLv7n//yf+MEPfgCPx4NYLAaXy4WFhQX4fD643W54vV5kMhl873vfQ6VSwdWrV7G4uIhUKoU7d+7AZrPhG9/4BhKJBA4ODrC9vY1+v49Go4HhcIjhcIjxeIxgMIiFhQWMRiPkcjl0Oh2srq5idXUVHo8HCwsLMBqN2N/fR6FQQK/XQ6fTgdFohMPhgNFoxKNHj7C1tQWbzQa/3w+73Y5EIgGv14uDgwM8efIEoVAIf/iHf4hoNIrvf//7eP/999Hr9dBqtWC1WnH58mVEo1HUajWUy2WMRiM8fPjwudcPAP7tv/23shdtNhv29vbwk5/8BIPBAPF4HG63G7lcDplMRs6QyWTC0tKS/D4cDsPhcGBhYQEulws/+MEP8Pbbb8NoNMLr9cJoNKJSqaDVasmeMhqNcLvdsFqtSCQSiMfjSCaTeOutt2A0GvFf/st/wbvvvotXX30Vv/d7v4d6vY6//uu/xuHhIbrdLrrdLkwmk+xtl8sl+sVqtSISieA73/kOgsEgvve97+EnP/kJhsMhBoMBAMjrPB4P/H4/AOCDDz547vXT+uX/3+Usa3Bm4xKPx2EwGNBut7G3t4fBYIBOpwOz2YzFxUX4fD54vV4EAgEAJ4rSYDAglUrh6OgITqcTgUAAg8EAtVoN9XodFosFi4uLMBgMaDabMBgMCAaDiMfj8Hg8cLlc6HQ6aDabKJfLCIfDmJ+fBwA8ePAABoMBHo8HTqcT9XodpVIJJpMJkUgEdrtdNpXJZJKNZrPZJpToeDxGpVIRRbuysgK73Q6n0wmj0SjKkMp61o3Fz6chcTqdiEQiGI/HKBQKaLVaGI/HGA6HaDQa2NragslkQr1eFyM3Go3kwPPaer3ehALv9/sTSpJ/TwNCo8aDMv234/EYzWYTRqMRg8FArpefwf/TgBiNRrhcLphMJozHYwwGAzGg+rqA8x3KH//4xxiNRrBYLHjttddQr9eRyWSQz+dRr9fhdDphNpthsVjQ7/fF0PX7faTTaXS7XSQSCVgsFozHY7RaLbRaLfT7fVkzg8GAzc1NxGIxWad2uy3G5+DgANVqFVarFffv34fZbEY0GoXP50OlUkGxWAQA9Pt9GI1GWCwW2YuBQAAOhwPz8/MIBAJoNBoYj8dot9vY3t5GoVBAvV6Xtez3+zCbzfK1vr6OaDQqDsks0mg05N8GgwG9Xg+DwQDdbheFQgG1Wk2cidFoJM5Jp9NBt9uF0+mEzWYDADx+/Bj9fh/ZbBZOpxODwQClUglmsxm3bt3C6uoqnjx5gnfffRfdbhetVgudTgdWqxUGgwHVahXpdBrD4RA7OzswGo04ODjA//2//xeDwUD2vdFonPhus9lw5coVxGIxpNNpbG9vo9fr4e2334bdbsf+/j4GgwGMRiOcTicAiKGx2WyIRCKy959XLozK88mZjUs4HBbFU61WJRJxOByIx+Ow2WxwOp3weDzymtFohP39fWSzWfF8RqMRms0marUaIpEIYrEYOp0OcrkcxuMxotEoYrGYKAoqsXq9jmQyiWQyiVKphLt376Lf7+Pq1avw+/2oVCooFAqw2WyIRqPinVmtVlgsFlGcfE+TyQSXy4Ver4darYbhcIhgMCibz2w2yz2MRiPxTGfdYMPhcOL/drsdsVgM4/EYjUYD7XZblDcVn8FgkOunEQBOI5zhcIh+vz+hwLXHBmDCuBiNRlGajERoCIxGoxw63if/9lnRBq/HarXC4XDAarWi3W6j0+lMGKwXdSA/+eQTGI1GJJNJXL16FTs7O3j48KHsQ4vFgtFohOFwCJfLhcXFRdhsNrRaLRQKBVitVoRCIVgsFgAQZdfv98VA09m5du0a6vU68vm8rFG73Ua73UYmkxEFb7fb8dZbbyEej6Ner6PT6cg1GAwGmM1mBINB2O12BINBOBwOxGIxhEIhHB4eylofHR2hVCqh2WzCbrdjPB6LU8TvCwsLeP3112dWjADEcFgsFoluB4MB+v0+qtXqRDQ9Ho/R7XYBAL1eD71eD+PxGFarFcPhEAcHBygWi+h0OrDb7Wi1WvIcNjc38d3vfhdOpxMffvghut2u7Cmr1QqTyYROp4N8Po/BYACn0wmr1YpsNou9vT2YzWb4/X5xDsfjsRhrp9OJtbU1rK+vYzgc4v79+6hUKkilUuJ8cX8TLaATYbFY4Pf7Ze9fyK9WzmxcbDabeKJms1kUN5XfaDQShU3FNBqN4HK54Ha7Bb7hwSNMU6lUZJOPx2PUajUAJwfX5XKh3W7DbrfD7/fD4XDIpgmHwxgOh3A6nTCZTHA4HAgGg7DZbBL1DIdDdDodNBoNHB8fYzQaIRqNipFzuVziSVFRaoXIgwicwkGzKksqdm3wKpWKGAgqcS00Ns+KROiB00tzOBwATpQVFQGFa6uvn8+B/zeZTLDZbBiPx2IgtLLhd94Hr4se7nA4FE+Yh1zfx3k8bgByf5VKBY1GA6VSSeA+i8UiBtdiscDtdiORSMDtdst6MAKmg9Tr9eB2u7GysoJms4lcLofRaIROp4NSqST7we12Y2FhAV6vVwzMaDRCv9+HyWQSCKlcLosxYhTlcDjg9Xolqut0Otjf38fx8TG63S6uXLki60+HgJEK74cGJpfL4YMPPoDRaMQf/MEfzLSG9XodwImRMZvNKJVKE/tBPyeTyYRwOCxOCaMeno96vS77kvuA+2l/fx8ffPABDg8PYbPZxEmiUu/3+wAAj8cjOmU4HMJkMsHtdosRIMrAyJKQ7tHRkThR165dQ7PZxN7eHhqNhjiDhMG4htyb6XT6XAb6Qs4uhrN2Rb579y4ATMAdfIihUAgOh0MwVuBECY1GIxweHiKfz8PhcIhSz2QyaDab4jUCpzAav3s8HoRCIQnfx+MxIpEIIpGI5BjG4zGcTicsFouE3larFXNzc7Db7cjn8ygWi8jlcvj0008xGAywubkpkQ09m36/L0pS47tGoxF2ux1msxn5fF7C+H/xL/7Fcy+03W6HwWBAOByG3+9Ho9FANpsVxc/1fNbj0N6kyWRCIBCAzWZDs9lEo9EQj9hsNiObzaJarU68noeaBuFZ4nK54Pf7MRwOUS6X0ev1BD6
2024-07-30 18:24:53 +08:00
"text/plain": [
"<Figure size 500x500 with 25 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-30 18:45:09 +08:00
"100%|██████████| 200/200 [15:59<00:00, 4.80s/it]\n"
2024-07-29 00:18:35 +08:00
]
2024-07-26 21:07:40 +08:00
}
],
"source": [
2024-07-30 00:44:16 +08:00
"# Set your parameters\n",
2024-07-30 18:45:09 +08:00
"n_epochs = 200\n",
2024-07-30 00:44:16 +08:00
"display_step = 5000\n",
2024-07-28 06:10:19 +08:00
"cur_step = 0\n",
2024-07-30 00:44:16 +08:00
"\n",
2024-07-30 07:44:41 +08:00
"batch_size = 128\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-28 06:10:19 +08:00
"D_loss_grad = nn.value_and_grad(disc, disc_loss)\n",
"G_loss_grad = nn.value_and_grad(gen, gen_loss)\n",
"\n",
"\n",
2024-07-30 00:44:16 +08:00
"for epoch in tqdm(range(n_epochs)):\n",
2024-07-28 06:10:19 +08:00
"\n",
2024-07-30 18:21:38 +08:00
" for idx,real in enumerate(batch_iterate(batch_size, train_images)):\n",
2024-07-28 06:10:19 +08:00
" \n",
2024-07-30 00:44:16 +08:00
" # TODO Train Discriminator\n",
2024-07-30 07:06:52 +08:00
" D_loss,D_grads = D_loss_grad(gen, disc,mx.array(real), batch_size, z_dim)\n",
2024-07-28 06:10:19 +08:00
"\n",
" # Update optimizer\n",
" disc_opt.update(disc, D_grads)\n",
" \n",
" # Update gradients\n",
" mx.eval(disc.parameters(), disc_opt.state)\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-30 00:44:16 +08:00
" # TODO Train Generator\n",
" G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)\n",
2024-07-28 06:10:19 +08:00
" \n",
" # Update optimizer\n",
" gen_opt.update(gen, G_grads)\n",
" \n",
" # Update gradients\n",
2024-07-30 07:37:09 +08:00
" mx.eval(gen.parameters(), gen_opt.state) \n",
2024-07-30 07:44:41 +08:00
" \n",
2024-07-30 18:21:38 +08:00
" # if (cur_step + 1) % display_step == 0:\n",
" # print(f\"Step {epoch}: Generator loss: {G_loss}, discriminator loss: {D_loss}\")\n",
" # fake_noise = mx.array(get_noise(batch_size, z_dim))\n",
" # fake = gen(fake_noise)\n",
" # show_images(fake)\n",
" # show_images(real)\n",
" # cur_step += 1\n",
" \n",
2024-07-30 18:45:09 +08:00
" if epoch%50==0:\n",
2024-07-30 18:21:38 +08:00
" print(\"Epoch: {}, iteration: {}, Discriminator Loss:{}, Generator Loss: {}\".format(epoch,idx,D_loss,G_loss))\n",
2024-07-30 07:44:41 +08:00
" fake_noise = mx.array(get_noise(batch_size, z_dim))\n",
" fake = gen(fake_noise)\n",
" show_images(fake)\n",
2024-07-30 18:45:09 +08:00
" # show_images(real) likjmnh jy,t\n",
2024-07-30 18:21:38 +08:00
" \n",
" # print('Losses D={0} G={1}'.format(D_loss,G_loss))"
2024-07-27 05:19:08 +08:00
]
2024-07-26 21:07:40 +08:00
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}