mlx-examples/gan/playground.ipynb

675 lines
1.1 MiB
Plaintext
Raw Normal View History

2024-07-26 21:07:40 +08:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Import Library"
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 517,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"import mnist"
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 518,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"import mlx.core as mx\n",
"import mlx.nn as nn\n",
"import mlx.optimizers as optim\n",
"\n",
"from tqdm import tqdm\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
2024-07-30 07:17:12 +08:00
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 519,
2024-07-30 07:17:12 +08:00
"metadata": {},
"outputs": [],
"source": [
"# mx.set_default_device(mx.gpu)"
]
},
2024-07-26 21:07:40 +08:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# GAN Architecture"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generator 👨🏻‍🎨"
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 520,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def GenBlock(in_dim:int,out_dim:int):\n",
" \n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
2024-07-30 07:06:52 +08:00
" nn.BatchNorm(out_dim, 0.8),\n",
" nn.LeakyReLU(0.2)\n",
2024-07-26 21:07:40 +08:00
" )"
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 521,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"class Generator(nn.Module):\n",
"\n",
2024-07-30 18:21:38 +08:00
" def __init__(self, z_dim:int = 10, im_dim:int = 784, hidden_dim: int = 64):\n",
2024-07-26 21:07:40 +08:00
" super(Generator, self).__init__()\n",
2024-07-30 18:21:38 +08:00
"\n",
2024-07-26 21:07:40 +08:00
" self.gen = nn.Sequential(\n",
" GenBlock(z_dim, hidden_dim),\n",
" GenBlock(hidden_dim, hidden_dim * 2),\n",
" GenBlock(hidden_dim * 2, hidden_dim * 4),\n",
"\n",
2024-07-30 07:56:13 +08:00
" nn.Linear(hidden_dim * 4,im_dim),\n",
2024-07-26 21:07:40 +08:00
" )\n",
" \n",
" def __call__(self, noise):\n",
2024-07-30 07:06:52 +08:00
" x = self.gen(noise)\n",
" return mx.tanh(x)"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 522,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Generator(\n",
" (gen): Sequential(\n",
" (layers.0): Sequential(\n",
2024-07-30 18:21:38 +08:00
" (layers.0): Linear(input_dims=100, output_dims=64, bias=True)\n",
" (layers.1): BatchNorm(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
2024-07-30 07:06:52 +08:00
" (layers.2): LeakyReLU()\n",
2024-07-26 21:07:40 +08:00
" )\n",
" (layers.1): Sequential(\n",
2024-07-30 18:21:38 +08:00
" (layers.0): Linear(input_dims=64, output_dims=128, bias=True)\n",
" (layers.1): BatchNorm(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
2024-07-30 07:06:52 +08:00
" (layers.2): LeakyReLU()\n",
2024-07-26 21:07:40 +08:00
" )\n",
" (layers.2): Sequential(\n",
2024-07-30 18:21:38 +08:00
" (layers.0): Linear(input_dims=128, output_dims=256, bias=True)\n",
" (layers.1): BatchNorm(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
2024-07-30 07:06:52 +08:00
" (layers.2): LeakyReLU()\n",
2024-07-26 21:07:40 +08:00
" )\n",
2024-07-30 18:21:38 +08:00
" (layers.3): Linear(input_dims=256, output_dims=784, bias=True)\n",
2024-07-26 21:07:40 +08:00
" )\n",
")"
]
},
2024-07-30 18:21:38 +08:00
"execution_count": 522,
2024-07-26 21:07:40 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gen = Generator(100)\n",
"gen"
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 523,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-30 00:44:16 +08:00
"# make 2D noise with shape n_samples x z_dim\n",
"def get_noise(n_samples:list[int], z_dim:int)->list[int]:\n",
" return mx.random.normal(shape=(n_samples, z_dim))"
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 524,
2024-07-30 00:44:16 +08:00
"metadata": {},
"outputs": [
{
"data": {
2024-07-30 18:21:38 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAWOElEQVR4nO3ce2zV9f3H8VddsaXcxv2mrGVcSikyAUHaiaaAMiZzEDux27Lhls0lI8HpnMuyq0u2bMsuDHfLInO6i7JBcNqJdhDRtdOWKYJcC2JhtCW2lYLY0nLO7793suyPntcnmb9ffnk+/j7Pc5Ce8vL7zzsvm81mBQCApMv+t/8AAID/OxgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAAhPxcX/jwww/bb97V1WU3DzzwgN1I0vz58+2moqLCbk6fPm03L7/8st1MnTrVbiRp/PjxdlNSUmI3tbW1djNkyBC7kaRBgwbZTX5+zl/tcPz4cbsZPXq03RQWFtqNJHV0dNjN9OnT7SYvL89uWlpa7CblZyRJ8+bNs5t//OMfdrNs2TK72bJli91I0vLly+1m1KhRdvPJT35ywNfwpAAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAABCzhepCgoK7DdPOeA1ePBgu5GkTCZjN48++qjdnDhxwm6+9KUv2c3+/fvtRpJ6e3vtZt++fXaTzWbt5vbbb7cbKe0YY8p3L+UwYMoBwr6+PruRpBtvvNFuvv/979vNpz71KbtJOX75rne9y26ktAOJM2fOtJuUQ5ZLly61G0l6/fXX7Sbldz0XPCkAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAkPNBvKqqKvvNL1y4YDfXXHON3UhpR8a+/e1v282uXbvspqmpyW4mTpxoN1LasbA//elPdlNTU2M3mzdvthtJGjJkiN0sXLjQblKOkqV8xysqKuxGkh5//HG7STkw2dbWZjcPPfSQ3XzrW9+yG0m655577GbZsmV2M3nyZLtJOcQopR1WbGlpSfqsgfCkAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIOV9J3bp1q/3m9fX1dpN6HTTlKubJkyft5s9//vM78jmf+9zn7EaSent77Wb58uV2U1BQYDerVq2yG0lqbGy0m87OTrsZNmyY3ZSUlNhNe3u73UjSmTNn7OaLX/yi3Tz99NN2s2HDBrs5ffq03UhSeXm53aRcPE25Zvvggw/ajSStX7/ebrq7u5M+ayA8KQAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAICQl81ms7m88Ac/+IH95q+++qrdLFmyxG4k6ciRI3aTciQrxaOPPmo31dXVSZ81atQou3n++eftpr+/325SDs5J0syZM+3m+PHjdtPV1WU358+ft5tUdXV1drN58+b/wp/kP/X09NhNyqFIKe07/sILL9jN6NGj7ebFF1+0G0kqLCy0m+uvv95uvvCFLwz4Gp4UAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQMj5IN4vf/lL+83Lysrs5t5777UbSbp48aLdVFRU2M2UKVPs5vLLL7eb5uZmu5HSjrrNmTPHbm644Qa7ueeee+xGkqZNm2Y3VVVVdrNr1y67mTBhgt2MGDHCbiTpwIEDdtPZ2Wk3p06dspuxY8fazQc/+EG7kdK+rym/g08//bTdpB5IrKystJuHHnrIbh5//PEBX8OTAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAj5ub7w7Nmz9punNCmH1qS0Q3V9fX12k3KoLuUYV3Fxsd1IUlFRkd0sWLDAbu6//367KS8vtxtJysvLs5umpia7STngWFdXZzcrV660G0nq6Oiwm5tuuslu3nrrLbtJOYj32muv2Y0ktbW12c0jjzxiN3fccYfdtLS02I0kZTIZu0k5FJkLnhQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAyPkg3uLFi+03X7dund2sWrXKbqS0g1x79uyxm3379tlNTU2N3fT09NiNlHbE65vf/KbdtLe3282QIUPsRpIqKyvtJuUIYW1trd1cdpn//1UpRxUlac2aNXaT8nN66aWX7CblmOA///lPu5GkRYsW2U1FRYXdvPHGG3YzdOhQu5HSvkezZ89O+qyB8KQAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAg5X0ltamqy33zZsmV2s3btWruRpJ/97Gd2M3PmTLtZuXKl3Tz22GN2k81m7UaSjh49ajfjxo2zm5RrrF/5ylfsRkq7BtnY2Gg3KZeABw8ebDc7d+60G0k6e/as3RQVFdnNe97zHrvp6Oiwm5tuusluJKm1tdVu+vr67Gb06NF28+qrr9qNJJ0/f95u+vv7kz5rIDwpAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgJDzQbze3l77zZcuXWo3DzzwgN1I0rp16+xm4sSJdrNx40a7GTlypN3cdtttdiNJ+/fvt5vJkyfbzdatW+0m5bCdJF133XVJnSuTydjN7t277Sb1ENy0adOSOte2bdvs5tOf/rTdHDx40G4kacSIEXZTWlpqN7/73e/spqqqym4kafjw4XazZcuWpM8aCE8KAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIORls9lsLi+888477TefN2+e3fztb3+zG0nq6uqym5SDeIMHD7ab5557zm6qq6vtRpIWL15sN4cPH7abxsZGu5k0aZLdSNK5c+fekebqq6+2m7y8PLtpaWmxG0lqa2uzm5KSErtJOVSXclTx9OnTdiNJ48ePt5umpia7KS8vt5tLly7ZjSQVFBTYTUdHh93kcsiSJwUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQ8nN9Ycqhp4aGBruZPXu23UjSlClT7ObXv/613VRVVdlNUVGR3Rw5csRuJCk/P+cfaUg5rDV27Fi7GTNmjN1IaQfazp49azeLFi2ym/7+frtpbW21G0nK8Xblvxk+fLjdpBx9TDlu96Mf/chuJGnnzp12U1FRYTfHjh2zm0wmYzeS9MYbb9jNihUrkj5rIDwpAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAABCXjbH04tPPPGE/ebbt2+3m5Trm5J05syZd+SzUi59plxw7erqshsp7b9p0KBBdvOTn/zEbsrKyuxGkk6ePGk3w4YNs5vJkyfbzYIFC+zmwIEDdiNJLS0tdnPttdfaTXd3t9288MILdjNy5Ei7kaSenh67KS0ttZslS5bYTV1dnd1I0o4dO+zma1/7mt0sX758wNfwpAAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAABCzgfxqqqq7DdPOQT397//3W4kqbm52W7uvfdeu7n11lvt5qWXXrKbV155xW4kaerUqXaTyWTekc/56U9/ajeSNGvWLLuZMWOG3Tz55JN2s27dOrupr6+3G0kaP3683bz++ut2c+LECbu5++677Wbjxo12I0nV1dV2c/HiRbvZtGmT3dx22212I6X9vvf399vNtm3bBnwNTwoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAg5Of6wgULFthvPmrUKLvZu3ev3UjSr371K7s5cOCA3aQ
2024-07-30 00:44:16 +08:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"img = get_noise(28,28)\n",
"plt.imshow(img, cmap='gray')\n",
"plt.axis('off')\n",
"plt.show()"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Discriminator 🕵🏻‍♂️"
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 525,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def DisBlock(in_dim:int,out_dim:int):\n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
2024-07-30 00:44:16 +08:00
" nn.LeakyReLU(negative_slope=0.2),\n",
2024-07-30 18:21:38 +08:00
" nn.Dropout(0.3),\n",
2024-07-26 21:07:40 +08:00
" )"
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 526,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"class Discriminator(nn.Module):\n",
"\n",
2024-07-30 18:21:38 +08:00
" def __init__(self,im_dim:int = 784, hidden_dim:int = 64):\n",
2024-07-26 21:07:40 +08:00
" super(Discriminator, self).__init__()\n",
"\n",
" self.disc = nn.Sequential(\n",
2024-07-30 07:37:09 +08:00
" DisBlock(im_dim, hidden_dim * 4),\n",
" DisBlock(hidden_dim * 4, hidden_dim * 2),\n",
2024-07-26 21:07:40 +08:00
" DisBlock(hidden_dim * 2, hidden_dim),\n",
2024-07-30 07:37:09 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" nn.Linear(hidden_dim,1),\n",
2024-07-30 18:21:38 +08:00
" nn.Sigmoid()\n",
2024-07-26 21:07:40 +08:00
" )\n",
" \n",
2024-07-30 18:21:38 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" def __call__(self, noise):\n",
2024-07-30 18:21:38 +08:00
" return self.disc(noise)"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 527,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Discriminator(\n",
" (disc): Sequential(\n",
" (layers.0): Sequential(\n",
2024-07-30 18:21:38 +08:00
" (layers.0): Linear(input_dims=784, output_dims=256, bias=True)\n",
2024-07-26 21:07:40 +08:00
" (layers.1): LeakyReLU()\n",
2024-07-30 18:21:38 +08:00
" (layers.2): Dropout(p=0.30000000000000004)\n",
2024-07-26 21:07:40 +08:00
" )\n",
" (layers.1): Sequential(\n",
2024-07-30 18:21:38 +08:00
" (layers.0): Linear(input_dims=256, output_dims=128, bias=True)\n",
2024-07-30 07:37:09 +08:00
" (layers.1): LeakyReLU()\n",
2024-07-30 18:21:38 +08:00
" (layers.2): Dropout(p=0.30000000000000004)\n",
2024-07-30 07:37:09 +08:00
" )\n",
" (layers.2): Sequential(\n",
2024-07-30 18:21:38 +08:00
" (layers.0): Linear(input_dims=128, output_dims=64, bias=True)\n",
2024-07-26 21:07:40 +08:00
" (layers.1): LeakyReLU()\n",
2024-07-30 18:21:38 +08:00
" (layers.2): Dropout(p=0.30000000000000004)\n",
2024-07-26 21:07:40 +08:00
" )\n",
2024-07-30 18:21:38 +08:00
" (layers.3): Linear(input_dims=64, output_dims=1, bias=True)\n",
" (layers.4): Sigmoid()\n",
2024-07-26 21:07:40 +08:00
" )\n",
")"
]
},
2024-07-30 18:21:38 +08:00
"execution_count": 527,
2024-07-26 21:07:40 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"disc = Discriminator()\n",
"disc"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model Training 🏋🏻‍♂️"
]
},
{
2024-07-30 00:44:16 +08:00
"cell_type": "markdown",
2024-07-26 21:07:40 +08:00
"metadata": {},
"source": [
2024-07-30 00:44:16 +08:00
"### Losses"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-07-30 00:44:16 +08:00
"#### Discriminator Loss"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 528,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-26 21:36:29 +08:00
"def disc_loss(gen, disc, real, num_images, z_dim):\n",
2024-07-30 00:44:16 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" noise = mx.array(get_noise(num_images, z_dim))\n",
" fake_images = gen(noise)\n",
2024-07-27 06:09:51 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" fake_disc = disc(fake_images)\n",
" \n",
2024-07-27 05:19:08 +08:00
" fake_labels = mx.zeros((fake_images.shape[0],1))\n",
2024-07-30 18:21:38 +08:00
" \n",
" fake_loss = mx.mean(nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True))\n",
2024-07-30 00:44:16 +08:00
" \n",
2024-07-30 18:21:38 +08:00
" real_disc = mx.array(disc(real))\n",
2024-07-27 05:19:08 +08:00
" real_labels = mx.ones((real.shape[0],1))\n",
2024-07-30 18:21:38 +08:00
" \n",
" real_loss = mx.mean(nn.losses.binary_cross_entropy(real_disc,real_labels,with_logits=True))\n",
" \n",
" disc_loss = (fake_loss + real_loss)\n",
2024-07-26 21:07:40 +08:00
"\n",
" return disc_loss"
]
},
2024-07-30 00:44:16 +08:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Generator Loss"
]
},
2024-07-26 21:07:40 +08:00
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 529,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-26 21:36:29 +08:00
"def gen_loss(gen, disc, num_images, z_dim):\n",
2024-07-26 21:07:40 +08:00
"\n",
" noise = mx.array(get_noise(num_images, z_dim))\n",
2024-07-30 07:06:52 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" fake_images = gen(noise)\n",
2024-07-30 18:21:38 +08:00
" fake_disc = mx.array(disc(fake_images))\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-27 05:19:08 +08:00
" fake_labels = mx.ones((fake_images.shape[0],1))\n",
2024-07-30 00:44:16 +08:00
" \n",
2024-07-30 18:21:38 +08:00
" gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)\n",
" \n",
" return mx.mean(gen_loss)"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 530,
2024-07-30 00:44:16 +08:00
"metadata": {},
"outputs": [],
"source": [
"# Get only the training images\n",
2024-07-30 18:21:38 +08:00
"train_images,train_labels,*_ = map(np.array, mnist.mnist())"
2024-07-30 00:44:16 +08:00
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 531,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-30 00:44:16 +08:00
"# Normalize the images to fall between -1,1\n",
"train_images = train_images * 2.0 - 1.0"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 532,
2024-07-29 06:24:50 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2024-07-30 18:21:38 +08:00
"<matplotlib.image.AxesImage at 0x15ad80190>"
2024-07-29 06:24:50 +08:00
]
},
2024-07-30 18:21:38 +08:00
"execution_count": 532,
2024-07-29 06:24:50 +08:00
"metadata": {},
"output_type": "execute_result"
2024-07-30 00:44:16 +08:00
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaEAAAGdCAYAAAC7EMwUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAaI0lEQVR4nO3df2jU9x3H8dfVH1d1lytBk7vUmGVF202dpWrVYP3R1cxApf4oWMtGZEPa+YOJ/cGsDNNBjdgpRdI6V0amW239Y9a6KdUMTXRkijpdRYtYjDOdCcFM72LUSMxnf4hHz1j1e975vkueD/iCufu+vY/ffuvTby75xueccwIAwMBD1gsAAHRfRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJjpab2AW3V0dOjcuXMKBALy+XzWywEAeOScU0tLi/Ly8vTQQ3e+1km7CJ07d075+fnWywAA3Kf6+noNHDjwjvuk3afjAoGA9RIAAElwL3+fpyxCH3zwgQoLC/Xwww9r5MiR2rdv3z3N8Sk4AOga7uXv85REaPPmzVq8eLGWLVumI0eO6JlnnlFJSYnOnj2bipcDAGQoXyruoj1mzBg99dRTWrduXeyx73//+5o+fbrKy8vvOBuNRhUMBpO9JADAAxaJRJSVlXXHfZJ+JXTt2jUdPnxYxcXFcY8XFxertra20/5tbW2KRqNxGwCge0h6hM6fP6/r168rNzc37vHc3Fw1NjZ22r+8vFzBYDC28ZVxANB9pOwLE259Q8o5d9s3qZYuXapIJBLb6uvrU7UkAECaSfr3CfXv3189evTodNXT1NTU6epIkvx+v/x+f7KXAQDIAEm/Eurdu7dGjhypqqqquMerqqpUVFSU7JcDAGSwlNwxYcmSJfrpT3+qUaNGady4cfr973+vs2fP6tVXX03FywEAMlRKIjR79mw1NzfrN7/5jRoaGjRs2DDt2LFDBQUFqXg5AECGSsn3Cd0Pvk8IALoGk+8TAgDgXhEhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmelovAEgnPXr08DwTDAZTsJLkWLhwYUJzffv29Tzz+OOPe55ZsGCB55nf/va3nmfmzJnjeUaSrl696nlm5cqVnmfefvttzzNdBVdCAAAzRAgAYCbpESorK5PP54vbQqFQsl8GANAFpOQ9oaFDh+rvf/977ONEPs8OAOj6UhKhnj17cvUDALirlLwndOrUKeXl5amwsFAvvfSSTp8+/a37trW1KRqNxm0AgO4h6REaM2aMNm7cqJ07d+rDDz9UY2OjioqK1NzcfNv9y8vLFQwGY1t+fn6ylwQASFNJj1BJSYlmzZql4cOH67nnntP27dslSRs2bLjt/kuXLlUkEolt9fX1yV4SACBNpfybVfv166fhw4fr1KlTt33e7/fL7/enehkAgDSU8u8Tamtr05dffqlwOJzqlwIAZJikR+j1119XTU2N6urqdODAAb344ouKRqMqLS1N9ksBADJc0j8d9/XXX2vOnDk6f/68BgwYoLFjx2r//v0qKChI9ksBADJc0iP0ySefJPu3RJoaNGiQ55nevXt7nikqKvI8M378eM8zkvTII494npk1a1ZCr9XVfP31155n1q5d63lmxowZnmdaWlo8z0jSv//9b88zNTU1Cb1Wd8W94wAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAMz7nnLNexDdFo1EFg0HrZXQrTz75ZEJzu3fv9jzDf9vM0NHR4XnmZz/7meeZS5cueZ5JRENDQ0JzFy5c8Dxz8uTJhF6rK4pEIsrKyrrjPlwJAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwExP6wXA3tmzZxOaa25u9jzDXbRvOHDggOeZixcvep6ZPHmy5xlJunbtmueZP/3pTwm9Fro3roQAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADPcwBT63//+l9DcG2+84Xnm+eef9zxz5MgRzzNr1671PJOoo0ePep6ZMmWK55nW1lbPM0OHDvU8I0m//OUvE5oDvOJKCABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAw43POOetFfFM0GlUwGLReBlIkKyvL80xLS4vnmfXr13uekaSf//znnmd+8pOfeJ75+OOPPc8AmSYSidz1/3muhAAAZogQAMCM5wjt3btX06ZNU15ennw+n7Zu3Rr3vHNOZWVlysvLU58+fTRp0iQdP348WesFAHQhniPU2tqqESNGqKKi4rbPr1q1SmvWrFFFRYUOHjyoUCikKVOmJPR5fQBA1+b5J6uWlJSopKTkts855/Tee+9p2bJlmjlzpiRpw4YNys3N1aZNm/TKK6/c32oBAF1KUt8TqqurU2Njo4qLi2OP+f1+TZw4UbW1tbedaWtrUzQajdsAAN1DUiPU2NgoScrNzY17PDc3N/bcrcrLyxUMBmNbfn5+MpcEAEhjKfnqOJ/PF/exc67TYzctXbpUkUgkttXX16diSQCANOT5PaE7CYVCkm5cEYXD4djjTU1Nna6ObvL7/fL7/clcBgAgQyT1SqiwsFChUEhVVVWxx65du6aamhoVFRUl86UAAF2A5yuhS5cu6auvvop9XFdXp6NHjyo7O1uDBg3S4sWLtWLFCg0ePFiDBw/WihUr1LdvX7388stJXTgAIPN5jtChQ4c0efLk2MdLliyRJJWWluqPf/yj3nzzTV25ckXz58/XhQsXNGbMGO3atUuBQCB5qwYAdAncwBRd0rvvvpvQ3M1/VHlRU1Pjeea5557zPNPR0eF5BrDEDUwBAGmNCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZriLNrqkfv36JTT317/+1fPMxIkTPc+UlJR4ntm1a5fnGcASd9EGAKQ1IgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMNzAFvuGxxx7zPPOvf/3L88zFixc9z+zZs8fzzKFDhzzPSNL777/veSbN/ipBGuAGpgCAtEaEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmOEGpsB9mjFjhueZyspKzzOBQMDzTKLeeustzzMbN270PNPQ0OB5BpmDG5gCANIaEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGG5gCBoYNG+Z5Zs2aNZ5nfvSjH3meSdT69es9z7zzzjueZ/773/96noENbmAKAEhrRAgAYMZzhPbu3atp06YpLy9PPp9PW7dujXt+7ty58vl8cdvYsWOTtV4AQBfiOUKtra0aMWKEKioqvnWfqVOnqqGhIbbt2LHjvhYJAOiaenodKCkpUUlJyR338fv9CoVCCS8KANA9pOQ9oerqauXk5GjIkCGaN2+empqavnXftrY2RaPRuA0A0D0kPUIlJSX66KOPtHv3bq1evVoHDx7Us88+q7a2ttvuX15ermAwGNvy8/OTvSQAQJry/Om4u5k9e3bs18OGDdOoUaNUUFCg7du3a+bMmZ32X7p0qZYsWRL7OBqNEiI
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
2024-07-29 06:24:50 +08:00
}
],
"source": [
2024-07-30 00:44:16 +08:00
"plt.imshow(train_images[0].reshape(28,28),cmap='gray')"
2024-07-29 06:24:50 +08:00
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 533,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-30 00:44:16 +08:00
"def batch_iterate(batch_size: int, ipt: list[int])-> list[int]:\n",
2024-07-29 06:24:50 +08:00
" perm = np.random.permutation(len(ipt))\n",
" for s in range(0, len(ipt), batch_size):\n",
2024-07-26 21:07:40 +08:00
" ids = perm[s : s + batch_size]\n",
2024-07-30 00:44:16 +08:00
" yield ipt[ids]"
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 534,
2024-07-30 00:44:16 +08:00
"metadata": {},
"outputs": [],
"source": [
"def show_images(imgs:list[int],num_imgs:int = 25):\n",
" if (imgs.shape[0] > 0): \n",
" fig,axes = plt.subplots(5, 5, figsize=(5, 5))\n",
" \n",
" for i, ax in enumerate(axes.flat):\n",
" img = mx.array(imgs[i]).reshape(28,28)\n",
" ax.imshow(img,cmap='gray')\n",
" ax.axis('off')\n",
" plt.show()"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-07-30 00:44:16 +08:00
"### show first batch of images"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 535,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
2024-07-30 18:21:38 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAADWZUlEQVR4nOz915PcV3rfj7865xynpycnDgaJIAiGXZKbg2RLWrnkKgW7LN/4wv+M73xhl8t2OciutWRJ1mq/u8vd5YI5gMjA5Ng559z9u8DvHPYAAxIgJvQ0Pq+qKe5ienr6c+ac85zzhPej6vV6PRQUFBQUFA4R9Ul/AAUFBQWF4UMxLgoKCgoKh45iXBQUFBQUDh3FuCgoKCgoHDqKcVFQUFBQOHQU46KgoKCgcOgoxkVBQUFB4dBRjIuCgoKCwqGjGBcFBQUFhUNH+6QvVKlUR/k5ThVfR9RAGb8vUMbv2fi6ohrKGH6BMgefjScZP+XmoqCgoKBw6CjGRUFBQUHh0FGMi4KCgoLCoaMYFwUFBQWFQ+eJA/oKCgoPUKlUaLVa1Go1NpsNk8lEr9ej2+3S6XQolUo0m0263S7dbvekP66CwomgGBcFhafEaDQSCARwOp388R//MS+99BLNZpN6vU4ikeCv//qv2d7eJp/PUygUTvrjKiicCIpxUVB4SrRaLVarFa/Xy5UrV/jhD39IvV6nUqmwubnJhx9+SDabpVqtnvRHVVA4MZSYi4LCM/Bwvr9KpZL1EEpdhMLzjGJcFBSegV6vJ7/ggUFRq9VoNBrFuCg81yhusSFHpVJhNBrRarX4fD68Xu+Bm16lUqFQKNDpdGi1WnQ6HWq1GrVaDY1Gg8lkQq1WYzAY0Gg08uc6nQ7lcpl2u02n03kuAtidTod6vU65XKZYLJLP5wHkOI2Pj9NsNun1epRKJdrtNq1W62tX1isonEYU4zLk6HQ6QqEQDoeDP/zDP+QP/uAPUKsfXFj7XTj37t3j/fffp1wuk06nqdVqrK6usrGxgcViYXx8HIvFwsjICHa7HUBunrdu3aJQKFCpVKjVaif2rMdFo9EgGo1SqVS4f/8+IyMj+P1+JiYmCIVC/PN//s/J5/P8zd/8DdVqlUqlQiqVot1un/RHV1A4NhTjMsSIlFmbzYbL5WJiYoKzZ88eaFw6nQ67u7uUSiU0Gg2VSoV4PI5Op8NkMuH1erFarYyMjOB0OoEHxiWfz7O7u0u73abZbA6lcREuLvEl3F5qtZp2u02tVqPVaqFWqzEajYyMjOBwOPD5fFgsFtrttuIiU/hS1Go1er1+X7xOrKnTimJchhS9Xo/ZbMbj8fDWW28xPT3N/Pz8vtf0u2lCoRDf+ta3qFarJBIJKpUKLpcLm83GxMQE3/72t3E4HLhcLkwmEwDdbpdSqcT4+DipVIqrV6/y6aefHutzHjUGg4GxsTGsVitOp1OOwdzcHHa7nfPnzxMOh7FardJdaLPZ0Ov1BAIBxsfHSSQSJBIJWq3WCT+NwqAyMjLCD37wA9xuN2q1GrVaze3bt3n77bep1+sn/fG+FopxGVK0Wi0WiwWPx8OFCxdYWloiFAqhUqn2BZ/F//Z4PLjdbprNJolEgnK5TKPRoNFoMD8/z3e/+13cbreM38AXbjGbzUYqlWJzc3PojItWqyUYDOLz+RgdHWVkZITR0VHeeOMNHA4HBoMBnU4nNwR4UAej0+lwOp14vV4ajYb8noLCQbhcLr797W8zPj6ORqNBq9ViMBj43e9+pxiXfsQVT6fTMTMzw+joKOVymUwmQ71eJxaLUalUjuJXP/fodDo0Gg2hUIhz584RCoUIh8O43W5543gcwuVjNpvRaDRMT08DD05VZrNZVqX3v77X68mAdafTOdJnOwk0Gg02mw2n00k4HGZ+fh63243FYkGv1wM8ciMxGo2oVCr8fj+Li4uo1Wp0Ot1JfPxjRbgLfT4fNpsNh8OB3+9Hq9VKAyxoNptUq1UajQZbW1tkMhlardap3UifFeG+djqd0v1qMpnkeu52u6cuIeRIjItWq8XhcGCz2fhn/+yf8f3vf5+dnR0+++wzkskkb7/9tmJcjgAxIU0mExcuXOAv/uIv8Hq9zM/P43A4nuj0rNVqcblc9Ho9PB4Ply9fRqPRYDAYUKvVj8QOer0e9XqdWq02lAFrrVaL1+tlbGyMF198kddee01ulgC1Wo1ms0m73abRaKDVauX3Z2ZmpDzMP/zDP5zwkxwtKpUKnU6HwWBgaWmJ+fl55ufn+eY3v4nZbMZqtWIwGORri8Ui29vbZLNZfvrTn/LJJ59QLBalbM7zhsFgwO/3MzIyImWD7HY7RqORer1OvV5XjAvsP+15PB4CgQDNZpPR0VG0Wi1ut5tisShPvN1ul3a7vW/wVCqVDKQ+fGI+aLMTi7vb7cr3fJ60ncR4WSwW7HY7brcbr9eLx+PZ58oSCC2sgwr+xFiLcX745+DBplqtVsnn86TTaZlhdtoRc8poNGKz2faNo81mw2AwSIPa6XTI5XJUq1U574xGI3a7Hb1eL9/DarViNBrR6/W02+2hm5MqlQqDwYDb7cZsNhMKhaQL0e/3YzKZsNls6HQ6Oc/EpqnT6QgEAgSDQbRaLZVKZd8NWKS4DyvCnarT6eSX2MfEPDltRkVwJMbFZrNx4cIFRkZGmJ2dxev1yjTWbDaLVqtlY2ODeDxOJBKhUqkQi8X2ZUYYjUbp0/b5fJjNZvk9t9vN0tLSPldNNBplfX2dUqnE7u4ulUqFUqlEuVw+ikccKET9idFo5MyZM0xNTXHp0iWmp6exWCwYjcZ9r+90OjQajX2xF61WK908X0ar1aLdbnP37l0+/PBD0uk0165dI5PJsL29fSTPd1yoVCo5XmfPnuUHP/gBHo+HpaUl6QoT82p7e5tSqcSdO3eIRqNyA/V6vfz4xz/GbDZjNBpl3GViYgKtVksikaBUKp30ox4aIj4QDof5p//0nxIKhTh//jwTExNYLBbcbjcajUYmO/R6PWm8w+GwHK+lpSVu377Nb3/7W5rNJmq1mm63SyqVIpPJnPBTHh1ms1keYoQbLJfLUSgUyOfz8mb8tAeSg7wUx32oORLjotPp8Pl8BAIBHA6HdNW43W7cbjezs7NywtVqNbRaLel0ep9bRa/XyxNfIBDAarXK742MjHDhwgXsdjtarRaNRsPGxgbtdpt8Pi8Xb6PRkCel02r9nwRx8jEYDHi9XsbHxwkGg9jt9gPjLL1eT8ZI+sdHq9V+acqs+Ll2u00ymeTevXukUinu378vT/CnGZVKhV6vx2QyEQqFeOmll3C73YyPj2O1WqnX6zQaDYrFIvF4nFwux/LyMhsbG9jtdnw+H41Gg1qtJsfTaDRiNptxOByUy2Wy2exJP+ahotFoZPLCmTNnmJiY4IUXXmB0dHTf68RNWaBWq+X6np6exm63U6lUuHXr1r51W6lUyOVy+1QQhgmdTofFYpGGRa1W02q1qFQqNBqNr3Vz6/f6CB5WkjgOjsS4VKtVlpeXyefzXLp0SW5iItC/tLTEyMgICwsLpNNpqtUqyWRyX2BUuBT0er107QhsNhvj4+OyWlycOL1eL7VajZdeekkWuK2urpLP59nc3KTRaBzF4544IiXW4/HwjW98gzNnzhAIBB4JInc6HdrtNtFolPfee2/fCVrUZXxZXKbfDXTnzh2uX79OqVQil8ud6piLRqPBaDRiMBiYnZ0lFApx5swZmWKsUqmo1+vcvn2bu3fvkslkuH//PqVSib29PbLZLFarVRaSrq2t0ev1ZNzR7/fzxhtvyMLLYTAw4nB4/vx5Ll++zOjoKEtLS3i9Xmw2GwD1ep1SqUStVmN3d3dfBqLP5+Oll17CZDLhdDrR6/Xy/7daLTQaDZ1Oh/v377OysiIPMcMW8LfZbASDQbxer6xzSSQSrK2tEYvFnsqwiAOm2+3mzTffJBgMSjfvzs4Ov/nNbygWizQajWNxNR6JcREbeywW44c//CHtdlsWnZlMJs6dOwc88KeKr1qttu+BRVBULHyNRrPP6j68CY6Pj3PhwgW63S7NZpNms8n
2024-07-26 21:07:40 +08:00
"text/plain": [
2024-07-29 06:24:50 +08:00
"<Figure size 500x500 with 25 Axes>"
2024-07-26 21:07:40 +08:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
2024-07-30 00:44:16 +08:00
"X = batch_iterate(25, train_images)\n",
"for x in X: \n",
" show_images(x)\n",
2024-07-26 21:07:40 +08:00
" break"
]
},
2024-07-30 07:06:52 +08:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training Cycle"
]
},
2024-07-26 21:07:40 +08:00
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 536,
2024-07-27 06:09:51 +08:00
"metadata": {},
2024-07-29 06:24:50 +08:00
"outputs": [],
2024-07-27 06:09:51 +08:00
"source": [
2024-07-30 18:21:38 +08:00
"lr = 2e-6\n",
"z_dim = 128\n",
2024-07-30 07:06:52 +08:00
"\n",
2024-07-27 06:09:51 +08:00
"gen = Generator(z_dim)\n",
"mx.eval(gen.parameters())\n",
2024-07-30 07:44:41 +08:00
"gen_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999]) #,betas=[0.5, 0.9]\n",
2024-07-27 06:09:51 +08:00
"\n",
"disc = Discriminator()\n",
"mx.eval(disc.parameters())\n",
2024-07-30 07:44:41 +08:00
"disc_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999])"
2024-07-29 06:30:08 +08:00
]
},
2024-07-27 06:09:51 +08:00
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 537,
2024-07-27 06:09:51 +08:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-30 18:21:38 +08:00
" 0%| | 0/50 [00:00<?, ?it/s]"
2024-07-30 07:56:13 +08:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-07-30 18:21:38 +08:00
"Epoch: 0, iteration: 468, Discriminator Loss:array(1.33504, dtype=float32), Generator Loss: array(0.460247, dtype=float32)\n"
2024-07-30 07:56:13 +08:00
]
},
{
"data": {
2024-07-30 18:21:38 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9yW+kWXodDp+Y53lmMEgG55wzqzJralV1Vw8otWzZsgT3ygJsGPDCGwOGF174f/DCS3tjaGFJsABbgoyfWi2p1V2lqu7MyjmT8xyMeZ7HN74FdR7eYLd/YkbWhw/4wAsksopkBiPue+8znOc859GNx+MxrtbVulpX62pdrW9w6f9//Qau1tW6Wlfrav3/37pyLlfral2tq3W1vvF15Vyu1tW6Wlfran3j68q5XK2rdbWu1tX6xteVc7laV+tqXa2r9Y2vK+dyta7W1bpaV+sbX1fO5Wpdrat1ta7WN76unMvVulpX62pdrW98XTmXq3W1rtbVulrf+DJe9gf/43/8jxiNRnj69Cm2trYQj8dx9+5d2Gw2GI1G6PV62O12OBwO9Pt9FAoFDIdDWK1WGI1GbG5u4ic/+QmMRiO+9a1vIR6Pw263w+l0wm63IxqNwmg0YjQaYTQaodfrodVqYTwew2AwAACazab8yefz0DQNy8vLCIfD0DQNo9EIer0eVqsVBoMBbrcbDocDmUwGv/jFLzAajfD+++9jfn4enU4HrVYLBoMBXq8XBoMB+/v7OD4+Rr1eRyqVQr/fh8FggE6nw0cffYTPPvsMJpMJ9+/ff+ON/t3f/V0AgMvlgs1mg9lsht1uBwDU63UMBgOsra3h1q1bqFarePToEZrNJlZXVzE7O4tHjx7hD//wD6FpGt555x1EIhFYLBbZf7fbDaPRCKvVCpPJhEqlgkwmA7vdjvfeew/BYBCvX7/GxsYG8vk8njx5Ak3T8PHHH2N1dRXNZhPVahWBQAD/+B//Y8TjcaRSKaTTaWSzWTx58gSj0Qg3btxALBZDKpXC7u4u+v0+2u02RqMRut0uer0ezGYznE4nxuMxKpUKut0url+/jgcPHsBoNOL3f//333j//v2///fQNA2bm5s4PDxELBbD7du35VkbDAZ4PB74fD5UKhU8fvwY7XYbN2/eRDKZxKNHj/BHf/RHMBqN+M3f/E0kk0l4vV74fD4MBgNUq1UMh0P0+32MRiP5b4PBAJfLBbPZjFarhXa7jVarhUKhgPF4jMXFRYRCIfR6PfR6Pbjdbrz//vvw+/342c9+hq+++goGgwEWiwV2u13OH/d7MBig1WrJnSmVShgOh+h2uxgOhygWi2i320gmk7h27Rr0ej3+w3/4D2+8fwDwb//tvwUAVKtVtFot6PV6mM1mGAwGBAIB2Gw2BAIBRCIR1Go1vHjxAt1uFzdu3MD8/Dw2NjbkDn/22WdYWFiA2WyGxWLBeDyWfTs9PUW5XEYgEMDs7CwMBgPG4zE0TUMmk0E2m0Wr1UI+n8d4PEYikUAgEEC/30e324XNZkMymYTL5UK9Xkej0UC1WsX+/j50Oh3u3LmDeDyO4XCIwWCAfr+ParWKXq+HXC6HYrEIq9UKj8eDbreLn/zkJzg4OMDv/M7v4F/9q38Fi8WC73znO2+8f7/9278NvV6PeDyOSCQCg8EAs9kMABiNRhiPx3IOAECn0wGA2JCNjQ385V/+JYxGI7797W8jkUjIvpjNZng8HlgsFkSjUXi9XjSbTZTLZQCA0+mEwWDA48eP8fjxY/T7fbRaLZhMJnz88cdYWVlBJpPB/v4+/H4/fvSjH2Fubg6np6fIZrMoFot4/fo1RqMR1tbWEA6HUa1WUSgU0Gg0sLW1hWaziXA4DL/fj263i2q1in6/j2aziV6vh5WVFdy4cQMGg+FSZ/DSzqXVakGn0+HWrVu4d+8eAECv10PTNPR6PQwGAxiNRjidTphMJng8HmiaBpfLBavVitFohHq9Dp1Oh/n5efh8Pnnjmqah2WzCZDKhUCigVquh2+2i2WxCp9PBbDZDr9ej0+mg3W7DZDIhFArBZDLBZrNB0zRYrVY4HA7odDqMx2OMx2NkMhlUq1V0Oh0AgMViEUfV7XZRLBZhMBgwGo1gNpsRDAYRDofRaDQQj8cxGo1gsVhgMpng8/mQTqeh1+unci7Hx8cwGAxYXFxEIBCAx+NBLBaDpmk4PDxErVaTSwlAnB4/t91ux40bNzAejxEIBGC1WjEcDlGtVmE0GtHtdmEymcR59ft96PV6tFot/M3f/A0AyGU0mUy4du0aNE1Du93G69evYbVa4XK5oNfrkU6n0el0kM/nUSqV0O/3MT8/j/F4jH6/j1QqhXa7DY/HA7PZLI6uUCigXC5D0zQMh0MAQDQahU6ng9VqxenpKfT66ZLl4XCI8XiMWCwmv7ff70On0yESicDhcMBqtcJsNsNqtcLr9cJisaDf7yOXy8HhcOD73/8+AMBms4nz7XQ6MJlM8Hq9MJlMGAwG0DQNer1ezkaxWISmaSgWiyiXyzAajRIgNBoNDIdDeL1ezM7Owmw2o1KpoF6vi8Mym83w+Xyw2WzodDrI5XKoVqsoFovo9/toNBoYjUawWq2IRqPiXMbjMeLxOADI5512/wCg0+lAp9PB6/UiGo1Cr9dLYEgDOBqN5H3H43GMx2P4fD65Hzdu3JA7n0qloGkaNE2DyWSCw+EAAJTLZXGcfEbdbleCxl6vB71ej1gsBr1eD51Oh2q1CqfTifn5eTnzvO+dTgeapmFmZkbud6VSkYBSp9PB5XLB7XbD7XZjfn5eftdgMMAPfvADMZxbW1vQ6/VTORe/3y/BKwBYrVYEg0F5v4PBAM1mE4PBAACgaRp0Oh3sdjvsdjuSySTef/996PV6zM/Pw+/3I5VK4fj4GFarFeFwWILF8XiMRqOBQqEgTns4HOLk5ASdTgcOhwOLi4uwWCxot9t4+fIlbDYblpeX4XK50Gw2cXp6ikKhgEqlgtFoJM9Tp9OhVquh0Wig3W7DaDTi3XffhU6nQ6fTQbfbhcViERteq9XQ6/VgtVpRKpXEaf5D69LOpdvtwmAw4Nq1a1heXkapVMLOzg46nQ4ajQa63S4GgwH0ej30er0Yeq/XK1EsLww3sVQqoVarAQB6vR6GwyGy2Syy2eyEc7HZbDCZTOh2u2i32wgEAkgkErDb7fIgLBYL/H4/AMiDLpVK2NragsFggNPpFCel0+nQ6/VQrVblYlmtVszNzWFmZgatVgt+vx+aponRojHQNO2NDiRXNpuF0WiU9+3xeDAzMwNN01AoFNBsNpHL5VCpVOB0OrGysgKn0ykX02azYWlpSYyQwWAQ56zX69HtdmE0GuXBDwYD+bevXr1CuVxGPB5HLBaDxWLBwsICxuMxjo6OcHp6img0Cr/fD51OJ++nWq2iXq/DYrEgFosBAE5PT1EsFmE2m+FwOOD1enH79m14vV4cHh6Kwa5WqwAAr9crhzKXy2FaKTs6q0AggFgshn6/j1qtJufD5/NBp9NJ9uZ2u8VZlMtl2Gw2vP/++xJZNxoNnJ6e4vj4GB6PB9euXYPdbsdoNJowlpqmSYDCLM7tdmNpaQlGoxGtVgutVgtutxvhcBg6nQ71eh29Xg+NRgOapkkWbbVa0ev1UC6XUSqVkM1m0e/3Ua/XoWkaFhYWEAwGxbnw/litVlQqFRSLRdmHaRZfMxQKIRKJyNkfj8dyZzRNQ6PRmAjgmBV7vV4sLS1JttpqtdDtdtHpdGCz2cRhMdvo9XriGJrNpjhaq9UKq9WKUCgEvV6PUqmEer0Or9eLSCSC0WiEdDqNdruNfr+PXq8Hi8UihpzvkUGkyWRCIBCAxWKB2WyG2WxGs9lEOp2GpmlIJpOwWq3IZDI4PDyc+gx6PB4AZ0GqTqcTm8PPTGfGoIS/x2azwev1ijMEzoIum82Gg4MDpFIpWK1W6HQ6OBwOCdAZoPT7fUEA+LfX68XKygrMZjM2NzeRTqexsrKCa9euwWq1otPpyL9Tnydwlig0m0202230ej04HA6sra3B6XTK+zGbzXC5XAAAh8Mh56NWq11
2024-07-30 07:56:13 +08:00
"text/plain": [
"<Figure size 500x500 with 25 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-30 18:24:53 +08:00
" 20%|██ | 10/50 [00:48<03:09, 4.73s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 10, iteration: 468, Discriminator Loss:array(1.31833, dtype=float32), Generator Loss: array(0.469772, dtype=float32)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9SY/kWXbmBz82z6Ob+TzGPGRWVjLJItUNstSUKHQvBKEFCIK+Qq/0GbTWutda9UobUYCGhsBuNiiSqqqsrIzMjIzBPXx2NzO3eZ7tXXj/jl9zVqs9LPLFC7zwCwQyM9Ld7P+/99wzPOc553hms9lM9+t+3a/7db/u10+4vP+/foD7db/u1/26X///t+6Ny/26X/frft2vn3zdG5f7db/u1/26Xz/5ujcu9+t+3a/7db9+8nVvXO7X/bpf9+t+/eTr3rjcr/t1v+7X/frJ171xuV/3637dr/v1k69743K/7tf9ul/36ydf98blft2v+3W/7tdPvvx3/cF/8S/+hSTp/Pxc5XJZGxsb+vLLLzWZTPS3f/u3Oj09VTQaVSwWk9/vVyQSUSAQ0Pr6upaWlpRMJrWysqLxeKyjoyM1m01Np1NJUrfb1cXFhfr9vkajkcbjsX3vbDbTeDzWbDZTJpNRJpNRLpfT559/Lr/fr2+++UYnJydqtVqq1WoaDoeqVCoaDodKJBJKJBJKpVLa2dlRKBRSv9/XeDzWxcWFDg8PlUql9J/8J/+JlpaW5PP55PV6lUgktLq6Kr/fr8FgoPF4rHK5rIuLC0nSv/yX//KjN/q//+//e3m9Xq2uriqbzapQKOjVq1eazWba3d1VKpVStVpVuVwWTRO8Xq+SyaSi0agqlYqOjo4UCAT0+eefK5fLaTAYqNfrKRAIKBaLaTqd6tWrVzo8PNSTJ0/0x3/8xxoOh/r6669VLpe1tbWlzc1NeTwe2/twOKxAIKDJZKLBYKBIJKLd3V37zlqtpnK5rDdv3mgwGCgajSoUCsnj8UiSxuOxGo2GRqORtre3tbq6qul0qtFopMlkolarpcFgoE6no3q9rtlspn/9r//1R+/ff/1f/9eSpH6/r36/r3A4rEQiIZ/Pp1gspkAgoNFopNFopGazqffv36vf72t1dVVLS0uKRqNaWlqSx+NRt9vVeDxWJpNRNptVr9dToVDQeDxWOByW3+9Xu91WqVTSYDBQuVxWv9/X1taWtra2FIvFtLa2Jo/Ho+PjY5XLZfl8Pvn9fo3HY9Xrdfu9er2ucDisdDqtcDis5eVlJRIJjcdjjcdjBYNBbWxsKBQKqVqtqtlsKhgMKp1OS5Lq9br6/b7i8bji8bi8Xq/+x//xf/zo/ZOk/+6/++/k8XiUSCQUjUZNzvv9vv7qr/5Kx8fH2tra0u7urhKJhLa2thQIBHR2dqarqyv7nGg0qj/6oz/S8vKyDg4O9O7dO/X7fdXrdU0mE5MpVigUss/s9Xrq9XqKRCJaXl7WbDbT999/r/Pzc62urmpvb0+hUEiJREJer1d/93d/p9/85jfyeDx2P0OhkPx+v+LxuNLptGazmfr9viaTiTqdjrrdrvL5vD777DN5vV69e/dO1WpV0+nU5P5//p//54/ev7/4i7+QJCUSCcViMcXjcZP3i4sLtVotDYdDjUYjeb1e04EbGxvK5XKqVCr68OGDIpGIfvnLX2pzc1M//vijXr9+reFwOPcO3LV0Oi2Px6PhcKjZbKalpSVls1lFo1Gtrq5Kkr799ludnZ0pHo9raWlJo9FIFxcX6vV6dieGw6FqtZqm06nC4bCCwaCWlpa0urqq4XCoQqFg95/939jYkMfjUblcVqfTkdfrtXv/r/7Vv/qP7tedjctwONR0OtVsNpPX61W329WHDx/k8XgUjUa1tbVlhx8IBJRMJhUIBNRqtVStVhWNRnV5eanRaKQPHz6o0WgonU4rnU5rPB7L6/XK7/er2+2q0+koGo0qmUzK7/crFArZZ3s8HjWbTb169cr+PRQKqdVqqd1uy+v1amNjwy76eDy2zwgEAnaA8XhcOzs7CgaDarfbGo/H6vV6GgwGymaz9nvNZlODwcCU2qLL6/VqNpvp6OhI7969s3f2er2aTCZzxhABcwVheXlZ2WxWPp9PqVRKgUBAiURC4XBYPp9PoVBI0+lU4/FYuVxOfr9fx8fHGo1GCgQCymQy8nq9arVaCgQCikajdo6j0UihUEixWEwej0cXFxfyeDwaDAYaDAaaTqdaWlrSdDo14zKZTDQejzUajcyY9Ho9nZ+fS7p2ClwjE41Glc1mTTg/dmWzWU2nU/V6PYVCIXU6Hb17905+v197e3tKp9OKx+NKJBJqNBqq1+vqdrvK5XJaWlqyfcZgIx+lUkmTyUR+v19er1ftdluj0UjRaFTPnj3TcDjU5eWlOp2OOUaxWEySzChHIhF1Oh2Vy2VzTtLptCKRiNLptDlbfr9fwWBw7lyn06mKxaJms5l8Pp+CwaDJ+3A41OHhoQqFgpaWlrS8vCyfz7ewDE4mE3k8HlM45XJZhUJB0+nU9q/b7erNmzdKJpOaTqeKRCKqVCp2z+LxuKbTqb799lt7n3w+r16vZ+eN3MRiMaXTaQUCAQ2HQ7VaLftuSSqXy5pMJqrX62q1WppOp+p2u7ZfPp9Pk8lEz58/n/t97stsNlMwGJTf79fS0pJms5nevHmjQqFgss4/G42GIpGI4vH4wvsXiUTk8XjMuLXbbX3//ffyer1Kp9NaWVnR1dWV3bF0Oq1gMKhKpaJKpSKv16tMJqNgMKhisah2u62DgwOdnJxIunb0JJkx4DNns5k5OOwVxtfn82kwGEi6dtIHg4Ht6Wg0ss/A+cFIezwehcNhRaNRTSYTJZNJ05fT6VR+v990JU5/OBw2HXGX9VHGZTabaTabmUd/fn4uv9+vZDJpD4ei4lIdHR2pWCwqGAwqFotpMBhof39fzWbTvGiPx2NGiZcJh8OKRCLm9aFQ2u22Op2OXYpUKqVIJCLp2qtF2BOJhHnePp9PgUBAfv/166IQYrGYZrOZer2e2u22arWams2mWq2W4vG4/H6/KpWKOp2OAoHAJ11s1uXlpS4vL83zCAaDkq4jgFQqpUwmY9HXaDQyo5pKpUw593o9+/nl5WUTsul0Ko/Ho1wup2KxqIODA1MQqVRKHo9nznB7vV7V63W1222l02mlUinNZjOL/NzF72NcuAD8IYpqNBpzAoyRSaVS2tjYkNe7GBKbSCRMAft8PjUaDZ2dncnr9SqfzysejysSidieotTS6bSSyaRFZnxWJBJRqVQy+YhEIvJ6vXaBcT6m06l8Pp+azaZOT091cXGhZDKpWCxmvxMKhVSv11UulxUKhbS8vKxkMmn7LmlOziUpGAwqmUyq3++rWCyq1+spl8spm82aEphMJjo7O9PBwYHW1tY0mUw+SQaRD5yQZrOp8/Nz+Xw+PXnyRNlsVm/evNHJyYkSiYTi8bhisZharZZ6vZ7J2WQy0bt379TtdvX06VO9ePFCoVDIopJer6dut2vRkdfrNd2APGCIxuOxWq2Wut2uer2eKpWK3Ve8/r29PbXbbUM3Wq2WoRuBQEDhcNiQh7dv35oz2+l05PP51G631W63zXAv6uDwu+iSRqOhw8ND+f1+ffXVV8pms2q1Wur3+yYXwWBQpVJJjUZDmUxG29vb8vl8qlarKpVKOj8/1+Xlpb2D3++3ffH7/eZUNRoN01PscSKRsIhdkt1BHL3xeKzl5WUtLy8rGo2a00n05p5FOp3WZDIx/cp9weDgmIXD4Z/euJTLZUkyb6TX66lardoXer1eUzKTycSU8Xg8NmGJRCIKBoNaX19XKpVSMpk0hREOh81iEsWEw2HzyPGG+v2+fD6fKVq+x710/H4gEFAwGJz7J97HZDIx5R0IBMzI9Hq9uc8CosKwLrqIeojq8LJDoZAZSJ/Pp263ax6m1+tVs9lUt9vVysqKhcHNZlO9Xk8XFxe6urqSz+ezSASBkGTGFK+ZffJ4PGq32+YpAhUiNCgR/ns0Gqnb7dr7E8Hy33hFKFz2ligBCOtjBPP2mkwmtn9
"text/plain": [
"<Figure size 500x500 with 25 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 40%|████ | 20/50 [01:35<02:22, 4.74s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 20, iteration: 468, Discriminator Loss:array(1.25664, dtype=float32), Generator Loss: array(0.526989, dtype=float32)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9R4+sW5behz/hvY/IjPR5Mo+7pm7V7apqC0jdYAMCBH0BzfgB+Ak004RfgQQ4IUBOOdFAoCABFEUSbaqrbtV1x6c3kRne+wgNsn8rd2Td7s4TeYU/8Edu4OCakxnxvnuvvcyznrWWZz6fz/W4HtfjelyP63H9iMv7/+sHeFyP63E9rsf1/3/r0bg8rsf1uB7X4/rR16NxeVyP63E9rsf1o69H4/K4HtfjelyP60dfj8blcT2ux/W4HtePvh6Ny+N6XI/rcT2uH309GpfH9bge1+N6XD/6ejQuj+txPa7H9bh+9PVoXB7X43pcj+tx/ejLf98f/Of//J9LkkajkcbjsTqdjq6vr+XxeLS2tqZ4PK5SqaSzszP7HZ/Pp1gspmAwqNFopE6nI6/Xq3w+r2g0qu3tbe3s7GgwGOji4kKDwUCDwUCTyUR+v1/hcFiDwUAHBwdqt9va39/XkydPNBwOValUNJ1OlUwmFYlE1Gq1VC6XFQqFtL29rXg8bp/XbrftuV6+fKnV1VV1u121221Np1ONx2PNZjNNJhNNJhPF43GtrKxoOp3q9evXKpfLisfjSqVS8ng8+o//8T9+9Eb/D//D/6DZbKZKpaJWq6V8Pq/nz5/L6/WqXC6r1+tpOp1qMpkoFAopnU4rGAwqHA7L7/drOp1qMBhoOBzq/PxcvV5PyWRSyWRSw+FQ1WpVk8lE4XBYoVBIPp9PgUDAvt9txMDnezweXVxcqNVqKRqNKplMajqdqt1uazQaaTabaTabKRQKKZlMKhAIKJFIKBgM6uLiQu/fv1ckEtGnn35qfx8IBDQcDtVqtSRJ4XBYgUBA/X5fvV5Ps9lM//v//r9/9P79r//r/6r5fK7JZGLPeH19bf89n8/VbDZVq9U0n8/l8Xjk8/m0tramXC6nUqmkN2/eKBwO65/9s3+mzc1NVatVlctlTadTDYdD2+PRaKR8Pq+trS15vV51Oh2NRiOFQiEFg0E1Gg29fv1ao9FIT548UaFQkCR5vV6l02n9d//df6dcLqf/+B//o/7Tf/pPikQiWl1dVTAYVCgUkt/v12g0Uq/X02Aw0OnpqQaDgeLxuCKRiGKxmAqFgvx+v92H2Wym6XQqSfp3/+7fffT+IYOSNBwONR6PFY1Glc1mNZ1OdX5+rk6no8lkotFopEgkovX1dYVCIbXbbXW7XcViMWWzWc1mM5PZTCajdDotv9+vUCgkj8dje4jc9ft9vXnzRs1mU/l8Xrlczvba4/HoyZMnyufzury81Lt37xQIBPTkyRMlEgl7b5/PJ7/fL5/Pp1wup1gspsPDQ3377bfy+/3a2NhQMBjU1dWVKpWKVlZW9Nlnn8nv96tcLqvb7SqRSNgd/lf/6l999P79L//L/6LZbKajoyOdn58rFAopkUjI4/HI7/fL4/GoVCrp/PxcgUBA+XxegUBAXu+ND99sNnV5eSm/36+9vT0lk0mVSiVdXl4u7Hc8Hrd7HIvF1O129etf/1qVSkVer1der1fRaNQ+PxgMyufz2d9FIhG9fPlS8XhcX3/9tb777ju7w36/X9FoVIFAQOPx2OSr3W5rNptpfX1dKysrajabOjk50XQ6VTabVSQSUSgUUjgclsfj0b/+1//6n9yvexuXarUq6eYCeTweSTeKYzqd6uLiQrPZTOl0Wj/96U81m800Ho81n88VDAbNuPR6PUlSIBCQz+fTaDRSpVLRfD5XLBZTNBrVYDDQeDy2Fx+NRvZCk8nElOhwOLTnCQQCSqfTSqVS9v/6/b4ZDY/Ho2AwqNlspmq1qm63q3A4rFgspsFgoMvLS/X7feVyOWWzWQ2HQ717907T6VSz2UyZTEb5fF7FYtHe/WMX757JZFQoFOyAUeAoSK/Xa8puNBqp2WxqNpvZ50wmE3U6HfX7ffn9fvn9fs1mMxOw2Wymfr+vdDqt1dVVzedzVSoVjUYjRaNRRaNRSdJgMNB8PlcgEFAymbRn5Hn4706no3w+r5WVFYXDYVN00WhUa2tr8vv9ms/ndm5cpHA4LK/Xa8YRw7Ds+t3vfidJ9nw4A5yt3+9XNptVNpvVfD7XdDqVx+NRLpdTIpGwSxgKhbS6uqp0Oq3JZGJyhjylUikz6nzHcDjUcDhUr9ez79zd3ZXX61UqlbKLF4vFFAgEdHJyoouLC7XbbSWTSc3nc11dXcnr9do5xWIxpVIphUIh9ft9DYdDRaNRRSIRjcdjnZ6e2vn4fD4lk0lls9ml90+6UW7z+dzeZzAYSJKm06mur6/VarWUSqWUzWaVyWT06aefKhwO6/vvv1ej0VCz2VSn05HH41E8Hl+Q4Xa7bcoII4zB8fl8KhaL5jAht7wb8uHz+RSPx81AzedzhUIhhUIhjcdjdbtdSTdyGQgEVK1WNZ/PNR6PbX9xsMbjsU5OTuwZkAHOb5nF/RiNRva9zWbTHDmv1yufz2dKP5fLye/3mzF3nYZEIqFoNGrvOpvNVK/XFY1GlUqllEql1Ov1dH5+rtFopHg8Lr//Vl2nUiltb2/bnZxMJnbHPB6PTk9P5fV6VavV5PP55PF4THePRiOT/dFoJI/Ho1QqZbrn+vpazWbTdDPnEwwG7Znvs+5tXBqNhiTZRZJkRgPv98svv9SzZ880m83U7XY1nU7N2xiNRraJo9HIIoZ6vW6eNIc0Ho/VbrfVaDRMWFGijUbDlAub5ff7FY/HFYvFzAAhnChKv9+vyWSier0uSVpdXVUul9NkMrGLk81mlUgkNBgMdHJyoslkomKxqEQioXw+r42NDfl8vvtu2cIajUbyer1mwPAyMMIoXpQzyq7dbqvf7ysQCCgUCmkymajf76vf75snwft5vV4zyHzXdDpVo9FQv99XMBhUMpnUaDRSo9HQZDKxz+33++p2uxbh8FmNRkPJZFKZTEaxWEy1Wk2DwUDhcFi5XM4EDUGVZN6Xz+cz5TKZTBQMBrVsK7t3797Zv6N04/G4RUso7EQiYRd/Pp+bEoxEIhZ14fkS2fIO/HwqlTKDj3FBuXW7XSWTST19+tT2HmWbyWQ0nU51dXWlwWBg3n6v11O9Xrc98Hq9Wl9f19ramu09xj8cDqtarery8lLT6VTpdFrhcNjOYFnFKEmdTsccARQM96JWq6nRaCgejyudTiufz2t3d1fRaFQnJyd2vshiOp02Az2ZTNTr9XR6eqrhcKh8Pq9EImGOhc/nM6MP8sFdx7jMZjP5fD6Fw+GFn+P3x+OxGfdWq2VGEqesVqtpNpuZ1z+dTnV5ealgMKitrS1FIhFJsuhvmdXv9+2u8oxEyaFQSF6vV36/X5lMRsFgUKlUSj6fzxASHC6iB/TocDjUbDZTp9MxvRaPx9XtdnV1daX5fK5oNGqR3Hw+VyaT0c7OjkKhkDqdjjmjnO319bU5TUTx6EuenzMNBoOKx+MKBoNqNptqNptqtVqmKzFsqVTKHP37rHsbF3eDe72efD6fgsGgPB6PVlZWzOMDrsIrisViBo0Eg0FNJhOVy2X1+315PB55PB7zKLCQrnLweDwmGK4S4BDc78IqX19fq9vtWtSEEiG6wkB2u131+31JNxBer9dTtVpVp9Mx74rwtN/v6+Li4mO3yxaHzCXiQgIX9no9+f1+21OeqVAomHHu9XryeDxKp9OmSPGMx+OxXTR3L1CQhL4uBOjxeBQOhxUMBjUej83ApVIp80j9fr9Bb9KNgkJJERW6kRXvOhwO5fV6lUwmFQqF1Ov1DCpbZqXTac3nc3U6HXW7XfPEXLhkNptZmI+XORwODY7FcANNdjod8+bz+by8Xq8KhYJSqZQ6nY7q9bo5LxiecDiscDhsCgHHZzgcql6vazQa6erqyuQbaLff72symSgajSoWiykUCi3AfK48DIdDk/mVlRXzcvv9vr3
"text/plain": [
"<Figure size 500x500 with 25 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 60%|██████ | 30/50 [02:23<01:34, 4.74s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 30, iteration: 468, Discriminator Loss:array(1.16943, dtype=float32), Generator Loss: array(0.59993, dtype=float32)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz96W6sWXbmhz8xz3MEGcGZZz5ZJzMrq5TK7mpLBgS1+oN9CTbQsOGL8I34MhqwAdtfGm2pVZJcpapK5XRmHs5TzAMZ8+AP7N/iDlbK4gmm8Qf+4AYOMvMkGfG+e6+9hmc9ay3PbDab6X7dr/t1v+7X/foJl/f/1w9wv+7X/bpf9+v//9a9cblf9+t+3a/79ZOve+Nyv+7X/bpf9+snX/fG5X7dr/t1v+7XT77ujcv9ul/3637dr5983RuX+3W/7tf9ul8/+bo3Lvfrft2v+3W/fvJ1b1zu1/26X/frfv3k69643K/7db/u1/36yZf/tj/43/13/52m06kODw9VqVQkSbPZTIFAQCsrK4rH4zo9PdXJyYmCwaBWVlYUDAZVr9d1cXGhbDarzc1NDYdDvXr1SrVaTaurq9rc3NRsNtNwONR0OlW329VwOFSz2dTZ2ZmCwaCePn2qZDKpfr+v4XCofr+vRqOh8XisYDAov9+vZDKpXC6n4XCoo6MjdbtdRSIRRSIRxWIxraysyOfzqd1uq9/vy+fzye/3y+PxKBgMyuv16uLiQr1ez945FArpwYMHymaz2tvb05s3bzQej/Xdd9999EZ/+eWXkiSfzyefz6d+v692uy1JSqVSCoVCarVaqtfrCgQCSiaTCgaDisfjikQi9jmDwUDv379Xs9lUKpVSOp1WPB63/T4/P1ez2dRsNtNkMpEk29vxeKzpdCqv16twOCy/3694PK5wOKxqtaqjoyNNp1P5fD55vV5lMhmlUil5PB77fs4nGAzavk2nU0nSaDTSYDBQOp3W+vr63H5Pp1P7ub/7u7/76P376quv5PV69ejRI62trSkUCikej2symejo6Ejtdls+n8/kIRaLyev1qtPpqNvtqlaraWdnR+l0Wv/+3/97ffLJJ/pP/+k/6T/+x/8ov9+vXC6ncDis1dVVZbNZ5XI5ra+vq9/v6+uvv1atVtOzZ8/09OlTNRoN/fDDD7q8vFSz2VS321U6ndbS0pKd0Xg8Fs0vGo2GdnZ2NJlMVCwWlUgkFI/HlU6nNR6PVa/X7XfG47EGg4FarZYGg4GOj4/V6XRUKpW0ubkpr9er//Af/sNH758kPXr0SB6PR2trayoUCsrn83r69Klms5nev3+ver2u2WxmMhIKheT1euX3++X1elWtVrWzs6N4PK7/9r/9b7W9va1/+qd/0u9+9zt5PB6FQiH5/X5tbm4ql8spnU6rUChoOBzq/fv36nQ6Wl5e1vLysnq9nk5PT9Xr9XR0dKRGo6FMJqNisSiPx6PJZKLpdKrhcKjhcKjLy0udn59LkjY3N5XJZOwOj0YjtVotk3N+D31xfHysdrutpaUlra6uyuv16n/73/63j96//+F/+B8kSfV6Xe12257L5/Mpn88rEomo1Wqp1WopHA5reXlZoVBIjUZDFxcXyuVy2tra0mAw0G9+8xudnZ0pl8upUCjYu7jyc3FxoWazqXA4rM8++0yZTEYnJyc6OTmRx+OR1+uVx+Ox+xoIBBSNRjUcDnVycqJer6dgMKhQKKRIJKKlpSX5fD5dXFxoOBzK7/fbHQ4Gg/L5fIpGowqHw+p2u6pWq/J4PFpeXlYsFtPh4aHJ8atXr/7F/bq1cRkMBvJ4PNra2tLjx4/Vbrd1fHwsSSZUmUxGfr9fPp/PBHNpaUlLS0sKBAKaTqeazWaKx+OazWYKBoPq9Xry+XyKxWKSZAKBgZrNZqpWq6pWq5pMJprNZraxkUhEmUxGiURCo9FIjUZD0+lUkUhEoVDIFGq329XZ2Zn8fr9dep/Pp8lkIq/XK6/XK5/Pp0gkomg0qsFgoIuLC43HY/X7fV1eXsrj8SidTpuC/NgVDofn/tvr9Zry4X1KpZK2trbsv6UrhT2ZTOyZJpOJUqmU4vG4GarRaKRKpSKv16tutztnQKbTqS4uLtTv981I+P1+O4tIJKJ0Oq3RaKRms6npdKpAICCPxyO/36/BYKDpdKrBYGCKZzabmdJBIfKdwWDQLjR7iwFadO8k2Rkg8Bje2WymdrttCrnRaCgQCCiTySgYDJqBlK5kOBKJqNPp6PDwUIPBQIlEwox5IBDQbDZTt9tVv9/X2dmZJCkQCKhUKmkwGOjt27e2J5JUrVZ1fn6uUqmkaDSq2Wxmjs/q6qpWVlaUy+UUi8U0mUwUi8XMsAcCAbtXHo9H4XDY/un1eu2d2LtarWZyscji/IfDobrdrhqNhnZ3d+X1etXv9+X3+1WpVMypy2azpiSz2aztdTAY1MXFhc7Pz3V5eSm/3y+/369UKmX3/OLiQt1uV5VKxRRgJpORdGVs+/2++v2+er2eqtWqzs7O5PF4VCqVNJvNzFisrKxoeXlZ/X5fpVJJ0+lUiURC4XBYs9lMs9nsj/aJ+80+R6NRjUYjTadTtVqthfePM0+lUkqlUrq8vFS1WpXX61U2m1UymTQd4jo4gUBAPp9PvV5P+/v7mk6nikajWllZUTQaNV3V7/c1mUzU7/fNcd7c3JQkMyr9fl+j0WjOQVxeXlYymVS321Wr1dJ4PFYoFJLH49F4PFav1zNHE53h9Xo1mUzU7Xbtfvr9fvs77jS6PBaLKZVKKZ/P22f9S+vWxoUX2tzc1Orqqk5PT3V5eanxeGybl0wmlUgkTBlKUiKRMGva6/U0m81M0aK8wuGwXSiPx6PRaKRQKGRKb39/XxcXF2YEgsGgCXImk1E2m1WtVlO1WpUkxeNx+f1+dbtddbtdzWYzNZtNSVeGEMOCtzCZTOTxeMxqdzoddTod8yJ7vZ48Ho8SiYQWbcXGO+ORSZozlHjPKysrkmQCVq1WdXFxocFgoEajIY/Ho3w+r3A4rMFgYD9Xr9clyRSV3+9XOBw2zw8j7vP57DkQUAQzGo1qOp2aYM5mM43HY41GI7XbbU2nU/NweA8uBUYdg1Mul+X1epVMJu1z77IwjpxjMBg0g4MBrlQq2tnZUTAY1NLSkqLRqGKxmOLxuCRpPB7L7/ebFzwYDBSPx83jw7jgUNRqNYXDYT1//lzpdFr1et28Z7zMVquls7MzBQIBLS8vazKZqF6vazgcanNzU6VSSf1+34wLF3s8HtseIvd4mcPhUB6PR/1+X/F43DxyIt27LL6bM+OZAoGARe+Hh4cKhUKSpFgsplKppGQyqfF4rEwmI6/Xq16vp1qtZigAxt410NydYDCora0tJRIJ9ft9MxyDwcBQinK5rHQ6LenqXlxeXtq+PXjwQIPBQLlcbi4iRDb5bwyL6/zwbDhanU5n4b1DT8TjccXjcbVaLVPO6XRa6XRawWDQnLNAICBJ9hy9Xk+tVkter1exWEzJZNIML1E/ZzMajRSNRs2pefPmjUUxoAbSleOTz+dVKpV0dnamRqNhjjvnyT6PRiPT0+Fw2BxD9CAIEvvJPgaDQUUiEcXjcaVSqVvf5Vsbl+FwKEkWQTQaDQ0GA7O0rhfu/l0ikTBv9vLyUpPJRKFQyDwcXpAXj0ajZh3ZcD6bS8DmhkIhjUYjXVxcmHfuhvV4LdPp1ITw8vLSDpxDx3DyfVxu/g4vy32Wj108G9HXcDjUeDw272I0GplCm81m5m30+337DLxW9mY0GpmHwfnwXoTVrtJnBQIBJRIJ+f1+DYdDM2CTycQ+2/0s6cooI2ycMVESFx5BvenZ8N936ZG6vr4uj8djUVssFjNv+uzsTJ1OR/F4XMViUcFgUIVCQcFgUN1uV3t7e7bv4/FYx8fHqlQqJgsYZEnq9XoajUYaDoe2HyhJjPlkMrGfI6LFiPp8PouUksmkwSeuY4XhB0pCHvgOjA6fjUOEY7DoKhaLpgiJoFDoQCU4daFQyDzxZrNpdwPZqFQqarVaury8NJlAFpBtnA/gbu4T3vT5+bn6/b6i0ag2NjZ
"text/plain": [
"<Figure size 500x500 with 25 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 80%|████████ | 40/50 [03:11<00:47, 4.74s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 40, iteration: 468, Discriminator Loss:array(1.11762, dtype=float32), Generator Loss: array(0.624788, dtype=float32)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9yW7kaZbmBz82z0YbaJwnd7p7eHhEZERkZXZmqkuFrpYgaSNIkKAL0C1oJegKtNJS0FJ7AQK0aAkCWupSZSWqKis7I2Pw8HniTBpp80QbtWD/Do9ZRlW6G0P4gA98AYLupPE/vMMZnvOccwKTyWSi23E7bsftuB2340ccwf9fP8DtuB2343bcjv//G7fK5XbcjttxO27Hjz5ulcvtuB2343bcjh993CqX23E7bsftuB0/+rhVLrfjdtyO23E7fvRxq1xux+24Hbfjdvzo41a53I7bcTtux+340cetcrkdt+N23I7b8aOPW+VyO27H7bgdt+NHH+H3/eCnn36q0WikarWqer2uQCCgUCikUCikTCajeDyuu3fv6uHDh6rX6/q3//bfqtPpaG1tTaVSScfHx/r++++VSCT0H/6H/6E2Nzf19ddf66uvvtLl5aVarZYkaXV1Vfl8XvV6XRcXFwoEAsrlcorH40qn00qn0+r3+2o2m7q8vNTFxYWazaYWFhZUKpU0Ho/VaDQ0GAwUi8UUjUYlSaFQSJI0mUw0Ho81Ho81GAw0Go3Ubrc1Go2UTqeVTCZVKBR07949jUYj/f73v9fR0ZFSqZQymYwCgYDevXv3wRP93/w3/42Gw6H+9m//Vo8fP1Y6ndby8rJSqZS2traUy+V0584dPXjwQM1mU0+fPlWj0VC5XFa9Xrc5SiaT+tWvfqWVlRV9//33+u6773R5ealms6nRaCRJCgaDGg6HGgwGKhQK+k//0/9U29vbevr0qZ4+fap0Oq3NzU1J0u9//3u9e/dO29vb+uSTTzSZTHR6eqper6dsNqtMJqPBYKB2u63BYKDT01O1Wi1NJhO7XywWUyAQsPktFAra3d1VIBDQwcGBGo2GwuGwwuGr7fav//W//uD5++/+u/9Ow+FQf/3Xf62vvvpKiURC+Xxe4XBYyWRS4XBYa2tr2tzcVL/fV7lcVq/X09nZmWq1mt0/Ho/r/v37WlhY0MHBgfb39zUYDNTpdDQcDtXtdtXv9zWZTDSZTJTNZvWrX/1KS0tLevHihV69eqWlpSX94he/UCQS0W9+8xu9ePFCOzs7+uyzzzQYDPT8+XM1m00tLi6qUChoMpnYehweHqper0uSAoGAgsGg4vG4gsGgxuOxzWcmk5EktVotDQYDRSIRxeNxSdJXX331wfMnSalUSpI0HA41Ho8VDAZtTUajkZ2L0WikUCikaDRqe2k0GimZTCqfz2s8Huv8/Fy9Xk/JZFKpVEqhUEixWEySVK1W7TyzL2ZHIBCY+l0gEJAvFsK/o9GoIpGIgsGgneFut6vBYGDPG4lEtLCwoEgkolarpXa7besXCASUSCQUDoc1Go00Go00mUzUbrc/eP7+h//hf9BoNNL333+vt2/fKpFIqFQqKZFIaGtrS9lsVvfu3dNHH32karWqv/u7v1OtVlOr1VK/39dwONTl5aUymYz+4i/+Quvr63r8+LG++eYbNZtNHRwcaDAYKJ/PK51O233j8bju3LmjTCajy8tLXV5eKhgM2rnr9XoaDAZKJBJKp9M2r+PxWNVq1Z7h+PhY/X7f1nphYUHFYlGXl5d69+6der2eVldXVSqVlM1mtb6+rvF4rNevX6tSqZi8l6T/6X/6n/7kfL23cqnVaiZQ4vG4wuGwEomEJKnX69nhvLi4sI2RTCbt98PhUIFAQJeXl/r222/19u1bXVxcaDgcKhgM2qSw8JPJRIlEQrFYTFtbW1pYWFC/39dgMLBDMZlMFI/HNR6PFYvFFAwG7Voc5na7rVgspmQyqWAwaIuMUBoMBhoOhxoOh0qn0yoUCgoGgzo6OtJoNDIllclkVCwWFQzO7+wFg0GVSiXdvXtXiURCi4uLCgQCarVaarVaKpfL+uabbxQKhRSJRBQIBBQOh22jsUm+++47vX37Vp1OR8lk0t6Pw4aQkmQb8vz8XJFIRFtbW4pGo0okEppMJsrn8yYker2exuOxOp2Oer2e+v2+arXa1EZeWlrS4uKi2u22Go2GAoGAMpmMwuGwCYxIJKJ6va5gMKhMJqNMJqNkMjl1YD50IOgKhYLW1tYUjUZNqcTjcUUiEVWrVZ2cnJigwwBKJpMajUa2D/f29pRIJNRqtTQajRSNRpXNZiVJlUpFrVZL2WxWKysrisViCoVCajQaKhQKymQyikQiajabmkwmSqfT2tnZUTQa1cHBgS4vL1WtVtXtdjUcDlWv16fWEEXS6XTUbDYVi8W0sLCgeDxue5E1CAaDWllZUSqVUq/XU7fbnXv+JJkika6UCUIoEAjY+eHcSJoyIAKBgIbDoSmNUChkyq7T6SgUCtlnOevs48lkon6/b/vSKxLuHQ6H1e/3dXl5aUpBkik7ziKKejwemwIZjUa6vLy0Z+fayAO+OOfzVrw6OjrSeDxWJpPRw4cPFY1GTWE3Gg3V63WVy2V9/fXXU/sNRT4ej+39nj17ptPTU9XrdTuLxWJR/X5f6XRasVjM3plzxFntdrsmGzCm6/X6lBxkjv2ccE5CoZBds1Ao2D2Gw6HW19e1vLxsxvdwOFQ8Hlcul1MymTSj533GeyuXZrMpSYpEIopGo4rFYspmsxqNRmo2m2o2m6rX69rf31cqldLm5qYSicTUogYCAQ0GA718+VLS1QYNh8MmAILBoCaTiXq9nqQrCy6VSmltbU2FQkEXFxc6Pz83K2YymSgWi2k8HpvwkWSbnkMeDAYVjUYVCoXsORCwWK2BQEDpdFoLCwvqdrs6Pz83LR+JRJRMJpXL5f5RS+xPDZ4tl8uZgl5YWNBwONTx8bHNYa1WUzab1aNHj8wiTKfT6vV6dojevHlj18rn8zaXkkwBs06RSES9Xs+E3PLyss3fcDhUNptVv9834TYYDHR5eWkHg2ddXFxUJBJRLpdTLBZTpVJRv99XOBzWwsKCzT9r3Gq1TJkmEgnlcjkVi8W554/3y2azWl5etsOFkRMOh/Xu3Tu9evXKLMl4PG5rjXWHcREKhUwAhsNhFYtFhcNhEwYrKyv69NNPJUnlclntdluLi4taXFxUp9PRycmJBoOBksmkEomEms2mTk9Pzavu9/u2r1Kp1JTnFIvFptYpnU4rk8mo1+vZV61WUzQaVbFY1NLSkqrVqsrl8pThMO8cIpglmSBnryD8ENr8HuHIO0UikSmFwO8DgYApGZT8ZDIxT8MrNK7NvIzHY3W7Xft5IBCYMpi4LnPAc47HY/X7fdvT/pnZk17YzjvK5bKCwaCWl5fNCA2FQhoMBtrf31ej0VCz2VSj0VAikdDm5qbi8bgJc+ZhOBzq3bt3Oj09VSwWM+NwYWFBg8FA0WjU5hbjO5VKaWFhQeFw2GQJc81ei0ajtmZ8ZjQa2ZxEIhFFIhE7Lxh+KJ7JZKLNzU2trKyo0Wjo8PBQ/X7fZGculzOD+H3GeysXbw1gVbMReVhcNl56MBio0WjYoheLRY1GI3U6HQ0GAztw4XBYqVTKPAvvssfjcfX7fbVaLRMQHsYIBoN2f0kmKLBGG42GKUJ+f3l5qXQ6rVwup263q1qtZoe90+mYFTSZTGwxEZrzCsfXr19LklnKvGsgEFA+n1c2m1WlUjFhiSDggDBPgUBgynNkE3kIIJVKKRKJGExRqVR0fn5uQoz3YT6y2axZl3hqwCJYOtFo1NacNS6VShqNRup2u2q32yY4A4GAHYRcLmdr22q15p6/J0+eSJIpSe4bCoVs/TmgHE4OZTAYVL1eV61Wk3RlfDCHGDf9ft+UbKFQUCQSUa1W03g81sXFhXq9ngkpYCPmHYHB/KG00um0EomE4vG4efGNRsNgHea5Vqup1+spFospHo8rGo0qnU4rEokon88rk8kYDHyTOrMIVm/1ApEg/L1H44f3JJgDBCYD4c9
"text/plain": [
"<Figure size 500x500 with 25 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 50/50 [03:58<00:00, 4.78s/it]\n"
2024-07-29 00:18:35 +08:00
]
2024-07-26 21:07:40 +08:00
}
],
"source": [
2024-07-30 00:44:16 +08:00
"# Set your parameters\n",
2024-07-30 18:21:38 +08:00
"n_epochs = 50\n",
2024-07-30 00:44:16 +08:00
"display_step = 5000\n",
2024-07-28 06:10:19 +08:00
"cur_step = 0\n",
2024-07-30 00:44:16 +08:00
"\n",
2024-07-30 07:44:41 +08:00
"batch_size = 128\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-28 06:10:19 +08:00
"D_loss_grad = nn.value_and_grad(disc, disc_loss)\n",
"G_loss_grad = nn.value_and_grad(gen, gen_loss)\n",
"\n",
"\n",
2024-07-30 00:44:16 +08:00
"for epoch in tqdm(range(n_epochs)):\n",
2024-07-28 06:10:19 +08:00
"\n",
2024-07-30 18:21:38 +08:00
" for idx,real in enumerate(batch_iterate(batch_size, train_images)):\n",
2024-07-28 06:10:19 +08:00
" \n",
2024-07-30 00:44:16 +08:00
" # TODO Train Discriminator\n",
2024-07-30 07:06:52 +08:00
" D_loss,D_grads = D_loss_grad(gen, disc,mx.array(real), batch_size, z_dim)\n",
2024-07-28 06:10:19 +08:00
"\n",
" # Update optimizer\n",
" disc_opt.update(disc, D_grads)\n",
" \n",
" # Update gradients\n",
" mx.eval(disc.parameters(), disc_opt.state)\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-30 00:44:16 +08:00
" # TODO Train Generator\n",
" G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)\n",
2024-07-28 06:10:19 +08:00
" \n",
" # Update optimizer\n",
" gen_opt.update(gen, G_grads)\n",
" \n",
" # Update gradients\n",
2024-07-30 07:37:09 +08:00
" mx.eval(gen.parameters(), gen_opt.state) \n",
2024-07-30 07:44:41 +08:00
" \n",
2024-07-30 18:21:38 +08:00
" # if (cur_step + 1) % display_step == 0:\n",
" # print(f\"Step {epoch}: Generator loss: {G_loss}, discriminator loss: {D_loss}\")\n",
" # fake_noise = mx.array(get_noise(batch_size, z_dim))\n",
" # fake = gen(fake_noise)\n",
" # show_images(fake)\n",
" # show_images(real)\n",
" # cur_step += 1\n",
" \n",
" if epoch%10==0:\n",
" print(\"Epoch: {}, iteration: {}, Discriminator Loss:{}, Generator Loss: {}\".format(epoch,idx,D_loss,G_loss))\n",
2024-07-30 07:44:41 +08:00
" fake_noise = mx.array(get_noise(batch_size, z_dim))\n",
" fake = gen(fake_noise)\n",
" show_images(fake)\n",
2024-07-30 18:21:38 +08:00
" # show_images(real)\n",
" \n",
" # print('Losses D={0} G={1}'.format(D_loss,G_loss))"
2024-07-27 05:19:08 +08:00
]
2024-07-26 21:07:40 +08:00
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}