2024-07-26 21:07:40 +08:00
|
|
|
{
|
|
|
|
"cells": [
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"# Import Library"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 21:59:35 +08:00
|
|
|
"execution_count": 1,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"import mnist"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 21:59:35 +08:00
|
|
|
"execution_count": 2,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"import mlx.core as mx\n",
|
|
|
|
"import mlx.nn as nn\n",
|
|
|
|
"import mlx.optimizers as optim\n",
|
|
|
|
"\n",
|
|
|
|
"from tqdm import tqdm\n",
|
|
|
|
"import numpy as np\n",
|
|
|
|
"import matplotlib.pyplot as plt"
|
|
|
|
]
|
|
|
|
},
|
2024-07-30 07:17:12 +08:00
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 21:59:35 +08:00
|
|
|
"execution_count": 3,
|
2024-07-30 07:17:12 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"# mx.set_default_device(mx.gpu)"
|
|
|
|
]
|
|
|
|
},
|
2024-07-26 21:07:40 +08:00
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"# GAN Architecture"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"## Generator 👨🏻🎨"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 21:59:35 +08:00
|
|
|
"execution_count": 4,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"def GenBlock(in_dim:int,out_dim:int):\n",
|
|
|
|
" \n",
|
|
|
|
" return nn.Sequential(\n",
|
|
|
|
" nn.Linear(in_dim,out_dim),\n",
|
2024-07-30 07:06:52 +08:00
|
|
|
" nn.BatchNorm(out_dim, 0.8),\n",
|
|
|
|
" nn.LeakyReLU(0.2)\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" )"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 21:59:35 +08:00
|
|
|
"execution_count": 5,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"class Generator(nn.Module):\n",
|
|
|
|
"\n",
|
2024-07-30 18:45:09 +08:00
|
|
|
" def __init__(self, z_dim:int = 32, im_dim:int = 784, hidden_dim: int = 128):\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" super(Generator, self).__init__()\n",
|
2024-07-30 18:21:38 +08:00
|
|
|
"\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" self.gen = nn.Sequential(\n",
|
|
|
|
" GenBlock(z_dim, hidden_dim),\n",
|
|
|
|
" GenBlock(hidden_dim, hidden_dim * 2),\n",
|
|
|
|
" GenBlock(hidden_dim * 2, hidden_dim * 4),\n",
|
|
|
|
"\n",
|
2024-07-30 07:56:13 +08:00
|
|
|
" nn.Linear(hidden_dim * 4,im_dim),\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" )\n",
|
|
|
|
" \n",
|
|
|
|
" def __call__(self, noise):\n",
|
2024-07-30 07:06:52 +08:00
|
|
|
" x = self.gen(noise)\n",
|
|
|
|
" return mx.tanh(x)"
|
2024-07-26 21:07:40 +08:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 21:59:35 +08:00
|
|
|
"execution_count": 6,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"text/plain": [
|
|
|
|
"Generator(\n",
|
|
|
|
" (gen): Sequential(\n",
|
|
|
|
" (layers.0): Sequential(\n",
|
2024-07-30 18:45:09 +08:00
|
|
|
" (layers.0): Linear(input_dims=100, output_dims=128, bias=True)\n",
|
|
|
|
" (layers.1): BatchNorm(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
|
2024-07-30 07:06:52 +08:00
|
|
|
" (layers.2): LeakyReLU()\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" )\n",
|
|
|
|
" (layers.1): Sequential(\n",
|
2024-07-30 18:45:09 +08:00
|
|
|
" (layers.0): Linear(input_dims=128, output_dims=256, bias=True)\n",
|
|
|
|
" (layers.1): BatchNorm(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
|
2024-07-30 07:06:52 +08:00
|
|
|
" (layers.2): LeakyReLU()\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" )\n",
|
|
|
|
" (layers.2): Sequential(\n",
|
2024-07-30 18:45:09 +08:00
|
|
|
" (layers.0): Linear(input_dims=256, output_dims=512, bias=True)\n",
|
|
|
|
" (layers.1): BatchNorm(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
|
2024-07-30 07:06:52 +08:00
|
|
|
" (layers.2): LeakyReLU()\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" )\n",
|
2024-07-30 18:45:09 +08:00
|
|
|
" (layers.3): Linear(input_dims=512, output_dims=784, bias=True)\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" )\n",
|
|
|
|
")"
|
|
|
|
]
|
|
|
|
},
|
2024-07-30 21:59:35 +08:00
|
|
|
"execution_count": 6,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"output_type": "execute_result"
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"gen = Generator(100)\n",
|
|
|
|
"gen"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 21:59:35 +08:00
|
|
|
"execution_count": 7,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
2024-07-30 00:44:16 +08:00
|
|
|
"# make 2D noise with shape n_samples x z_dim\n",
|
|
|
|
"def get_noise(n_samples:list[int], z_dim:int)->list[int]:\n",
|
|
|
|
" return mx.random.normal(shape=(n_samples, z_dim))"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 21:59:35 +08:00
|
|
|
"execution_count": 8,
|
2024-07-30 00:44:16 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"data": {
|
2024-07-30 21:59:35 +08:00
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAWWklEQVR4nO3cbWzV9f3G8auuRW6r3I/QAgU2hxRoJ5MbLY7YSrMhE8wk6jYGToIbg2lmsmBMNjCaBXQ6HAnDG7xJxrYOKW7ApiKtyGQEKBRkgBQQhILAwEKrlPX8n32SPeq5vsnf7cH79fi8z4Fy6uXvyScnk8lkBACApKv+238AAMD/DkYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAITfbF65YscJ+86uu8jcnJyfHbiSpqanJbrZs2WI3M2bMsJvXX3/dboYMGWI3krRhwwa7GT58uN307NnTbjZv3mw3kvSd73zHbr761a/azd/+9je76dChg900NDTYTepn7d69224efPBBuzlz5ozdtLS02I0kvfbaa3Yzffp0u6mqqrKb0aNH240knThxwm7Ky8vtZtq0ae2+hicFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEHIymUwmmxfOmTPHfvOJEyfazcaNG+1GkgYPHmw3ly5dspsrV67YTV5ent3ceeeddiOlHcTr3r273fzlL3+xmzvuuMNuJKm1tdVuUg601dfX28348ePt5oMPPrAbSfroo4/sJuV3sKamxm4+/fRTu5k6dardSGnHGFMOTB48eNBuXnrpJbuRpAULFthNykHPefPmtfsanhQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAyM32hVOmTLHfPOXQWsrRL0kqKiqymwEDBthNbW2t3Zw6dcpujh49ajeSNGLECLvZs2eP3dx11112k3LITJI+/PBDu0n5Ofz617+2m2HDhtlNt27d7EaSZs+ebTfPPfec3RQXF9tNc3Oz3axatcpuJGnJkiV2k5ub9X/qwunTp+1m7NixdiNJb731lt0UFhYmfVZ7eFIAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAIScTCaTyeaFL7/8sv3mEyZMsJv6+nq7kaSqqiq7SbmCmHJ9s62tzW4GDx5sN5J09uxZu8nLy7Oby5cv201+fr7dSFJdXZ3d3HLLLXYzaNAguzl06JDdbN261W4k6fz583aT8nNobGy0m4kTJ9rN4sWL7UaSxo8fbzcp36HS0lK7GTlypN1I0gMPPGA3kydPtpvf/va37b6GJwUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQcv8/37yiosJuUo5QSdLw4cPtJuXPt23bNruZNWuW3dTW1tqNJB07dsxu7rvvPrtpaWmxm+rqaruRpDvvvNNuUo7HLViwwG7mz59vN9/+9rftRpIuXLhgN++++67dFBQU2M3tt99uN7t377YbSXr77bft5sCBA3aTciAx5fdCkqZPn243/fv3T/qs9vCkAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAELWB/FSjlAtW7bMbo4cOWI3UtpxqJQDY1/72tfsZtOmTXazZs0au5Gk1tZWu3nttdfsJuXAWOfOne1GknJycuwm5XuUcqBt+fLldjNu3Di7kaS8vLykzpXyfUhp3nnnHbuRpP3799tNWVmZ3aQcO0w5silJ3//+9+3m5MmTSZ/VHp4UAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQMj6IF7KobXPPvvMbn73u9/ZjSR961vfspudO3fazQMPPGA3n3zyid3ceuutdiOlHZ07evSo3QwZMsRuUr4PkvT888/bzaJFi+zm9OnTdlNaWmo3H3zwgd1IUn19vd088sgjdlNeXm4377//vt28++67diNJjz76qN3U1tbazZw5c+xm5syZdiNJr7zyit2cOnXKbmbNmtXua3hSAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAACHrg3j79u2z3zzlsNZf//pXu5Gk6667zm5Sjs5VVVXZTUtLi93U1NTYjZT2c2hqarKb8ePH283BgwftRpLq6ursZvPmzXbzhS98wW6Kiorsprm52W4k6eOPP7ablCN/Z86csZvRo0fbTe/eve1Gko4dO2Y3GzdutJt77rnHbvbu3Ws3Utrv+8MPP5z0We3hSQEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEHIymUwmmxfu2bPHfvPVq1fbzYsvvmg3krRlyxa7Wbx4sd0sWLDAbn7yk5/YTepF0e9+97t207NnT7s5f/683Vx99dV2I0nr1q2zm6NHj9rNwoUL7ebxxx+3m2uvvdZuJGnYsGF2c/3119vNqlWr7Oaxxx6zmxtuuMFuJGnDhg12s23bNrvJ8j+N/+HSpUt2I0mtra12M2TIELuZO3duu6/hSQEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAACErA/ivfDCC/abd+rUyW5OnDhhN1L6sTVXx44dP5fPGThwYFJXU1NjN927d7ebJUuW2M1Pf/pTu5HSjrrt3LnTbo4fP243paWldvN5Hjs8e/as3TQ2NtpNys/70KFDdiNJ/fv3t5v6+nq76dy5s90UFRXZjSSNGTPGbkaNGmU3I0aMaPc1PCkAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAkJvtC2tra+03X7t2rd2kHP2SpNOnT9tNyvGqlONsKce45s2bZzeSNHLkSLuZNm2a3TQ1NdlN6tHCl156yW5OnjxpN83NzXaT8rNLOVooSbt27bKb6upqu3niiSc+l89J+V2S0n5vy8vL7eaaa66xmxUrVtiNJPXu3dtuHnvsMbv5/e9/3+5reFIAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAIeuDeJ07d7bfvLKy0m5KSkrsRpLy8/Pt5g9/+IPdpBzeKy4utpu+ffvajSQ9/fTTdrNnzx67mTVrlt2kSvmslL9Tnz597Gbv3r12c+jQIbuRpH/96192c8stt9jNzJkz7eaRRx6xm5///Od2I0mTJ0+2m8bGRrt555137KaiosJuJKmurs5uRowYkfRZ7eFJAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQsr6SOm7cOPvNX331VbtpaGiwG0nKzc36rxLOnz9vNytXrrSbXbt22U2XLl3sRkq7/JryWSlXSN9++227kaS77rrLbl5++WW7mT17tt306NHDbr75zW/ajSTt27fPbkaNGmU3ZWVldpPyu1RYWGg3knTjjTfaTadOneympaXFbjp27Gg3krR161a7mTdvXtJntYcnBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABByMplMJpsXLlu2zH7zs2fP2s26devsRpIqKyvtJj8/327eeOMNu/nxj39sNynHzyRpwIABdlNdXW03ly9ftpvJkyfbjSStX7/eblIOOKYcaNu+fbvd9OnTx24kafXq1XYzYcIEu0n5vVi7dq3d1NTU2I0kLVq0yG6amprspqKiwm527NhhN6lOnDhhN88880y7r+FJAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAISsD+JNnTrVfvN//vO
|
2024-07-30 00:44:16 +08:00
|
|
|
"text/plain": [
|
|
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "display_data"
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"img = get_noise(28,28)\n",
|
|
|
|
"plt.imshow(img, cmap='gray')\n",
|
|
|
|
"plt.axis('off')\n",
|
|
|
|
"plt.show()"
|
2024-07-26 21:07:40 +08:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"## Discriminator 🕵🏻♂️"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 21:59:35 +08:00
|
|
|
"execution_count": 9,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"def DisBlock(in_dim:int,out_dim:int):\n",
|
|
|
|
" return nn.Sequential(\n",
|
|
|
|
" nn.Linear(in_dim,out_dim),\n",
|
2024-07-30 00:44:16 +08:00
|
|
|
" nn.LeakyReLU(negative_slope=0.2),\n",
|
2024-07-30 18:21:38 +08:00
|
|
|
" nn.Dropout(0.3),\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" )"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 21:59:35 +08:00
|
|
|
"execution_count": 10,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"class Discriminator(nn.Module):\n",
|
|
|
|
"\n",
|
2024-07-30 18:45:09 +08:00
|
|
|
" def __init__(self,im_dim:int = 784, hidden_dim:int = 128):\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" super(Discriminator, self).__init__()\n",
|
|
|
|
"\n",
|
|
|
|
" self.disc = nn.Sequential(\n",
|
2024-07-30 07:37:09 +08:00
|
|
|
" DisBlock(im_dim, hidden_dim * 4),\n",
|
|
|
|
" DisBlock(hidden_dim * 4, hidden_dim * 2),\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" DisBlock(hidden_dim * 2, hidden_dim),\n",
|
2024-07-30 07:37:09 +08:00
|
|
|
" \n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" nn.Linear(hidden_dim,1),\n",
|
2024-07-30 18:21:38 +08:00
|
|
|
" nn.Sigmoid()\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" )\n",
|
|
|
|
" \n",
|
2024-07-30 18:21:38 +08:00
|
|
|
" \n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" def __call__(self, noise):\n",
|
2024-07-30 18:21:38 +08:00
|
|
|
" return self.disc(noise)"
|
2024-07-26 21:07:40 +08:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 21:59:35 +08:00
|
|
|
"execution_count": 11,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"text/plain": [
|
|
|
|
"Discriminator(\n",
|
|
|
|
" (disc): Sequential(\n",
|
|
|
|
" (layers.0): Sequential(\n",
|
2024-07-30 18:45:09 +08:00
|
|
|
" (layers.0): Linear(input_dims=784, output_dims=512, bias=True)\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" (layers.1): LeakyReLU()\n",
|
2024-07-30 18:21:38 +08:00
|
|
|
" (layers.2): Dropout(p=0.30000000000000004)\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" )\n",
|
|
|
|
" (layers.1): Sequential(\n",
|
2024-07-30 18:45:09 +08:00
|
|
|
" (layers.0): Linear(input_dims=512, output_dims=256, bias=True)\n",
|
2024-07-30 07:37:09 +08:00
|
|
|
" (layers.1): LeakyReLU()\n",
|
2024-07-30 18:21:38 +08:00
|
|
|
" (layers.2): Dropout(p=0.30000000000000004)\n",
|
2024-07-30 07:37:09 +08:00
|
|
|
" )\n",
|
|
|
|
" (layers.2): Sequential(\n",
|
2024-07-30 18:45:09 +08:00
|
|
|
" (layers.0): Linear(input_dims=256, output_dims=128, bias=True)\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" (layers.1): LeakyReLU()\n",
|
2024-07-30 18:21:38 +08:00
|
|
|
" (layers.2): Dropout(p=0.30000000000000004)\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" )\n",
|
2024-07-30 18:45:09 +08:00
|
|
|
" (layers.3): Linear(input_dims=128, output_dims=1, bias=True)\n",
|
2024-07-30 18:21:38 +08:00
|
|
|
" (layers.4): Sigmoid()\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" )\n",
|
|
|
|
")"
|
|
|
|
]
|
|
|
|
},
|
2024-07-30 21:59:35 +08:00
|
|
|
"execution_count": 11,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"output_type": "execute_result"
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"disc = Discriminator()\n",
|
|
|
|
"disc"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"# Model Training 🏋🏻♂️"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
2024-07-30 00:44:16 +08:00
|
|
|
"cell_type": "markdown",
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
2024-07-30 00:44:16 +08:00
|
|
|
"### Losses"
|
2024-07-26 21:07:40 +08:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
2024-07-30 00:44:16 +08:00
|
|
|
"#### Discriminator Loss"
|
2024-07-26 21:07:40 +08:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 21:59:35 +08:00
|
|
|
"execution_count": 12,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
2024-07-26 21:36:29 +08:00
|
|
|
"def disc_loss(gen, disc, real, num_images, z_dim):\n",
|
2024-07-30 00:44:16 +08:00
|
|
|
" \n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" noise = mx.array(get_noise(num_images, z_dim))\n",
|
|
|
|
" fake_images = gen(noise)\n",
|
2024-07-27 06:09:51 +08:00
|
|
|
" \n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" fake_disc = disc(fake_images)\n",
|
|
|
|
" \n",
|
2024-07-27 05:19:08 +08:00
|
|
|
" fake_labels = mx.zeros((fake_images.shape[0],1))\n",
|
2024-07-30 18:21:38 +08:00
|
|
|
" \n",
|
|
|
|
" fake_loss = mx.mean(nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True))\n",
|
2024-07-30 00:44:16 +08:00
|
|
|
" \n",
|
2024-07-30 18:21:38 +08:00
|
|
|
" real_disc = mx.array(disc(real))\n",
|
2024-07-27 05:19:08 +08:00
|
|
|
" real_labels = mx.ones((real.shape[0],1))\n",
|
2024-07-30 18:21:38 +08:00
|
|
|
" \n",
|
|
|
|
" real_loss = mx.mean(nn.losses.binary_cross_entropy(real_disc,real_labels,with_logits=True))\n",
|
|
|
|
" \n",
|
|
|
|
" disc_loss = (fake_loss + real_loss)\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
"\n",
|
|
|
|
" return disc_loss"
|
|
|
|
]
|
|
|
|
},
|
2024-07-30 00:44:16 +08:00
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"#### Generator Loss"
|
|
|
|
]
|
|
|
|
},
|
2024-07-26 21:07:40 +08:00
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 21:59:35 +08:00
|
|
|
"execution_count": 13,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
2024-07-26 21:36:29 +08:00
|
|
|
"def gen_loss(gen, disc, num_images, z_dim):\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
"\n",
|
|
|
|
" noise = mx.array(get_noise(num_images, z_dim))\n",
|
2024-07-30 07:06:52 +08:00
|
|
|
" \n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" fake_images = gen(noise)\n",
|
2024-07-30 18:21:38 +08:00
|
|
|
" fake_disc = mx.array(disc(fake_images))\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
"\n",
|
2024-07-27 05:19:08 +08:00
|
|
|
" fake_labels = mx.ones((fake_images.shape[0],1))\n",
|
2024-07-30 00:44:16 +08:00
|
|
|
" \n",
|
2024-07-30 18:21:38 +08:00
|
|
|
" gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)\n",
|
|
|
|
" \n",
|
|
|
|
" return mx.mean(gen_loss)"
|
2024-07-26 21:07:40 +08:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 21:59:35 +08:00
|
|
|
"execution_count": 14,
|
2024-07-30 00:44:16 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"# Get only the training images\n",
|
2024-07-30 18:21:38 +08:00
|
|
|
"train_images,train_labels,*_ = map(np.array, mnist.mnist())"
|
2024-07-30 00:44:16 +08:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 21:59:35 +08:00
|
|
|
"execution_count": 15,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
2024-07-30 00:44:16 +08:00
|
|
|
"# Normalize the images to fall between -1,1\n",
|
|
|
|
"train_images = train_images * 2.0 - 1.0"
|
2024-07-26 21:07:40 +08:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 21:59:35 +08:00
|
|
|
"execution_count": 16,
|
2024-07-29 06:24:50 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"text/plain": [
|
2024-07-30 21:59:35 +08:00
|
|
|
"<matplotlib.image.AxesImage at 0x14478c610>"
|
2024-07-29 06:24:50 +08:00
|
|
|
]
|
|
|
|
},
|
2024-07-30 21:59:35 +08:00
|
|
|
"execution_count": 16,
|
2024-07-29 06:24:50 +08:00
|
|
|
"metadata": {},
|
|
|
|
"output_type": "execute_result"
|
2024-07-30 00:44:16 +08:00
|
|
|
},
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaEAAAGdCAYAAAC7EMwUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAaI0lEQVR4nO3df2jU9x3H8dfVH1d1lytBk7vUmGVF202dpWrVYP3R1cxApf4oWMtGZEPa+YOJ/cGsDNNBjdgpRdI6V0amW239Y9a6KdUMTXRkijpdRYtYjDOdCcFM72LUSMxnf4hHz1j1e975vkueD/iCufu+vY/ffuvTby75xueccwIAwMBD1gsAAHRfRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJjpab2AW3V0dOjcuXMKBALy+XzWywEAeOScU0tLi/Ly8vTQQ3e+1km7CJ07d075+fnWywAA3Kf6+noNHDjwjvuk3afjAoGA9RIAAElwL3+fpyxCH3zwgQoLC/Xwww9r5MiR2rdv3z3N8Sk4AOga7uXv85REaPPmzVq8eLGWLVumI0eO6JlnnlFJSYnOnj2bipcDAGQoXyruoj1mzBg99dRTWrduXeyx73//+5o+fbrKy8vvOBuNRhUMBpO9JADAAxaJRJSVlXXHfZJ+JXTt2jUdPnxYxcXFcY8XFxertra20/5tbW2KRqNxGwCge0h6hM6fP6/r168rNzc37vHc3Fw1NjZ22r+8vFzBYDC28ZVxANB9pOwLE259Q8o5d9s3qZYuXapIJBLb6uvrU7UkAECaSfr3CfXv3189evTodNXT1NTU6epIkvx+v/x+f7KXAQDIAEm/Eurdu7dGjhypqqqquMerqqpUVFSU7JcDAGSwlNwxYcmSJfrpT3+qUaNGady4cfr973+vs2fP6tVXX03FywEAMlRKIjR79mw1NzfrN7/5jRoaGjRs2DDt2LFDBQUFqXg5AECGSsn3Cd0Pvk8IALoGk+8TAgDgXhEhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmelovAEgnPXr08DwTDAZTsJLkWLhwYUJzffv29Tzz+OOPe55ZsGCB55nf/va3nmfmzJnjeUaSrl696nlm5cqVnmfefvttzzNdBVdCAAAzRAgAYCbpESorK5PP54vbQqFQsl8GANAFpOQ9oaFDh+rvf/977ONEPs8OAOj6UhKhnj17cvUDALirlLwndOrUKeXl5amwsFAvvfSSTp8+/a37trW1KRqNxm0AgO4h6REaM2aMNm7cqJ07d+rDDz9UY2OjioqK1NzcfNv9y8vLFQwGY1t+fn6ylwQASFNJj1BJSYlmzZql4cOH67nnntP27dslSRs2bLjt/kuXLlUkEolt9fX1yV4SACBNpfybVfv166fhw4fr1KlTt33e7/fL7/enehkAgDSU8u8Tamtr05dffqlwOJzqlwIAZJikR+j1119XTU2N6urqdODAAb344ouKRqMqLS1N9ksBADJc0j8d9/XXX2vOnDk6f/68BgwYoLFjx2r//v0qKChI9ksBADJc0iP0ySefJPu3RJoaNGiQ55nevXt7nikqKvI8M378eM8zkvTII494npk1a1ZCr9XVfP31155n1q5d63lmxowZnmdaWlo8z0jSv//9b88zNTU1Cb1Wd8W94wAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAMz7nnLNexDdFo1EFg0HrZXQrTz75ZEJzu3fv9jzDf9vM0NHR4XnmZz/7meeZS5cueZ5JRENDQ0JzFy5c8Dxz8uTJhF6rK4pEIsrKyrrjPlwJAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwExP6wXA3tmzZxOaa25u9jzDXbRvOHDggOeZixcvep6ZPHmy5xlJunbtmueZP/3pTwm9Fro3roQAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADPcwBT63//+l9DcG2+84Xnm+eef9zxz5MgRzzNr1671PJOoo0ePep6ZMmWK55nW1lbPM0OHDvU8I0m//OUvE5oDvOJKCABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAw43POOetFfFM0GlUwGLReBlIkKyvL80xLS4vnmfXr13uekaSf//znnmd+8pOfeJ75+OOPPc8AmSYSidz1/3muhAAAZogQAMCM5wjt3btX06ZNU15ennw+n7Zu3Rr3vHNOZWVlysvLU58+fTRp0iQdP348WesFAHQhniPU2tqqESNGqKKi4rbPr1q1SmvWrFFFRYUOHjyoUCikKVOmJPR5fQBA1+b5J6uWlJSopKTkts855/Tee+9p2bJlmjlzpiRpw4YNys3N1aZNm/TKK6/c32oBAF1KUt8TqqurU2Njo4qLi2OP+f1+TZw4UbW1tbedaWtrUzQajdsAAN1DUiPU2NgoScrNzY17PDc3N/bcrcrLyxUMBmNbfn5+MpcEAEhjKfnqOJ/PF/exc67TYzctXbpUkUgkttXX16diSQCANOT5PaE7CYVCkm5cEYXD4djjTU1Nna6ObvL7/fL7/clcBgAgQyT1SqiwsFChUEhVVVWxx65du6aamhoVFRUl86UAAF2A5yuhS5cu6auvvop9XFdXp6NHjyo7O1uDBg3S4sWLtWLFCg0ePFiDBw/WihUr1LdvX7388stJXTgAIPN5jtChQ4c0efLk2MdLliyRJJWWluqPf/yj3nzzTV25ckXz58/XhQsXNGbMGO3atUuBQCB5qwYAdAncwBRd0rvvvpvQ3M1/VHlRU1Pjeea5557zPNPR0eF5BrDEDUwBAGmNCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZriLNrqkfv36JTT317/+1fPMxIkTPc+UlJR4ntm1a5fnGcASd9EGAKQ1IgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMNzAFvuGxxx7zPPOvf/3L88zFixc9z+zZs8fzzKFDhzzPSNL777/veSbN/ipBGuAGpgCAtEaEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmOEGpsB9mjFjhueZyspKzzOBQMDzTKLeeustzzMbN270PNPQ0OB5BpmDG5gCANIaEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGG5gCBoYNG+Z5Zs2aNZ5nfvSjH3meSdT69es9z7zzzjueZ/773/96noENbmAKAEhrRAgAYMZzhPbu3atp06YpLy9PPp9PW7dujXt+7ty58vl8cdvYsWOTtV4AQBfiOUKtra0aMWKEKioqvnWfqVOnqqGhIbbt2LHjvhYJAOiaenodKCkpUUlJyR338fv9CoVCCS8KANA9pOQ9oerqauXk5GjIkCGaN2+empqavnXftrY2RaPRuA0A0D0kPUIlJSX66KOPtHv3bq1evVoHDx7Us88+q7a2ttvuX15ermAwGNvy8/OTvSQAQJry/Om4u5k9e3bs18OGDdOoUaNUUFCg7du3a+bMmZ32X7p0qZYsWRL7OBqNEiI
|
|
|
|
"text/plain": [
|
|
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "display_data"
|
2024-07-29 06:24:50 +08:00
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
2024-07-30 00:44:16 +08:00
|
|
|
"plt.imshow(train_images[0].reshape(28,28),cmap='gray')"
|
2024-07-29 06:24:50 +08:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 21:59:35 +08:00
|
|
|
"execution_count": 17,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
2024-07-30 00:44:16 +08:00
|
|
|
"def batch_iterate(batch_size: int, ipt: list[int])-> list[int]:\n",
|
2024-07-29 06:24:50 +08:00
|
|
|
" perm = np.random.permutation(len(ipt))\n",
|
|
|
|
" for s in range(0, len(ipt), batch_size):\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" ids = perm[s : s + batch_size]\n",
|
2024-07-30 00:44:16 +08:00
|
|
|
" yield ipt[ids]"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 21:59:35 +08:00
|
|
|
"execution_count": 18,
|
2024-07-30 00:44:16 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"def show_images(imgs:list[int],num_imgs:int = 25):\n",
|
|
|
|
" if (imgs.shape[0] > 0): \n",
|
|
|
|
" fig,axes = plt.subplots(5, 5, figsize=(5, 5))\n",
|
|
|
|
" \n",
|
|
|
|
" for i, ax in enumerate(axes.flat):\n",
|
|
|
|
" img = mx.array(imgs[i]).reshape(28,28)\n",
|
|
|
|
" ax.imshow(img,cmap='gray')\n",
|
|
|
|
" ax.axis('off')\n",
|
|
|
|
" plt.show()"
|
2024-07-26 21:07:40 +08:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
2024-07-30 00:44:16 +08:00
|
|
|
"### show first batch of images"
|
2024-07-26 21:07:40 +08:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 21:59:35 +08:00
|
|
|
"execution_count": 19,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"data": {
|
2024-07-30 21:59:35 +08:00
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAADseElEQVR4nOy953Nc15nn/73dt3POGTkRAEkwiaKoQCpYsmV7XS7P1Iy3ZndrpnarNvwju283/AG7M+WZ8fzGM7Y1ki0rWWIQM0ECJHLnnOPt/HvBOUfdAEiCJFI3z6eKRVFoNPoe3Hue86Tvw7Xb7TYYDAaDwdhFRAf9ARgMBoPRfzDjwmAwGIxdhxkXBoPBYOw6zLgwGAwGY9dhxoXBYDAYuw4zLgwGg8HYdZhxYTAYDMauw4wLg8FgMHYdZlwYDAaDsevwO30hx3F7+Tl6iucRNWDr9x1s/V6M5xXVYGv4HewefDF2sn7Mc2EwGAzGrsOMC4PBYDB2HWZcGAwGg7HrMOPCYDAYjF1nxwl9xuGA4zjIZDKIxWJIJBLwPA+O42iyURAE1Ot1NJtN1Ov1507+MhgMxovAjEuPoVAoMDs7C5PJBJfLBbfbDYlEAoVCgWaziVu3bmFjYwPpdBqBQACNRgONRoMZGQaDsa8cWuPSWfbHNsZHcBwHiUQCm80Gt9uN8fFxTE5OQiaTQaPRoNFooFQqQRAEtNttxGIxtNttNJtNtoYMBgPA9iXVe7E/HCrjIhaLMTw8DLvdDrPZDI/Hg0KhgD/84Q8IBoMH/fEOFKPRiKGhIZjNZly4cIH+t8ViAc/zkMlkaLVaOHfuHIaGhuD3+zEwMIB8Po+lpSXkcjlUKhUIgnDQl8JgMPYZsVgMmUwGhUKBEydOwO12o9lsotFoIJvN4vbt28hkMmg0Gmg2m7vyMw+VcZFIJJidncWJEycwNTWFt956C8FgECsrKy+9cbFYLDh79iycTie+//3vY2xsjOZaOk8ibrcbrVYLGxsbGBsbQzQaRbPZhNfrRTqdZsaFwXgJ4XkeGo0GBoMBP/rRj3D+/HnUajUIggCv14t4PI5KpYJKpdJfxkUikUCr1UKtVsPhcMDtdkMsFmNtbQ2hUAilUmnH7yUWi6HT6SCVSqFWq6HValGr1VAoFFCv15HJZFCpVPbwanYXnufB8zx0Oh3cbjccDgdkMhkAIJ/PI5/Pg+M4iMViiMViaDQaqFQqqFQq2Gw2iMViTE5OQqPRIBKJIB6Po1aroVwuo9lsolKpoFarHfBVMhiM/UAsFkOhUECj0aBWq0Eul0Oj0UAmk4HneYhEu1dAfCiMi06nw5kzZ2C1WvH222/j9OnTuHLlCv7H//gfSCaT8Pl8O34vlUqFkydPwmaz4fTp0zh+/DgSiQTu3LmDVCqFzz//HKurq3t4NbsHx3HU6E5NTeG9996DwWCASqVCPp/HrVu3cOPGDYjFYmi1WiiVSrzyyisYGxuD2WymN9Ds7CzK5TKWl5exsrKCdDpNQ2UbGxtIJBJot9ssL8Ng9Dkcx0Gj0cBsNqPVaqHVaqFUKkGn00GlUqFer+9adONAjQs5bavVathsNtjtduj1eiiVSlSrVaytrSGdTqNcLj/1vTiOg0gkglwuh8VigdPpxODgIMbHx6FUKhEMBtFoNCCVSvfhynYHkUgEpVIJvV4Po9FIDYYgCCiXy4jFYvD5fOB5Hnq9HiqVCul0GsViETzPQ6lUQqFQQKFQ0KqxWq0GlUqFbDYLuVxO17fRaNDS5d1yi/sREoIUiUTgOA48z0MsFu9Yd6rRaKDVaqHZbLJ1fk46114ikQAALVoh93m/Q+43ErUga0I8D3JY3G49SBsDeZ1cLu9qa9gtDtS4jI2NYWpqCm63G++++y60Wi0CgQDm5+dx9+5dxGIxuvE9DaPRCJfLBafTibfffhtDQ0PgeR4bGxtYWVnBF198gUQigXg8vg9X9mKIRCJqFN555x2cPXsWHo8HEokEuVwOv//97+H1erGysoKVlRUAgFwuB8/zuHfvHux2O8bHx3Hu3Dmo1WqYTCbI5XIMDg5Cp9NBEATMzc2hUqlgbW0N8Xgcfr8f9+/fR7FYRDgc7qnQ4X4hk8noH3LSe+211zA6OgqJRAKpVPrEh7NarWJ+fh7hcBh+vx9LS0svxUa4m4jFYjidTuh0OszMzODcuXNoNpuIRqMoFAq4evUq7ty5c9Afc88gYXIS3pJKpRgbG4PVaoXRaITD4UCr1UI2m0WlUsHt27dx//59iEQiVKtVCIJADzgikQgikYgaaalUCrFYvHufddfe6RnhOA5OpxOvvPIKBgYG8Nprr0EqleL27dv47LPPEAqFkM1mUa/Xd/R+Go0GQ0NDGBwcxIkTJzA6OorV1VVsbGxgY2MDd+/eRTKZ7In8glgshlwuh1qtxokTJ/DDH/4QYrEYPM8jnU7j6tWruHbtGlKpFJLJJA1ncRyH+/fvQy6X44033oDD4YDZbIZer4dUKqXeIaHRaCAQCCCVSuHmzZsoFApIpVJIpVLMuGyDRCKBUqmESqWC3W6HyWTC9773Pbz++uuQy+VQKpVbCiza7Tb9d7FYxG9+8xvcu3cP7XYbq6urrEz8GRGLxTCZTHA4HDh//jz+/b//92g0GlhcXEQymUQ0Gu1r40K8DplMBrVaDbVajSNHjmB8fBwDAwOYmZlBs9lEMBikOdmlpSW0223U63XaYE2MCzFW5E9P51x4nofVaoVarcbExAQmJyehUqkQDAZRr9cRCAQQi8WQz+fRarV2/L5yuRwmkwkGgwEKhQI8z6Ner6NQKKBUKtFF7YUHmSTmjUYjNBoNlEolEokEVldXEY1GEQqFUCgUUK1Wu66HuMHVahWBQADffPMN9Ho9IpEI9Ho9DAYD9Ho95HI5bDYbRCIRVCoV2u02hoeH8eqrryIWi6FUKiESiaBYLD5TMUU/oVAooNfrwfM8VCoVpFIptFot9Ho9dDodhoaGoNfr4XK5aDJ0M+R3QwwMz/PweDzgOA6CICAYDKJQKCAajaJare73Je4qPM9DrVZDJBKhVCrt+vXIZDLo9XpoNBqcOHEC4+PjGB8fh0QiQavVomEdEqLshed8p5C1lclk8Hg8GBwchFwuh8FggFwux8TEBOx2O/13q9Wi0QqXy4WhoSHq6ZhMJphMJkgkEnqYXF9fRzKZRC6X29Xf274bF7lcjmPHjsHtduPChQu4ePEikskkLl26hGg0ips3b2J5eZkag52iVqsxODgIl8sFrVYLmUwGQRAQj8eRzWZRrVZ7wmsBHj1IdrsdNpsNVqsVer0e9+7dw9/8zd8gFothYWEBqVRqW+NbrVZp+GV5eRkKhQLj4+MwGAw4cuQIpqam4HQ6odVqaWLPaDTCYrFgZmYGwWAQlUoFS0tL2NjYeGmNC1kvjUYDj8cDtVoNq9VKe4tmZmagVqtpvPppTb/tdhtSqRSnTp3C3Nwc9Ho9BEFAJBLBN9980/PGRS6Xw+12QyqVwu/37/r1qNVqTE5Owmaz4Wc/+xleeeUVyOVyyGQyNJtN6jluZ+R7FeIFy2QyuN1u6PV6vP/++3jvvfegUqlgsVggkUhoeIv0rQCPIjn1eh3Hjx9HPp+HUqmEw+GATqfD8PAwFAoF1tfXceXKFfh8PqytrSEcDj/Tgf5p7OtvgnSYGwwG2O12WjLcbreRSqUQj8dRLBZRq9WeOdkpEolo6Ij8rGazSY3Kbi7aXkNuKFJ80Gq1IAgC0uk07VV53PqQja1Wq6FWq6FarSKRSEAQBHrya7Va8Pv99BSuVCppHqFUKsFisSCbzSKRSEAkEvXU2j0vJBmq0WigUCjgdDrhcrmgUqno3xaLBUajEUajkRaedLK54m67/AspIyd/9/oAKrlcDoVCQUvlyWYvkUggCAIKhQKtSnpeyPNgMplgNpthMBig0+loErvdbtOejX7KYUmlUqhUKqjVarhcLpjNZjgcDlgsFrrmYrEYhUIB5XIZlUoF+XyeesqkEgx49HuyWq3Q6XRQKBTgOA6VSgWJRAKpVArVanXXC0z2zbiQ+J5Wq8Urr7yCU6d
|
2024-07-26 21:07:40 +08:00
|
|
|
"text/plain": [
|
2024-07-29 06:24:50 +08:00
|
|
|
"<Figure size 500x500 with 25 Axes>"
|
2024-07-26 21:07:40 +08:00
|
|
|
]
|
|
|
|
},
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "display_data"
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
2024-07-30 00:44:16 +08:00
|
|
|
"X = batch_iterate(25, train_images)\n",
|
|
|
|
"for x in X: \n",
|
|
|
|
" show_images(x)\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" break"
|
|
|
|
]
|
|
|
|
},
|
2024-07-30 07:06:52 +08:00
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"### Training Cycle"
|
|
|
|
]
|
|
|
|
},
|
2024-07-26 21:07:40 +08:00
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 21:59:35 +08:00
|
|
|
"execution_count": 20,
|
2024-07-27 06:09:51 +08:00
|
|
|
"metadata": {},
|
2024-07-29 06:24:50 +08:00
|
|
|
"outputs": [],
|
2024-07-27 06:09:51 +08:00
|
|
|
"source": [
|
2024-07-30 18:21:38 +08:00
|
|
|
"lr = 2e-6\n",
|
|
|
|
"z_dim = 128\n",
|
2024-07-30 07:06:52 +08:00
|
|
|
"\n",
|
2024-07-27 06:09:51 +08:00
|
|
|
"gen = Generator(z_dim)\n",
|
|
|
|
"mx.eval(gen.parameters())\n",
|
2024-07-30 07:44:41 +08:00
|
|
|
"gen_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999]) #,betas=[0.5, 0.9]\n",
|
2024-07-27 06:09:51 +08:00
|
|
|
"\n",
|
|
|
|
"disc = Discriminator()\n",
|
|
|
|
"mx.eval(disc.parameters())\n",
|
2024-07-30 07:44:41 +08:00
|
|
|
"disc_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999])"
|
2024-07-29 06:30:08 +08:00
|
|
|
]
|
|
|
|
},
|
2024-07-27 06:09:51 +08:00
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 21:59:35 +08:00
|
|
|
"execution_count": 21,
|
2024-07-27 06:09:51 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stderr",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
2024-07-30 18:45:09 +08:00
|
|
|
" 0%| | 0/200 [00:00<?, ?it/s]"
|
2024-07-30 18:24:53 +08:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
2024-07-30 21:59:35 +08:00
|
|
|
"Epoch: 0, iteration: 468, Discriminator Loss:array(1.32117, dtype=float32), Generator Loss: array(0.462009, dtype=float32)\n"
|
2024-07-30 18:24:53 +08:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"data": {
|
2024-07-30 21:59:35 +08:00
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9WY+cWXYdgK74Yp7nOSNnJpNDcSqyukrqrio1JLUMy4IhWPAAPRiCAL/6xX/Cf8APBvRqwYAAw5IMCN0lqbtL1c2ayOKQzGTOkTHP8zzch+i180R220oGeXGBizxAoopkZkbE+c7Ze+21195bN51Op7haV+tqXa2rdbXe4dL+f/0GrtbVulpX62r9/9+6ci5X62pdrat1td75unIuV+tqXa2rdbXe+bpyLlfral2tq3W13vm6ci5X62pdrat1td75unIuV+tqXa2rdbXe+bpyLlfral2tq3W13vm6ci5X62pdrat1td75unIuV+tqXa2rdbXe+TJc9hv/23/7bxiPxzg6OkI6ncZwOESn08FkMsF0OsV0OoXD4YDL5YLP58P9+/fhcDiQzWZRq9UwGAzk+00mEzRNQzAYRDAYRLVaxZMnT9BsNtHtdjEYDDCdTjEej+FwOHD79m14vV5kMhlks1kUCgU8ffoUAPAnf/InePToEc7OzrC3t4dOp4NisYjBYIBoNIpQKASj0QiLxQKz2YyNjQ0EAgF88803+Oyzz2CxWHDnzh14PB44nU7YbDY0m02k02m0Wi3s7e2hVCrhD/7gD/Anf/InMJlM+PDDD994o//yL/8So9EI3333HV6/fg2n04loNApN09DtdjEajRAKhRCLxTAej9FutzEcDnF2doZyuQybzQaPxwOdTnf+8AwGGAwGWK1WeL1eGAwGFItFNBoNGI1GWK1WtFotfPHFFyiVSrhz5w5u376NTqeDTCaDfr8vz+/27dv4wQ9+gE6ng5///OcoFov49NNP8dFHHyGfz+Pbb79Fv99HOByG0+lEvV5HqVSCwWCAz+eDpmn49ttvsbu7CzZ90Ov18Hg8sFqtuHv3Lj766CMYDAbcvXv3jffvv/7X/4rJZIJ0Oo1CoQBN02AwnB/f6XSKWq2GcrkMr9eLR48ewePxIBgMwu124+XLl/i7v/s7tFotDIdDjMdj3L9/Hw8ePECz2cTe3h56vR58Ph8cDgeMRiOMRiP6/T4ODw/RbDaxurqKlZUVAMB4PMZ0OsVoNMJkMoHNZoPP50Ov18PBwQGazaacOYPBAJPJBJPJhKWlJbjdbqRSKbx+/Rrj8Rh6vR4A0Gg05OfcbjfG4zGOj49RrVbx8OFDfPrppzAYDPjTP/3TN94/APg3/+bfYDKZYDAYYDgcwmw2w+VyYTKZoFwuo9PpYDgcot/vw+v14uHDh3C5XBgOhxgOh0in03j27Bl6vR4mkwkmkwni8ThisRiMRiNMJhOMRiO2t7cRj8eRyWTw6tUrVKtVfP311yiXy7h16xZu3LgBi8UCr9cLTdNQr9fR6/WwtraGO3fuoFKp4K/+6q9wdnaG7e1tbG5uotFo4Pj4GABw584dRKNRZDIZHB0dYTAYoNvtyjMBgHq9jlQqBYPBgEePHiEejyMQCMidW2QP/+Iv/gKj0Qg/+9nP8PTpUwQCAWxtbcFkMsk9slqtsNlssFgs8Pv9MBgMqNfraLfbAACdTofRaIRCoYBut4tYLIZ4PI5Op4OTkxMMBgN4vV44HA4Mh0OxmzyDa2trWF1dxWg0QrvdlvM3nU5ht9vhdrvRbrfx7Nkz1Ot1fPzxx/jwww9RLpexs7OD4XAoZ7BQKCCZTGIymcBoNELTNNjtdthsNhSLRbx69QrD4RB2ux0mkwk3btzA3bt3YTQa8W//7b/9Z/fr0s7l+fPn0Ov1CIfD2NzcxGg0Qq/Xw2AwQC6Xkw86Go1QrVbx7bffQq/Xo9vtotfrwWAwwGg0wmw2w+PxwGazQafToVaroVKpoFqtotvtYnV1FaFQCKPRCOPxGDqdDpqmodlsotfrYTQaQa/Xw+12YzqdIp1OQ6fTyYE3GAwIBALys+VyWT6DXq/HcDiEx+PByckJ6vU6BoMBWq0WjEajHAA6zvF4DLvdjslkgna7jRcvXkCv1y/kXEqlEgDgvffew4cffohqtYqTkxOMx2NxgAaDAb1eD91uF9lsFr1eD41GA4PBAEajEePxGFarFUtLS3A4HBiPxxiNRhgMBqhUKhiPx7BYLAiHw2i32yiXyxiNRrh58yam0ylWV1exvLyMVqslhzOVSqFSqYhx0Ov1cDqdcDgcSCaTsr+apsnzarVayOfzOD09hdFoRDQahclkwng8hs/ng16vh9FoBAB0Oh00Gg10u10xSIss7p9OpxNHarFY5LU0TcNgMECv15P/LxaLODs7w2g0QrfbRSgUQiQSgcfjgcVigdPpxGg0wnA4lPNmMBhgs9lgt9vh9XrRbreRTCYxHo+Rz+fRarVgNpvhdrvlfLVaLXi9XvT7fQAQkNLpdAQkdDodecYAUK1WUavVoGka/H4/zGaznFGLxQKPx4PpdIpOpwOLxYJGo4Gf/exnALCwc/H7/dDpdPB4PHA4HKjVajg7O8NwOITT6YTL5ZLv1TQNuVwOhUJB3vdwOITX64Ver0cikYDb7YbdbofD4UCn00Eul0Ov10Mul8NgMMBoNBIw0uv1UKvV4HQ60Wq10O120el05IwMBgOMx2P0ej0xaPF4HOVyGZlMBuPxWBxio9GA0+kUcKTX6+FyuaDT6eByuWC329Hr9bC+vo7JZCI/Y7fbMRqNoGmLETY/+9nPoNPpEI1Gsb29jcFggEajIcDaYDCg3++j0WgIkNY0DeVyGfV6XQCf0WjEysoKLBYLhsMharUaxuOxOHqea9473hnuD21hoVDAeDyW92e32wWo2mw2GI1G5HI5/PSnP8VkMsFwOIROp5N9z+VyODo6gl6vh8/ng9FoRDqdRrPZhN1ux/r6utgk2t1utyvn/J9bl3YuR0dHMBqNuHbtGh48eCBGrdPpYHd3F8ViEfV6HZVKRQ4YN0Wn08FutyMQCMBkMsHhcMDj8aDRaKBer8vXcDhEKBTC7du35dL3ej2JIvr9vjgNl8uF0Wgkr2uz2eB0OmE0GmGz2aBpGmq1Gur1ujwUvV6P6XQKl8uFXC6HVquF6XSKbrcLi8WCTqcjxkmv18uDBiCIlCjzTVetVoPRaMSDBw9w//59vHr1CplMBr1eDx6PBy6XC+12G61WC+12G4VCQQ7oZDIRx61pGmKxGMLhsBgvOqrBYIBEIgGv1yuX2WAwYHNzE263G8FgEIFAALVaDdVqFXq9Hv1+H+VyWSIil8uF7e1tOBwOPHv2DDs7OwgEArh16xYsFos4pnQ6jcPDQxiNRoxGI3HCLpcLJpMJNpsNo9FI3qMaJS2y6vU6AIhTMBqNcDqdErmZTCY5a51OB6lUShxDsViEz+fD6uoqHA4H1tfX4fF45Nxxb+lczGYznE4nQqEQGo0G9Ho9xuMxyuUy8vk8nE4ndDoddDodstksSqUSOp0OzGYzzGYzvF4vjEaj3I92u418Pg+DwYBQKAS73Y5Wq4Vmswmz2SxGZzqdQtM0WK1WeDweTCYTcUqVSgVHR0cL7x8AeDweaJqG9fV1RKNRHB8fC/L3+/2w2WzirDudDs7OzjAYDOB0OmG1WsUA2mw23L9/H9FoFMAsaiyXy+JcqtUqer0enE4n/H4/nE4nhsMh2u22fE2nUzHM/X5fHHCj0RAHb7PZ8Pz5c7x69QoWiwU+nw92u11+R6/Xkz2z2+0CdILBoIDffr+P09NTNBoNsR+L7uE333wDs9mMP/mTP8Hv/u7vIplM4vHjxxgMBhKtEGyodiKfz6NarcJqtcLtdsPpdCISiSAYDOLs7AwnJyfQNE3sFjBz7vwajUYAgNFohH6/j06ng1arhWKxiOFwKI6IjpkAxWazoVwu4+joCA6HA9FoFBaLRfatWq0ik8lIlG6xWHB6eopkMonNzU08ePAANpsNhUIBrVZLWJbLOudLOxdeAiLYwWAgBrBYLKJcLou
|
2024-07-30 18:24:53 +08:00
|
|
|
"text/plain": [
|
|
|
|
"<Figure size 500x500 with 25 Axes>"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "display_data"
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"name": "stderr",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
2024-07-30 21:59:35 +08:00
|
|
|
" 25%|██▌ | 50/200 [04:00<11:30, 4.60s/it]"
|
2024-07-30 18:24:53 +08:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
2024-07-30 21:59:35 +08:00
|
|
|
"Epoch: 50, iteration: 468, Discriminator Loss:array(1.04, dtype=float32), Generator Loss: array(0.682963, dtype=float32)\n"
|
2024-07-30 18:24:53 +08:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"data": {
|
2024-07-30 21:59:35 +08:00
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz913OlaXbmhz7bewdseCCR3pTtKk5zWuKMhhwGIxSh0Y10p4v5H/Qf6V43MjGjUIgackQ2u8nq7nJZVZmVDgm/AWzv/bmAfgtr785uIoE6cSJO4I1AZBWwzfe933qXedaz1gpMp9OpbtbNulk362bdrJ9wBf9/fQE362bdrJt1s/7/b90Yl5t1s27WzbpZP/m6MS4362bdrJt1s37ydWNcbtbNulk362b95OvGuNysm3WzbtbN+snXjXG5WTfrZt2sm/WTrxvjcrNu1s26WTfrJ183xuVm3aybdbNu1k++bozLzbpZN+tm3ayffIUv+8JYLGb/RiIRxWIxZTIZTSYTnZycqNPpKJlMKp1Oazweq9VqaTqdanFxUZlMRq1WSycnJ5KkRCKhcDisbrerbrereDyulZUVRaNRDQYDjUYjBQIBBQIBxWIx3bp1S6lUSru7u9rd3VUkElEqlVIoFNJwONRoNFI6ndbCwoJCoZCCwaCCwaDa7bY6nY4Gg4EajYbG47Emk4mm06ny+bxWV1clSb1eT6PRSKenp6pUKorH41pYWFAgENBkMtFkMlEsFlMikZAkffnll++90fF4XMFgUMvLy1pcXNRoNFK/35/Zj1QqpWQyqcFgoNPTU/X7fbVaLdujdDqtyWSis7MzdbtdxWIxxWIxjUYjdTodjcdjhUIhBQIBBYNBhcPnj5e9TCaTSiaTM9fV7XY1HA5tz3ivv/fhcKhOp6NAIKDl5WVlMhk1m03VajVNp1P7nuFwqOFwqEgkomQyqWAwqGg0qmAwqMFgoG63q+l0qtPT0/fev83NTU2nU7VaLXU6HW1vb+sXv/iFJpOJfvWrX+nw8FAffPCBPv30U1WrVf3jP/6j6vW63ddkMtF4PFY0GtWtW7eUyWR0enqqk5MTjcdjDQYDTadTkx/WZDLRaDTSZDKxfUylUtrY2FAoFNLp6amazabtVTAYVCqVUjgcVqFQUC6XU6/XU6VS0XA4NBnkrAwGAx0dHanX6ymRSCgejysUCikSiSgSiWhxcVGJREJnZ2c6Pj7WZDJRo9F47/2TpL/6q7+y5zQajdRsNnV0dKTJZKJMJqNYLKZ2u61ms6lkMqmtrS0lk0m7rl6vp3q9rn6/r5OTE3W7XeXzeRWLRaXTaW1tbSkWi6nX66nf7yudTqtQKGgwGOjly5dqNBpKJpNKpVIql8v69ttv1e/3Tab6/b663a7S6bQ+++wzFQoFvXjxQq9fv9b6+rp+8YtfKBAI6Je//KVev35tsh6LxZTP5xWJRNRoNNRutzUYDNRutzUej23fHzx4oJ/97GcKhUL6n/6n/+m99297e3vm/5PJpBYXFxWPx7W0tGT3dXZ2ptFopOl0qvF4rOPjY1WrVSWTSRUKBU0mE9XrdQ0GAxWLRS0tLanVaml3d1e9Xk+S7L6SyaQ988FgYHIRCAQUCoXseY7HY02nU02nU8ViMa2vryuRSGg0Gmk0GtmzQ5an06nS6bSy2axGo5GOj4/V7/eVSqWUSCTsvEynU3s91yXpUjJ4aeMSiUTswyeTifr9vobDof0ukUjY4WdjAoGAxuOxPWyUEAc4HA7b5zYaDQWDQcXjcUUiEQ2HQ/V6PQ0GA+3t7SkWi2k4HGphYUGTyUSSbNMGg4ECgYB9Vr/f12QyMUOFcpxOp0qlUmbE9vb2FIlEVCwWZxSCX71eT+PxWP1+X+Px+LLb9Xsrk8lIkhkO7j8QCKjZbNrrECauezQa2QHp9/umzDE2mUxGo9FIsVjMrm86nSoajSqRSJihHw6HGgwGtv84CwgLByEYDCoSiZhSDofDGo1GSqVSJmgcDElmFMPhsB2AXC6nJ0+eKBAIaG9vz55tLpe78v7FYjFNJhNFo1ENh0M1Gg199dVXdv3FYlHNZlNfffWVKTeMBEYjGo0qFAqp1WqZIguFQgqHw0omkwoEAiZ/3W7XDhBGhTUcDlUulxUOh02BNhoNnZ2dKRgMKpPJKJFIaDgcqlQq2UGVZIe9WCzq9u3bJl+tVkvj8Vij0cieQSgUMueo0+loOBzqOt2aDg8PFQwGtbi4qMXFRaXTaTNwjUbDzhJyUCgUlMlk7LxOJhOFw2FNJhOlUilFIhFTTPF4XOPxWJFIxM6dJDO2jx8/VqFQ0Js3b/TmzRt1Oh3l83lzSCqVira3t/Xxxx+bA1KtVtVoNBSNRlWr1fT3f//3dh23b99WrVbT2dmZJCmXy5libLfbarVaJqM8f+ncmfLOw/ssf76m06mdzV6vp52dHdNLOGS1Wk3D4dD2S5JqtZrJCHt5fHysZDKpzz77TMFgUAcHByqXy0qlUsrn8xoMBhoMBubgeP0XCoW0srKiTCajWq2mk5OT35PXYDCodDqtYrGo0Wikw8ND1Wo1hcNh5fN59Xo9O1+TyUS9Xk/BYNCeeywWUzgcth//2X9sXdq4YBjYPJSfdO6VR6NRO7DhcFi5XE7hcFjj8VjdbleTycQ8DX44RJPJxIwSitMrUxRFLpdTLpfTaDQyr6Tf79tmRCIRTadT1et1O4gozV6vp0AgoHQ6rVgsplarpVqtplQqpdXVVfPq+T4MIsKEwFx18dndblfNZtOMg3RuDEejkRloBNcbF7/fPPRUKmX7EQqF7LWTyUTxeFzZbFbD4VDdbleDwcA8nGg0qmg0OiMkKK3pdGpGJRqNKhaL2f1PJhPVajWLQPCe0um0RaIY8Pv37ysYDKpcLttByWazlxbM+cWzxSHpdDr68ccfFQ6Htby8rGw2q0ajoYODA0my6Iv95JBIMkcDpRMOh2eix3g8rlqtpna7/U5lPhwOVa/XFQ6HtbCwoGKxKEk6OzuzyCaZTOr09FTlclmhUEjxeFyS7DnG43Gtra2p1+upWq0qEAhYpM19BgIBDQYD9Xo9M0LXMS5nZ2cKhUJaXFxULpfTZDJRoVBQr9fTixcv1Gw2zbMNBoPKZrPK5XL2vcgZxjgcDqter6tWq5mc4BhyNtvttlZXV/Vf/9f/tR49eqRGo6Evv/zSPOfxeKxyuax2u618Pq9/9a/+lVqtlv73//1/1/7+vumEVqulV69eKRgM6uOPP9ba2ppGo5EODg4MySB6icViJnsoYvbNOx3vuzAuePLs1Wg00snJiVqtlrLZrPL5vPr9vsrlsgaDgbLZrJLJpPr9vprNpsLhsEV7Z2dnqlarisVievz4sTmXvV5PmUxGhUJB/X5f1WrVohq+n+tZWFjQ+vq6wuGwyuWynTEcRM7k2tqaJpOJKpWKoRw4hjjcOEJESDhQOF3I8WXWpY2LD0G5kFarNQOdcCiAzBCKbrerSCRiypSDgnIKh8MKhUIKhUIqFovK5/OmxFCW/LCpeEcoDe9tr66uKhgMqtVqqd1um5JGIHhIRDF4bclkUvF43AQG70u6MDBXPdx8FhFWMBg04eAw9no9izLYcww0npc3zjxsPHUOP5/daDRMmY3HY4XDYRO2fr9vXpCPUvhvH1ki7AidJPOmJNlr8XIjkYj29/ftWWP4rgrn8Dn+QPl9xAnhejy0x79E0cFg0AwJr0XGeD9KkdcvLS2ZR+4PIN9NJIQxw2sOBALK5/P2euQHZ+ro6MgUTq/X02QysUPNPuPkANtdx7gMBgMFg0HV63WdnJyYnEvnkd/i4qJarZYajYbi8bgqlYpFVNwD+x+NRhUOh83JCAQC5tzE43GlUik1Gg2DMZ89e6ZOp6Pd3V11Oh1FIhFlMhmTR5COg4MDdTqdd579RCJhUDh7HAqFLNLpdrvq9/saDAbq9/uGFqBflpaWtLCwcGXjAozNM2QvJc1EdVwH8ofeQT+GQiFzxIgQI5GIDg4OFA6HZwyJdOGQIPPRaNT0WCAQUKvVUqVSUbvdtmtB76EPgM7Zy1AoZAYQI+l
|
2024-07-30 18:24:53 +08:00
|
|
|
"text/plain": [
|
|
|
|
"<Figure size 500x500 with 25 Axes>"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "display_data"
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"name": "stderr",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
2024-07-30 21:59:35 +08:00
|
|
|
" 50%|█████ | 100/200 [07:56<08:08, 4.88s/it]"
|
2024-07-30 18:24:53 +08:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
2024-07-30 21:59:35 +08:00
|
|
|
"Epoch: 100, iteration: 468, Discriminator Loss:array(1.01769, dtype=float32), Generator Loss: array(0.685513, dtype=float32)\n"
|
2024-07-30 18:24:53 +08:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"data": {
|
2024-07-30 21:59:35 +08:00
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9x5Nd63UejD8n55w6d6OBRiNf3ETyihSvLFGU6bJl2i4H2VX+Qzzx2APPNfLAHlmSXZYtlWVmirwkbwIuMtCNzunknNP+DdrP6nU2TgMdQH1ffb9eVSik0/vs/e73XeFZz1rLYhiGgQu5kAu5kAu5kLco1v+nb+BCLuRCLuRC/r8nF8blQi7kQi7kQt66XBiXC7mQC7mQC3nrcmFcLuRCLuRCLuSty4VxuZALuZALuZC3LhfG5UIu5EIu5ELeulwYlwu5kAu5kAt563JhXC7kQi7kQi7krcuFcbmQC7mQC7mQty72k37wo48+AgC0Wi10Oh20223U63UMh0P0+30Mh0PY7XY4HA4Mh0N0Oh0Mh0OwAYBhGPILACwWC9xut/yKxWKwWq3Y399HqVTCcDjEYDAY+RkKP+90OhEIBOB2u1EqlZDJZDAcDuFwOGC1WuF0OuFwONDtdlGv12EYBsLhMDweD7rdLnq9HgaDAVqtFgaDAbrdLvr9PiwWC6zWQ7s7HA4xHA7h9/sRCARgsViwu7t76oX+8MMPYbVaMTk5iWg0ina7jWKxiG63i1wuh3q9jmazObKmAODz+eB2u+U+hsMh2u02BoMBPB4PvF4vQqEQrl27BrfbjXw+j1qthkKhgL29PfT7/VfW0Ol0IhgMwm63w+12w+FwoFqtIp/Pw+Px4Pr16wgEAtje3sbe3h78fj9mZmbgdrvh9XrhcDhQKpWQzWbR6/XQaDQwGAxgtVrlF/dBtVpFu90eWYter3fq9fsX/+JfoN/vY3V1Fdvb2+j1emi32zAMAxaLBRaLRT47bs+ME5vNBrvdDpfLhXA4DKvVikKhgFqt9saf5f7g83K/2mw2+Hw+2O12dLtddLvdkffpcrlgt9vlHofDoXxGC59n3HOctamGx+MBANn3b/qecWKxWMZ+lu8cAPr9vlz/bYj+TrfbjQ8++ADz8/NYXV3FvXv3ZG3NP6PPsPmez7KG/+E//Af0ej38+Mc/xmeffSb7zu12Y3l5GbFYDKFQCJFIBKVSCffv30etVkOn05E173a7sFgsCAQCcDgc6PV66HQ6cLlciEajcDgcsNlssFqtcLlc8Hq9AIBut4vBYIBcLodcLie6wGKxwOFwwG634/3338c/+Af/AL1eD48ePUK5XIbNZoPNZkMul8PDhw8BAL/3e7+HpaUlNJtN1Go1lEolfPrppygWiwgGg/D7/bJGWi+22220Wi0YhnGiM3xi41IqlQAcHUgAaLfbGA6HcLlccLlconz0QW+323LIWq0WgEPlZrPZ4HQ64XQ6YRgGMpkMBoMBGo2GKHh+jzZMVCYAMBgMUKvVUK/X5V642DabTZRvvV5HuVxGv99HpVJBs9nEYDAY2ZQWiwV+v1+U4mAwEEXe6/Vgt9vhdDpPulyvCL9ra2sLOzs7IwZPGw6r1Qq3241IJCKbxmazoVaryabyer3y7zabDe12GysrK7BYLGL8+/0+3G63bBDg0FB5vV54vV5MTEzAarVie3sbpVIJbrcbS0tL8Hg8mJubg8/nQ6lUgsViQbvdxu7urqyp3W6XjWmxWOR+bt26haWlJWxvb+OXv/wlms0mnE4n3G63vP+zKsaHDx/CMAwUi0V519pR0XvOarXCZrON/LxW8PwZ4HAPtdttedZOp3PsPfBnnE4n/H6/rE2v14PVapX7oTIxGwzDMMQRMztaVITjvlN/7jzC59fvgOfJbrePPAP3JXC4nhaLRc7FOOGzAXjlucc9k81mg8PhgMViEeVl/gw/Z7PZMBwO0ev10Ov1sLGxgXw+j3K5/NrvMu+1867fo0ePMBgMkM/nxaiEQiF4PB4Eg0F4vV6USiVsbGyg3W6jWq2i3+9jZmYGiUQCpVIJOzs7sgfogFssFtFlNpsNXq8XTqcTzWYTe3t7I/fg9Xpx6dIl9Pt91Ot1DAYDMVx7e3v40Y9+JH9utVrw+/3wer2oVqtiEPb392EYBiqVipynwWAAt9uNXq+HcrmMZDKJGzduwGazIZ1Oi96s1WonPsMnNi7ValUsLpUsPS6v1wuXywWfzwe/3y+K2GKxoFKpoN1uo9lsysN5PB5RnA6HA61WC4VCAd1u9+jG/q/y5KYGMGKteThbrZZ4SoZhwGq1wm63i1fu9XrF6vN3s6dJb9Lr9cLn86Hf74uC5rVpVM1K66TCey8Wi7LptBfjcrnk/t1uNyYmJuD1euU5uYkMw4Db7YbH45Frdrtd5PP5EcXldrvh8/lgsVgkegmHw4jFYggGg1hYWIBhGEin06hWq/D7/ZidnYXP58Pk5CTcbrco0G63i0qlAuDQc7Tb7SMeVzAYRCgUwu3bt/F7v/d7+PTTT/Hzn/8czWYT8XgcHo8HjUZDIo2zyOrqKgCMKD/KOONCxcXPmfcIr8F9oPfeOOF30HkJBAKw2WzybrgvAch7ogeq73cwGMh9UPg58zPx586rFCk0DOPeAZW4GWHQ98d7HyfaidFijoy4htznvK/jjAt1Sb/fH1Gir5Nx66Xf31nl5cuXGAwGqFarACDRBs+K2+3G/v4+VldXZU/Y7XZMTExgeXkZu7u7KJfLaLVaog95T8PhEK1WSwyq1WpFrVZDOp2W69hsNiwuLmJ6ehqdTgf5fB6dTkfQjmw2i0KhIIaq3+8jFoshHA4LOmMYBnK5HLrdLorFIjKZjBhKl8uFZrOJdruN6elpLC8vw+VyweFwwO12o1KpHOsEjZMTG5fXvZhutyuHgL/TuHQ6HTm43Ew6rKLV1lGJ3pDj4DTDMNDpdGC1WkWhasPDw03jwH+nR0vlow8TPVjDMGC32+Hz+eQw8Z7epIBeJ9VqFYZhyDUYuQEQL4aHiV5Fu92WtdB/pqHRnp0+tMChhxOPx2UzcRPabDbxzrkWiUQCdrtdILVOpwOHw4FisQgAogi0otYGnxHp6uoqXC4X1tbWYLVa4ff7MTk5iUgkgoODAzQajTd6tW8S3rP2os1Kme/eDJWNU+D8NzNMdJzw2s1mE1arVRwWvUc15MW/m+EkfVbMyvxtGZNx937cv+t71c+g71Eb0Ddd83Wf0XAgzz/F/M60Z35S0XrErE/Os//cbjcGg4FE44QB9bWdTifC4bA4j1arFa1WS6BvRmxer1ecGp59Oq482/zd6/Xi6tWrCIfDgk60Wi3UajWB9w3DgM/nQzQaFZSg0+mg1WrJdy4uLorjDWDE8aHzxb3c6/VQr9fR6XRQKpVQKBTQbrdfgbhfJyc2LtpjNyv6RqMBi8WCRqMhHiMVHi2uw+FAOBwWy9/tdsVb1x43MWx94HV4zhCZ4ZnejLwfHnwqXm5iekGEd/x+P3q9nngAVPzRaBTT09MSBdXrdfR6PfFYziL7+/uHC/5/4QdGVoy+er2eRE6GYeDg4AAAJMLjBgQgUZXP55MNznXhYYpEIrh27Rra7Tby+bxslHK5DKvVio2NDdjtdszMzODq1avI5XJYX18f+y4cDofcF40P17vf76NQKMBisaBUKuEXv/iFrHUwGMSdO3cwOzuLBw8eIJvNjsXHTyIaknI6nZJz4b1qB0JDYNrwcg/zAGnIZdz+G2cwgENnqlAoyL9TzFGV2TkxK07+fVy+hfdt/o7fllCRjfu+18F34/IZ437W/Hedl9HXp8PH6/Z6vbF5w+NyP1oI59GR7Ha7p1KOZgmFQjAMA+12W2BnrgffdSAQwNzcnOiLwWCAcrmMWq0m+WBtlIrFIur1+sgeabVaArfabDbEYjH803/6T7G8vIxf/OIX+OSTTySq6fV6cLvdcDqdiMViuHPnDmq1GrLZLCqVCqrVKur1Oi5duoQ/+IM/gMfjwYsXL5DL5UQXaceaEVez2ZSoZmtrC5lM5tT78MTGhZGIGQM2bzh6jOZDxcOkDQd
|
2024-07-30 18:24:53 +08:00
|
|
|
"text/plain": [
|
|
|
|
"<Figure size 500x500 with 25 Axes>"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "display_data"
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"name": "stderr",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
2024-07-30 21:59:35 +08:00
|
|
|
" 75%|███████▌ | 150/200 [12:03<03:53, 4.68s/it]"
|
2024-07-30 18:24:53 +08:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
2024-07-30 21:59:35 +08:00
|
|
|
"Epoch: 150, iteration: 468, Discriminator Loss:array(1.0305, dtype=float32), Generator Loss: array(0.684713, dtype=float32)\n"
|
2024-07-30 18:24:53 +08:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"data": {
|
2024-07-30 21:59:35 +08:00
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9V2/keXYe/FTOOTHHZufpnqDRBs9qZ70rS14JkmHYsGHYgAXBgO/8IfwFfGX4zgZ8IcMWbMmytZIhK+zu7ISdns7sbrKbsarIyjmn94J4Dk/9p7qHLLbwAu/LAxDsJov/8AsnPOc552cajUYjXMqlXMqlXMqlvEUx/7/9AJdyKZdyKZfy/z25NC6XcimXcimX8tbl0rhcyqVcyqVcyluXS+NyKZdyKZdyKW9dLo3LpVzKpVzKpbx1uTQul3Ipl3Ipl/LW5dK4XMqlXMqlXMpbl0vjcimXcimXcilvXS6Ny6VcyqVcyqW8dbGe9YM2mw0AMBgMwKJ+s/nENg2Hw4l/YzKZ4PP54Ha70el0UKvVMBwOMRqNMBqNYDab5Rq8Jn9vvA4AuN1uuFwu9Ho91Ot1DAaD87yrXIei72MymWCz2WC1WjEcDtHv9zEajSY+zzRNDW7cuIHRaIR6vY5Wq4Ver4dWq4XRaASHwwGLxSL3Gg6H6Ha7AIBwOAy/349Wq4VKpYLBYIB+v4/hcAiv1wufz4der4dKpYJ+vw+z2QyTyQSTySRjOxgMxubIYrHIu/p8PjidTgAnYz8cDtHpdDAYDNDtdtHtduXzANDpdNDr9WC32+F0OjEYDFCv1zEcDrGwsICZmRmUy2Xs7+/LO4xGI3kmAOj3++cev9/93d+FyWSC1WqFxWKR8XM4HHjnnXcQj8fR6/XQ7XbR6/XQbDble7vdRrPZRKlUgt1ux7Vr1xAMBvHixQtsbm4COFnLdrsdS0tLiMViqNVqyOVysFqtmJmZgdvtRqlUQqlUQrvdRj6fBwAsLS0hEomgUqkgm83CZDLB6/XCZrMhEAjA7/ejVqshlUrJsw0GA8TjcczPz8NsNmM4HGIwGODLL7/E48ePMTs7i/feew8+nw9zc3Pw+Xwy7wDwb//tvz33+AFAJBLBaDRCu91Gp9OBxWKB3W6XOeEa4VyZTCZYLBasrq5ibm4O2WwW29vb6PV68jmLxQKr1YrRaCS6gfua82G1WhGJROBwONBsNtFoNGCxWGTd1et1tNttWbtms1n2hNVqhdVqRb/fR7vdBgB4PB44HA7UajWUy2VYrVaEw2HY7Xb0ej30ej3ZwyaTCW63GzabTT5P/XNeeeeddwAAVqsVdrsdnU4H9Xpd9l+320U0GkUikUC73ZY9sLy8jHg8jlqthqOjIxmXwWCA5eVlrKysoNVqIZ1Oo9vtyh7jWHBste50Op2IRCJwOp1YWVlBJBLBzs4OHjx4AJPJhFgsJuPdarVkvDk3JpNpbL+Uy2UMh0PcunUL165dQ7FYxM7OjqyVfr+Pfr8ve7rRaHzjeJ3ZuBgNiF6YHCgtVCR8sNFoJIuw1+vJIHGS+XmTyTRx4k0mE7rdrihA4/MYDQd/xutx0XNhdrtdtNtt+Tnfsdvtjj3b2+qOUygUAEAmczgcyvNx4lwuFzweD3q9Hmq1Gvr9PjqdDiqVCjweD65evYp+v49kMolarSZGlu9gs9nkmtxcwInitFqtMibhcBi3bt2C0+lErVZDu91GrVZDqVTCYDCQ56Oy4PPRADscDgwGA1lgdrsdZrMZzWZTlCgXs91uh8VikWtMK+l0GhaLBbOzswiFQiiXyygWi7DZbKhUKnC5XEgmk9jb24PH48H6+jq8Xi/a7Tb6/T6q1SoODg5gtVrhdDpRrVYxGo2wuLiIfr+PZrOJ0WiERqOBwWAAs9kMv98PACiVSigWi7BarQgEAhgOh6jVauh2u/D7/bDZbGg2m7IH6vU6TCYT2u02KpUK7HY7VlZWMBqNkM1m5fe8J/fPwsICEokELBYLzGYzOp0OisUiGo0GPB4PfD7f1OMHAHfv3sVwOMTOzg5SqRQ8Hg9mZmZgsVhQrVZFkTSbTTGOdrsdbrdblBK/qOycTie8Xi/6/b6MHfenxWKBx+ORd6WS83g84sRwj3F9Aie6xeVywWazod/vYzAYwG63IxAIiLGhM1av18f2r3ZSqWd4zU6nM1FPnFe63e6Yg6gNM/Viv9+HzWaTnw+HQ3g8HqytraHb7SKZTIqjWSgU5Jp0EF0ul4w118hwOITP54PP54Pf78fi4qKs/1QqBafTie9+97uw2WyIRCKw2+1IJpM4OjpCs9lEoVDAcDhEKBSCy+VCq9US4zw7Owun04m1tTUsLS3J2LbbbcRiMbjdblSrVZRKpdcGE0Y5s3EBxj12KjMA4slQ9AR2Op2TG1mtcDgcoqz0NbVXy7+fpNT7/f4b76X/zetycobDoRgXj8cjyoCKmcrgIgrwTVKtVuW5jAaVG5ILp9PpSFRDYxQIBLC0tIRer4disYh6vS4bjx69NtA0Dvwd3xEAgsEgbt++DZ/Ph52dHWSz2TEPRUdQ/N7r9WTRu1yuMW/T5/PBarVKdMW1QUVis9nQ6XTEmE8j+XweVqsViUQCHo8H1WoVlUoFFosF9XodzWYTyWQS9+/fRyKRwPr6uihFKr7j42OYzWYEAgF0u11RroxEuMGbzSZ8Ph+i0SgGgwGOjo7QarUQi8UQDAbRaDREWTYaDYnMqVSpCJrNJkwmExKJBK5fvw6r1YputyuGut1ui5GmoZudnUW1WkUymUSv10O1WoXFYoHFYkE4HBZFOo3QOSkUCkin03C5XJidnYXZbIbFYkGj0UClUkGj0YDZbIbX6xUHQUfD3J/D4RBWq1XQhFarNeZh00jQGPf7ffj9fnGguAd1lM396HA4YLfbRYm7XC5Eo9ExI9RqtcYUsN5bRieWe4L7YxrhM7bbbdEVet7152gA+bzD4RAulwuRSATtdhuFQgGNRgOdTgflcnnsWi6XCw6HQ+aF70DDMDc3h2AwiI2NDZjNZnz11VfI5XJYXFzEjRs34HK5EAqFYLPZYLPZ0Ov1kMvlZA58Pp84fLVaDR6PB9FoVK49MzODQqGAVquFdrsNr9eLWCwGs9ksa/UscmbjoieN36mILRbLmELjZBr/XnvS+uccOJPJBIfDAavVil6vJ56Nvu/rnmuSgdJRi1aSnU5nbJPbbLYxiE6L0WBNKxqi4sbjptUeV6PREKOhF2yr1UI2m5UQ2Wazwel0wu12o9fridfIueDi1mNBKGwwGGB3dxculwvtdhtOp1PgQMJ02tjyOWjsgHFoi5Gpfmb+jM4APdlpx3BlZUWiLpfLBb/fj5mZGZhMJtjtdtmU0WgUgUBAnpWKKhAIYGFhYWycqdTMZjPi8TgAoFwui1FgJOL3++FyuWCxWCTS4xgR1qJBJZyolZrZbIbT6ZR1xrWulaL+GZXSYDBAu93GcDiEw+GA2+2+kOedTqfR7/fRarUAnDh+hP54Dzoio9FobP8RzmKUqt+DUBTnn9fgdwAy99yDGloDTtaWy+WCz+eTaL7X68n6oxLWkQvXItc7FSkjQd6bz8F5uogQEeh2u6KvgNP1DpzCfHz+ZrMJs9ksz9vr9cTxCoVCiEajaDabArFxrLxeL4LBoMxZt9sVOMxmsyGTyWAwGGAwGMDr9aLT6WBnZwc2mw3BYBA2mw2pVAqlUgmtVkscUEbU3W5X4EemLPb399HpdHB0dCRQMXU34ce/NeNCoQKkN2u1WsXjnvR5LjwqOYa8nBR6LH6/Hz6fTzBvvUnf9FIaJ+YXB57PMxwOxVugkWHegV7l6657UXiMUR6VMO+rIwqGyDrnw/sWi0UJ6wlHJBIJzM/Po1Kp4Pnz5xJdmEwmOJ1OiRSpoDweD7xeL7rdLv7mb/4GZrMZd+7cwcLCAorFohgMr9cLu92OQqGAer0u4wCcKJlmsyk/4/X5bhxnhvcARCFwDKaR73znO2MeLdcQldJgMBBvzul0Yjgcot1uj+WVwuEwOp0Okskk8vm8eKDhcBjvvPMOnE4nnj17hnQ6LTg/IQO
|
2024-07-30 18:24:53 +08:00
|
|
|
"text/plain": [
|
|
|
|
"<Figure size 500x500 with 25 Axes>"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "display_data"
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"name": "stderr",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
2024-07-30 21:59:35 +08:00
|
|
|
"100%|██████████| 200/200 [16:04<00:00, 4.82s/it]\n"
|
2024-07-29 00:18:35 +08:00
|
|
|
]
|
2024-07-26 21:07:40 +08:00
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
2024-07-30 00:44:16 +08:00
|
|
|
"# Set your parameters\n",
|
2024-07-30 18:45:09 +08:00
|
|
|
"n_epochs = 200\n",
|
2024-07-30 00:44:16 +08:00
|
|
|
"display_step = 5000\n",
|
2024-07-28 06:10:19 +08:00
|
|
|
"cur_step = 0\n",
|
2024-07-30 00:44:16 +08:00
|
|
|
"\n",
|
2024-07-30 07:44:41 +08:00
|
|
|
"batch_size = 128\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
"\n",
|
2024-07-28 06:10:19 +08:00
|
|
|
"D_loss_grad = nn.value_and_grad(disc, disc_loss)\n",
|
|
|
|
"G_loss_grad = nn.value_and_grad(gen, gen_loss)\n",
|
|
|
|
"\n",
|
|
|
|
"\n",
|
2024-07-30 00:44:16 +08:00
|
|
|
"for epoch in tqdm(range(n_epochs)):\n",
|
2024-07-28 06:10:19 +08:00
|
|
|
"\n",
|
2024-07-30 18:21:38 +08:00
|
|
|
" for idx,real in enumerate(batch_iterate(batch_size, train_images)):\n",
|
2024-07-28 06:10:19 +08:00
|
|
|
" \n",
|
2024-07-30 00:44:16 +08:00
|
|
|
" # TODO Train Discriminator\n",
|
2024-07-30 07:06:52 +08:00
|
|
|
" D_loss,D_grads = D_loss_grad(gen, disc,mx.array(real), batch_size, z_dim)\n",
|
2024-07-28 06:10:19 +08:00
|
|
|
"\n",
|
|
|
|
" # Update optimizer\n",
|
|
|
|
" disc_opt.update(disc, D_grads)\n",
|
|
|
|
" \n",
|
|
|
|
" # Update gradients\n",
|
|
|
|
" mx.eval(disc.parameters(), disc_opt.state)\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
"\n",
|
2024-07-30 00:44:16 +08:00
|
|
|
" # TODO Train Generator\n",
|
|
|
|
" G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)\n",
|
2024-07-28 06:10:19 +08:00
|
|
|
" \n",
|
|
|
|
" # Update optimizer\n",
|
|
|
|
" gen_opt.update(gen, G_grads)\n",
|
|
|
|
" \n",
|
|
|
|
" # Update gradients\n",
|
2024-07-30 07:37:09 +08:00
|
|
|
" mx.eval(gen.parameters(), gen_opt.state) \n",
|
2024-07-30 07:44:41 +08:00
|
|
|
" \n",
|
2024-07-30 18:21:38 +08:00
|
|
|
" # if (cur_step + 1) % display_step == 0:\n",
|
|
|
|
" # print(f\"Step {epoch}: Generator loss: {G_loss}, discriminator loss: {D_loss}\")\n",
|
|
|
|
" # fake_noise = mx.array(get_noise(batch_size, z_dim))\n",
|
|
|
|
" # fake = gen(fake_noise)\n",
|
|
|
|
" # show_images(fake)\n",
|
|
|
|
" # show_images(real)\n",
|
|
|
|
" # cur_step += 1\n",
|
|
|
|
" \n",
|
2024-07-30 18:45:09 +08:00
|
|
|
" if epoch%50==0:\n",
|
2024-07-30 18:21:38 +08:00
|
|
|
" print(\"Epoch: {}, iteration: {}, Discriminator Loss:{}, Generator Loss: {}\".format(epoch,idx,D_loss,G_loss))\n",
|
2024-07-30 07:44:41 +08:00
|
|
|
" fake_noise = mx.array(get_noise(batch_size, z_dim))\n",
|
|
|
|
" fake = gen(fake_noise)\n",
|
|
|
|
" show_images(fake)\n",
|
2024-07-30 18:45:09 +08:00
|
|
|
" # show_images(real) likjmnh jy,t\n",
|
2024-07-30 18:21:38 +08:00
|
|
|
" \n",
|
|
|
|
" # print('Losses D={0} G={1}'.format(D_loss,G_loss))"
|
2024-07-27 05:19:08 +08:00
|
|
|
]
|
2024-07-26 21:07:40 +08:00
|
|
|
}
|
|
|
|
],
|
|
|
|
"metadata": {
|
|
|
|
"kernelspec": {
|
|
|
|
"display_name": "base",
|
|
|
|
"language": "python",
|
|
|
|
"name": "python3"
|
|
|
|
},
|
|
|
|
"language_info": {
|
|
|
|
"codemirror_mode": {
|
|
|
|
"name": "ipython",
|
|
|
|
"version": 3
|
|
|
|
},
|
|
|
|
"file_extension": ".py",
|
|
|
|
"mimetype": "text/x-python",
|
|
|
|
"name": "python",
|
|
|
|
"nbconvert_exporter": "python",
|
|
|
|
"pygments_lexer": "ipython3",
|
|
|
|
"version": "3.10.10"
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"nbformat": 4,
|
|
|
|
"nbformat_minor": 2
|
|
|
|
}
|