mlx-examples/gan/playground.ipynb

579 lines
323 KiB
Plaintext
Raw Normal View History

2024-07-26 21:07:40 +08:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Import Library"
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 517,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"import mnist"
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 518,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"import mlx.core as mx\n",
"import mlx.nn as nn\n",
"import mlx.optimizers as optim\n",
"\n",
"from tqdm import tqdm\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
2024-07-30 07:17:12 +08:00
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 519,
2024-07-30 07:17:12 +08:00
"metadata": {},
"outputs": [],
"source": [
"# mx.set_default_device(mx.gpu)"
]
},
2024-07-26 21:07:40 +08:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# GAN Architecture"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generator 👨🏻‍🎨"
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 520,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def GenBlock(in_dim:int,out_dim:int):\n",
" \n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
2024-07-30 07:06:52 +08:00
" nn.BatchNorm(out_dim, 0.8),\n",
" nn.LeakyReLU(0.2)\n",
2024-07-26 21:07:40 +08:00
" )"
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 521,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"class Generator(nn.Module):\n",
"\n",
2024-07-30 18:21:38 +08:00
" def __init__(self, z_dim:int = 10, im_dim:int = 784, hidden_dim: int = 64):\n",
2024-07-26 21:07:40 +08:00
" super(Generator, self).__init__()\n",
2024-07-30 18:21:38 +08:00
"\n",
2024-07-26 21:07:40 +08:00
" self.gen = nn.Sequential(\n",
" GenBlock(z_dim, hidden_dim),\n",
" GenBlock(hidden_dim, hidden_dim * 2),\n",
" GenBlock(hidden_dim * 2, hidden_dim * 4),\n",
"\n",
2024-07-30 07:56:13 +08:00
" nn.Linear(hidden_dim * 4,im_dim),\n",
2024-07-26 21:07:40 +08:00
" )\n",
" \n",
" def __call__(self, noise):\n",
2024-07-30 07:06:52 +08:00
" x = self.gen(noise)\n",
" return mx.tanh(x)"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 522,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Generator(\n",
" (gen): Sequential(\n",
" (layers.0): Sequential(\n",
2024-07-30 18:21:38 +08:00
" (layers.0): Linear(input_dims=100, output_dims=64, bias=True)\n",
" (layers.1): BatchNorm(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
2024-07-30 07:06:52 +08:00
" (layers.2): LeakyReLU()\n",
2024-07-26 21:07:40 +08:00
" )\n",
" (layers.1): Sequential(\n",
2024-07-30 18:21:38 +08:00
" (layers.0): Linear(input_dims=64, output_dims=128, bias=True)\n",
" (layers.1): BatchNorm(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
2024-07-30 07:06:52 +08:00
" (layers.2): LeakyReLU()\n",
2024-07-26 21:07:40 +08:00
" )\n",
" (layers.2): Sequential(\n",
2024-07-30 18:21:38 +08:00
" (layers.0): Linear(input_dims=128, output_dims=256, bias=True)\n",
" (layers.1): BatchNorm(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
2024-07-30 07:06:52 +08:00
" (layers.2): LeakyReLU()\n",
2024-07-26 21:07:40 +08:00
" )\n",
2024-07-30 18:21:38 +08:00
" (layers.3): Linear(input_dims=256, output_dims=784, bias=True)\n",
2024-07-26 21:07:40 +08:00
" )\n",
")"
]
},
2024-07-30 18:21:38 +08:00
"execution_count": 522,
2024-07-26 21:07:40 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gen = Generator(100)\n",
"gen"
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 523,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-30 00:44:16 +08:00
"# make 2D noise with shape n_samples x z_dim\n",
"def get_noise(n_samples:list[int], z_dim:int)->list[int]:\n",
" return mx.random.normal(shape=(n_samples, z_dim))"
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 524,
2024-07-30 00:44:16 +08:00
"metadata": {},
"outputs": [
{
"data": {
2024-07-30 18:21:38 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAWOElEQVR4nO3ce2zV9f3H8VddsaXcxv2mrGVcSikyAUHaiaaAMiZzEDux27Lhls0lI8HpnMuyq0u2bMsuDHfLInO6i7JBcNqJdhDRtdOWKYJcC2JhtCW2lYLY0nLO7793suyPntcnmb9ffnk+/j7Pc5Ce8vL7zzsvm81mBQCApMv+t/8AAID/OxgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAAhPxcX/jwww/bb97V1WU3DzzwgN1I0vz58+2moqLCbk6fPm03L7/8st1MnTrVbiRp/PjxdlNSUmI3tbW1djNkyBC7kaRBgwbZTX5+zl/tcPz4cbsZPXq03RQWFtqNJHV0dNjN9OnT7SYvL89uWlpa7CblZyRJ8+bNs5t//OMfdrNs2TK72bJli91I0vLly+1m1KhRdvPJT35ywNfwpAAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAABCzhepCgoK7DdPOeA1ePBgu5GkTCZjN48++qjdnDhxwm6+9KUv2c3+/fvtRpJ6e3vtZt++fXaTzWbt5vbbb7cbKe0YY8p3L+UwYMoBwr6+PruRpBtvvNFuvv/979vNpz71KbtJOX75rne9y26ktAOJM2fOtJuUQ5ZLly61G0l6/fXX7Sbldz0XPCkAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAkPNBvKqqKvvNL1y4YDfXXHON3UhpR8a+/e1v282uXbvspqmpyW4mTpxoN1LasbA//elPdlNTU2M3mzdvthtJGjJkiN0sXLjQblKOkqV8xysqKuxGkh5//HG7STkw2dbWZjcPPfSQ3XzrW9+yG0m655577GbZsmV2M3nyZLtJOcQopR1WbGlpSfqsgfCkAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIOV9J3bp1q/3m9fX1dpN6HTTlKubJkyft5s9//vM78jmf+9zn7EaSent77Wb58uV2U1BQYDerVq2yG0lqbGy0m87OTrsZNmyY3ZSUlNhNe3u73UjSmTNn7OaLX/yi3Tz99NN2s2HDBrs5ffq03UhSeXm53aRcPE25Zvvggw/ajSStX7/ebrq7u5M+ayA8KQAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAICQl81ms7m88Ac/+IH95q+++qrdLFmyxG4k6ciRI3aTciQrxaOPPmo31dXVSZ81atQou3n++eftpr+/325SDs5J0syZM+3m+PHjdtPV1WU358+ft5tUdXV1drN58+b/wp/kP/X09NhNyqFIKe07/sILL9jN6NGj7ebFF1+0G0kqLCy0m+uvv95uvvCFLwz4Gp4UAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQMj5IN4vf/lL+83Lysrs5t5777UbSbp48aLdVFRU2M2UKVPs5vLLL7eb5uZmu5HSjrrNmTPHbm644Qa7ueeee+xGkqZNm2Y3VVVVdrNr1y67mTBhgt2MGDHCbiTpwIEDdtPZ2Wk3p06dspuxY8fazQc/+EG7kdK+rym/g08//bTdpB5IrKystJuHHnrIbh5//PEBX8OTAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAj5ub7w7Nmz9punNCmH1qS0Q3V9fX12k3KoLuUYV3Fxsd1IUlFRkd0sWLDAbu6//367KS8vtxtJysvLs5umpia7STngWFdXZzcrV660G0nq6Oiwm5tuuslu3nrrLbtJOYj32muv2Y0ktbW12c0jjzxiN3fccYfdtLS02I0kZTIZu0k5FJkLnhQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAyPkg3uLFi+03X7dund2sWrXKbqS0g1x79uyxm3379tlNTU2N3fT09NiNlHbE65vf/KbdtLe3282QIUPsRpIqKyvtJuUIYW1trd1cdpn//1UpRxUlac2aNXaT8nN66aWX7CblmOA///lPu5GkRYsW2U1FRYXdvPHGG3YzdOhQu5HSvkezZ89O+qyB8KQAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAg5X0ltamqy33zZsmV2s3btWruRpJ/97Gd2M3PmTLtZuXKl3Tz22GN2k81m7UaSjh49ajfjxo2zm5RrrF/5ylfsRkq7BtnY2Gg3KZeABw8ebDc7d+60G0k6e/as3RQVFdnNe97zHrvp6Oiwm5tuusluJKm1tdVu+vr67Gb06NF28+qrr9qNJJ0/f95u+vv7kz5rIDwpAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgJDzQbze3l77zZcuXWo3DzzwgN1I0rp16+xm4sSJdrNx40a7GTlypN3cdtttdiNJ+/fvt5vJkyfbzdatW+0m5bCdJF133XVJnSuTydjN7t277Sb1ENy0adOSOte2bdvs5tOf/rTdHDx40G4kacSIEXZTWlpqN7/73e/spqqqym4kafjw4XazZcuWpM8aCE8KAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIORls9lsLi+888477TefN2+e3fztb3+zG0nq6uqym5SDeIMHD7ab5557zm6qq6vtRpIWL15sN4cPH7abxsZGu5k0aZLdSNK5c+fekebqq6+2m7y8PLtpaWmxG0lqa2uzm5KSErtJOVSXclTx9OnTdiNJ48ePt5umpia7KS8vt5tLly7ZjSQVFBTYTUdHh93kcsiSJwUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQ8nN9Ycqhp4aGBruZPXu23UjSlClT7ObXv/613VRVVdlNUVGR3Rw5csRuJCk/P+cfaUg5rDV27Fi7GTNmjN1IaQfazp49azeLFi2ym/7+frtpbW21G0nK8Xblvxk+fLjdpBx9TDlu96Mf/chuJGnnzp12U1FRYTfHjh2zm0wmYzeS9MYbb9jNihUrkj5rIDwpAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAABCXjbH04tPPPGE/ebbt2+3m5Trm5J05syZd+SzUi59plxw7erqshsp7b9p0KBBdvOTn/zEbsrKyuxGkk6ePGk3w4YNs5vJkyfbzYIFC+zmwIEDdiNJLS0tdnPttdfaTXd3t9288MILdjNy5Ei7kaSenh67KS0ttZslS5bYTV1dnd1I0o4dO+zma1/7mt0sX758wNfwpAAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAABCzgfxqqqq7DdPOQT397//3W4kqbm52W7uvfdeu7n11lvt5qWXXrKbV155xW4kaerUqXaTyWTekc/56U9/ajeSNGvWLLuZMWOG3Tz55JN2s27dOrupr6+3G0kaP3683bz++ut2c+LECbu5++677Wbjxo12I0nV1dV2c/HiRbvZtGmT3dx22212I6X9vvf399vNtm3bBnwNTwoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAg5Of6wgULFthvPmrUKLvZu3ev3UjSr371K7s5cOCA3aQ
2024-07-30 00:44:16 +08:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"img = get_noise(28,28)\n",
"plt.imshow(img, cmap='gray')\n",
"plt.axis('off')\n",
"plt.show()"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Discriminator 🕵🏻‍♂️"
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 525,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def DisBlock(in_dim:int,out_dim:int):\n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
2024-07-30 00:44:16 +08:00
" nn.LeakyReLU(negative_slope=0.2),\n",
2024-07-30 18:21:38 +08:00
" nn.Dropout(0.3),\n",
2024-07-26 21:07:40 +08:00
" )"
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 526,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"class Discriminator(nn.Module):\n",
"\n",
2024-07-30 18:21:38 +08:00
" def __init__(self,im_dim:int = 784, hidden_dim:int = 64):\n",
2024-07-26 21:07:40 +08:00
" super(Discriminator, self).__init__()\n",
"\n",
" self.disc = nn.Sequential(\n",
2024-07-30 07:37:09 +08:00
" DisBlock(im_dim, hidden_dim * 4),\n",
" DisBlock(hidden_dim * 4, hidden_dim * 2),\n",
2024-07-26 21:07:40 +08:00
" DisBlock(hidden_dim * 2, hidden_dim),\n",
2024-07-30 07:37:09 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" nn.Linear(hidden_dim,1),\n",
2024-07-30 18:21:38 +08:00
" nn.Sigmoid()\n",
2024-07-26 21:07:40 +08:00
" )\n",
" \n",
2024-07-30 18:21:38 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" def __call__(self, noise):\n",
2024-07-30 18:21:38 +08:00
" return self.disc(noise)"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 527,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Discriminator(\n",
" (disc): Sequential(\n",
" (layers.0): Sequential(\n",
2024-07-30 18:21:38 +08:00
" (layers.0): Linear(input_dims=784, output_dims=256, bias=True)\n",
2024-07-26 21:07:40 +08:00
" (layers.1): LeakyReLU()\n",
2024-07-30 18:21:38 +08:00
" (layers.2): Dropout(p=0.30000000000000004)\n",
2024-07-26 21:07:40 +08:00
" )\n",
" (layers.1): Sequential(\n",
2024-07-30 18:21:38 +08:00
" (layers.0): Linear(input_dims=256, output_dims=128, bias=True)\n",
2024-07-30 07:37:09 +08:00
" (layers.1): LeakyReLU()\n",
2024-07-30 18:21:38 +08:00
" (layers.2): Dropout(p=0.30000000000000004)\n",
2024-07-30 07:37:09 +08:00
" )\n",
" (layers.2): Sequential(\n",
2024-07-30 18:21:38 +08:00
" (layers.0): Linear(input_dims=128, output_dims=64, bias=True)\n",
2024-07-26 21:07:40 +08:00
" (layers.1): LeakyReLU()\n",
2024-07-30 18:21:38 +08:00
" (layers.2): Dropout(p=0.30000000000000004)\n",
2024-07-26 21:07:40 +08:00
" )\n",
2024-07-30 18:21:38 +08:00
" (layers.3): Linear(input_dims=64, output_dims=1, bias=True)\n",
" (layers.4): Sigmoid()\n",
2024-07-26 21:07:40 +08:00
" )\n",
")"
]
},
2024-07-30 18:21:38 +08:00
"execution_count": 527,
2024-07-26 21:07:40 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"disc = Discriminator()\n",
"disc"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model Training 🏋🏻‍♂️"
]
},
{
2024-07-30 00:44:16 +08:00
"cell_type": "markdown",
2024-07-26 21:07:40 +08:00
"metadata": {},
"source": [
2024-07-30 00:44:16 +08:00
"### Losses"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-07-30 00:44:16 +08:00
"#### Discriminator Loss"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 528,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-26 21:36:29 +08:00
"def disc_loss(gen, disc, real, num_images, z_dim):\n",
2024-07-30 00:44:16 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" noise = mx.array(get_noise(num_images, z_dim))\n",
" fake_images = gen(noise)\n",
2024-07-27 06:09:51 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" fake_disc = disc(fake_images)\n",
" \n",
2024-07-27 05:19:08 +08:00
" fake_labels = mx.zeros((fake_images.shape[0],1))\n",
2024-07-30 18:21:38 +08:00
" \n",
" fake_loss = mx.mean(nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True))\n",
2024-07-30 00:44:16 +08:00
" \n",
2024-07-30 18:21:38 +08:00
" real_disc = mx.array(disc(real))\n",
2024-07-27 05:19:08 +08:00
" real_labels = mx.ones((real.shape[0],1))\n",
2024-07-30 18:21:38 +08:00
" \n",
" real_loss = mx.mean(nn.losses.binary_cross_entropy(real_disc,real_labels,with_logits=True))\n",
" \n",
" disc_loss = (fake_loss + real_loss)\n",
2024-07-26 21:07:40 +08:00
"\n",
" return disc_loss"
]
},
2024-07-30 00:44:16 +08:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Generator Loss"
]
},
2024-07-26 21:07:40 +08:00
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 529,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-26 21:36:29 +08:00
"def gen_loss(gen, disc, num_images, z_dim):\n",
2024-07-26 21:07:40 +08:00
"\n",
" noise = mx.array(get_noise(num_images, z_dim))\n",
2024-07-30 07:06:52 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" fake_images = gen(noise)\n",
2024-07-30 18:21:38 +08:00
" fake_disc = mx.array(disc(fake_images))\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-27 05:19:08 +08:00
" fake_labels = mx.ones((fake_images.shape[0],1))\n",
2024-07-30 00:44:16 +08:00
" \n",
2024-07-30 18:21:38 +08:00
" gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)\n",
" \n",
" return mx.mean(gen_loss)"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 530,
2024-07-30 00:44:16 +08:00
"metadata": {},
"outputs": [],
"source": [
"# Get only the training images\n",
2024-07-30 18:21:38 +08:00
"train_images,train_labels,*_ = map(np.array, mnist.mnist())"
2024-07-30 00:44:16 +08:00
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 531,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-30 00:44:16 +08:00
"# Normalize the images to fall between -1,1\n",
"train_images = train_images * 2.0 - 1.0"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 532,
2024-07-29 06:24:50 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2024-07-30 18:21:38 +08:00
"<matplotlib.image.AxesImage at 0x15ad80190>"
2024-07-29 06:24:50 +08:00
]
},
2024-07-30 18:21:38 +08:00
"execution_count": 532,
2024-07-29 06:24:50 +08:00
"metadata": {},
"output_type": "execute_result"
2024-07-30 00:44:16 +08:00
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaEAAAGdCAYAAAC7EMwUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAaI0lEQVR4nO3df2jU9x3H8dfVH1d1lytBk7vUmGVF202dpWrVYP3R1cxApf4oWMtGZEPa+YOJ/cGsDNNBjdgpRdI6V0amW239Y9a6KdUMTXRkijpdRYtYjDOdCcFM72LUSMxnf4hHz1j1e975vkueD/iCufu+vY/ffuvTby75xueccwIAwMBD1gsAAHRfRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJjpab2AW3V0dOjcuXMKBALy+XzWywEAeOScU0tLi/Ly8vTQQ3e+1km7CJ07d075+fnWywAA3Kf6+noNHDjwjvuk3afjAoGA9RIAAElwL3+fpyxCH3zwgQoLC/Xwww9r5MiR2rdv3z3N8Sk4AOga7uXv85REaPPmzVq8eLGWLVumI0eO6JlnnlFJSYnOnj2bipcDAGQoXyruoj1mzBg99dRTWrduXeyx73//+5o+fbrKy8vvOBuNRhUMBpO9JADAAxaJRJSVlXXHfZJ+JXTt2jUdPnxYxcXFcY8XFxertra20/5tbW2KRqNxGwCge0h6hM6fP6/r168rNzc37vHc3Fw1NjZ22r+8vFzBYDC28ZVxANB9pOwLE259Q8o5d9s3qZYuXapIJBLb6uvrU7UkAECaSfr3CfXv3189evTodNXT1NTU6epIkvx+v/x+f7KXAQDIAEm/Eurdu7dGjhypqqqquMerqqpUVFSU7JcDAGSwlNwxYcmSJfrpT3+qUaNGady4cfr973+vs2fP6tVXX03FywEAMlRKIjR79mw1NzfrN7/5jRoaGjRs2DDt2LFDBQUFqXg5AECGSsn3Cd0Pvk8IALoGk+8TAgDgXhEhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmelovAEgnPXr08DwTDAZTsJLkWLhwYUJzffv29Tzz+OOPe55ZsGCB55nf/va3nmfmzJnjeUaSrl696nlm5cqVnmfefvttzzNdBVdCAAAzRAgAYCbpESorK5PP54vbQqFQsl8GANAFpOQ9oaFDh+rvf/977ONEPs8OAOj6UhKhnj17cvUDALirlLwndOrUKeXl5amwsFAvvfSSTp8+/a37trW1KRqNxm0AgO4h6REaM2aMNm7cqJ07d+rDDz9UY2OjioqK1NzcfNv9y8vLFQwGY1t+fn6ylwQASFNJj1BJSYlmzZql4cOH67nnntP27dslSRs2bLjt/kuXLlUkEolt9fX1yV4SACBNpfybVfv166fhw4fr1KlTt33e7/fL7/enehkAgDSU8u8Tamtr05dffqlwOJzqlwIAZJikR+j1119XTU2N6urqdODAAb344ouKRqMqLS1N9ksBADJc0j8d9/XXX2vOnDk6f/68BgwYoLFjx2r//v0qKChI9ksBADJc0iP0ySefJPu3RJoaNGiQ55nevXt7nikqKvI8M378eM8zkvTII494npk1a1ZCr9XVfP31155n1q5d63lmxowZnmdaWlo8z0jSv//9b88zNTU1Cb1Wd8W94wAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAMz7nnLNexDdFo1EFg0HrZXQrTz75ZEJzu3fv9jzDf9vM0NHR4XnmZz/7meeZS5cueZ5JRENDQ0JzFy5c8Dxz8uTJhF6rK4pEIsrKyrrjPlwJAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwExP6wXA3tmzZxOaa25u9jzDXbRvOHDggOeZixcvep6ZPHmy5xlJunbtmueZP/3pTwm9Fro3roQAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADPcwBT63//+l9DcG2+84Xnm+eef9zxz5MgRzzNr1671PJOoo0ePep6ZMmWK55nW1lbPM0OHDvU8I0m//OUvE5oDvOJKCABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAw43POOetFfFM0GlUwGLReBlIkKyvL80xLS4vnmfXr13uekaSf//znnmd+8pOfeJ75+OOPPc8AmSYSidz1/3muhAAAZogQAMCM5wjt3btX06ZNU15ennw+n7Zu3Rr3vHNOZWVlysvLU58+fTRp0iQdP348WesFAHQhniPU2tqqESNGqKKi4rbPr1q1SmvWrFFFRYUOHjyoUCikKVOmJPR5fQBA1+b5J6uWlJSopKTkts855/Tee+9p2bJlmjlzpiRpw4YNys3N1aZNm/TKK6/c32oBAF1KUt8TqqurU2Njo4qLi2OP+f1+TZw4UbW1tbedaWtrUzQajdsAAN1DUiPU2NgoScrNzY17PDc3N/bcrcrLyxUMBmNbfn5+MpcEAEhjKfnqOJ/PF/exc67TYzctXbpUkUgkttXX16diSQCANOT5PaE7CYVCkm5cEYXD4djjTU1Nna6ObvL7/fL7/clcBgAgQyT1SqiwsFChUEhVVVWxx65du6aamhoVFRUl86UAAF2A5yuhS5cu6auvvop9XFdXp6NHjyo7O1uDBg3S4sWLtWLFCg0ePFiDBw/WihUr1LdvX7388stJXTgAIPN5jtChQ4c0efLk2MdLliyRJJWWluqPf/yj3nzzTV25ckXz58/XhQsXNGbMGO3atUuBQCB5qwYAdAncwBRd0rvvvpvQ3M1/VHlRU1Pjeea5557zPNPR0eF5BrDEDUwBAGmNCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZriLNrqkfv36JTT317/+1fPMxIkTPc+UlJR4ntm1a5fnGcASd9EGAKQ1IgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMNzAFvuGxxx7zPPOvf/3L88zFixc9z+zZs8fzzKFDhzzPSNL777/veSbN/ipBGuAGpgCAtEaEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmOEGpsB9mjFjhueZyspKzzOBQMDzTKLeeustzzMbN270PNPQ0OB5BpmDG5gCANIaEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGG5gCBoYNG+Z5Zs2aNZ5nfvSjH3meSdT69es9z7zzzjueZ/773/96noENbmAKAEhrRAgAYMZzhPbu3atp06YpLy9PPp9PW7dujXt+7ty58vl8cdvYsWOTtV4AQBfiOUKtra0aMWKEKioqvnWfqVOnqqGhIbbt2LHjvhYJAOiaenodKCkpUUlJyR338fv9CoVCCS8KANA9pOQ9oerqauXk5GjIkCGaN2+empqavnXftrY2RaPRuA0A0D0kPUIlJSX66KOPtHv3bq1evVoHDx7Us88+q7a2ttvuX15ermAwGNvy8/OTvSQAQJry/Om4u5k9e3bs18OGDdOoUaNUUFCg7du3a+bMmZ32X7p0qZYsWRL7OBqNEiI
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
2024-07-29 06:24:50 +08:00
}
],
"source": [
2024-07-30 00:44:16 +08:00
"plt.imshow(train_images[0].reshape(28,28),cmap='gray')"
2024-07-29 06:24:50 +08:00
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 533,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-30 00:44:16 +08:00
"def batch_iterate(batch_size: int, ipt: list[int])-> list[int]:\n",
2024-07-29 06:24:50 +08:00
" perm = np.random.permutation(len(ipt))\n",
" for s in range(0, len(ipt), batch_size):\n",
2024-07-26 21:07:40 +08:00
" ids = perm[s : s + batch_size]\n",
2024-07-30 00:44:16 +08:00
" yield ipt[ids]"
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 534,
2024-07-30 00:44:16 +08:00
"metadata": {},
"outputs": [],
"source": [
"def show_images(imgs:list[int],num_imgs:int = 25):\n",
" if (imgs.shape[0] > 0): \n",
" fig,axes = plt.subplots(5, 5, figsize=(5, 5))\n",
" \n",
" for i, ax in enumerate(axes.flat):\n",
" img = mx.array(imgs[i]).reshape(28,28)\n",
" ax.imshow(img,cmap='gray')\n",
" ax.axis('off')\n",
" plt.show()"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-07-30 00:44:16 +08:00
"### show first batch of images"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 535,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
2024-07-30 18:21:38 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAADWZUlEQVR4nOz915PcV3rfj7865xynpycnDgaJIAiGXZKbg2RLWrnkKgW7LN/4wv+M73xhl8t2OciutWRJ1mq/u8vd5YI5gMjA5Ng559z9u8DvHPYAAxIgJvQ0Pq+qKe5ienr6c+ac85zzhPej6vV6PRQUFBQUFA4R9Ul/AAUFBQWF4UMxLgoKCgoKh45iXBQUFBQUDh3FuCgoKCgoHDqKcVFQUFBQOHQU46KgoKCgcOgoxkVBQUFB4dBRjIuCgoKCwqGjGBcFBQUFhUNH+6QvVKlUR/k5ThVfR9RAGb8vUMbv2fi6ohrKGH6BMgefjScZP+XmoqCgoKBw6CjGRUFBQUHh0FGMi4KCgoLCoaMYFwUFBQWFQ+eJA/oKCgoPUKlUaLVa1Go1NpsNk8lEr9ej2+3S6XQolUo0m0263S7dbvekP66CwomgGBcFhafEaDQSCARwOp388R//MS+99BLNZpN6vU4ikeCv//qv2d7eJp/PUygUTvrjKiicCIpxUVB4SrRaLVarFa/Xy5UrV/jhD39IvV6nUqmwubnJhx9+SDabpVqtnvRHVVA4MZSYi4LCM/Bwvr9KpZL1EEpdhMLzjGJcFBSegV6vJ7/ggUFRq9VoNBrFuCg81yhusSFHpVJhNBrRarX4fD68Xu+Bm16lUqFQKNDpdGi1WnQ6HWq1GrVaDY1Gg8lkQq1WYzAY0Gg08uc6nQ7lcpl2u02n03kuAtidTod6vU65XKZYLJLP5wHkOI2Pj9NsNun1epRKJdrtNq1W62tX1isonEYU4zLk6HQ6QqEQDoeDP/zDP+QP/uAPUKsfXFj7XTj37t3j/fffp1wuk06nqdVqrK6usrGxgcViYXx8HIvFwsjICHa7HUBunrdu3aJQKFCpVKjVaif2rMdFo9EgGo1SqVS4f/8+IyMj+P1+JiYmCIVC/PN//s/J5/P8zd/8DdVqlUqlQiqVot1un/RHV1A4NhTjMsSIlFmbzYbL5WJiYoKzZ88eaFw6nQ67u7uUSiU0Gg2VSoV4PI5Op8NkMuH1erFarYyMjOB0OoEHxiWfz7O7u0u73abZbA6lcREuLvEl3F5qtZp2u02tVqPVaqFWqzEajYyMjOBwOPD5fFgsFtrttuIiU/hS1Go1er1+X7xOrKnTimJchhS9Xo/ZbMbj8fDWW28xPT3N/Pz8vtf0u2lCoRDf+ta3qFarJBIJKpUKLpcLm83GxMQE3/72t3E4HLhcLkwmEwDdbpdSqcT4+DipVIqrV6/y6aefHutzHjUGg4GxsTGsVitOp1OOwdzcHHa7nfPnzxMOh7FardJdaLPZ0Ov1BAIBxsfHSSQSJBIJWq3WCT+NwqAyMjLCD37wA9xuN2q1GrVaze3bt3n77bep1+sn/fG+FopxGVK0Wi0WiwWPx8OFCxdYWloiFAqhUqn2BZ/F//Z4PLjdbprNJolEgnK5TKPRoNFoMD8/z3e/+13cbreM38AXbjGbzUYqlWJzc3PojItWqyUYDOLz+RgdHWVkZITR0VHeeOMNHA4HBoMBnU4nNwR4UAej0+lwOp14vV4ajYb8noLCQbhcLr797W8zPj6ORqNBq9ViMBj43e9+pxiXfsQVT6fTMTMzw+joKOVymUwmQ71eJxaLUalUjuJXP/fodDo0Gg2hUIhz584RCoUIh8O43W5543gcwuVjNpvRaDRMT08DD05VZrNZVqX3v77X68mAdafTOdJnOwk0Gg02mw2n00k4HGZ+fh63243FYkGv1wM8ciMxGo2oVCr8fj+Li4uo1Wp0Ot1JfPxjRbgLfT4fNpsNh8OB3+9Hq9VKAyxoNptUq1UajQZbW1tkMhlardap3UifFeG+djqd0v1qMpnkeu52u6cuIeRIjItWq8XhcGCz2fhn/+yf8f3vf5+dnR0+++wzkskkb7/9tmJcjgAxIU0mExcuXOAv/uIv8Hq9zM/P43A4nuj0rNVqcblc9Ho9PB4Ply9fRqPRYDAYUKvVj8QOer0e9XqdWq02lAFrrVaL1+tlbGyMF198kddee01ulgC1Wo1ms0m73abRaKDVauX3Z2ZmpDzMP/zDP5zwkxwtKpUKnU6HwWBgaWmJ+fl55ufn+eY3v4nZbMZqtWIwGORri8Ui29vbZLNZfvrTn/LJJ59QLBalbM7zhsFgwO/3MzIyImWD7HY7RqORer1OvV5XjAvsP+15PB4CgQDNZpPR0VG0Wi1ut5tisShPvN1ul3a7vW/wVCqVDKQ+fGI+aLMTi7vb7cr3fJ60ncR4WSwW7HY7brcbr9eLx+PZ58oSCC2sgwr+xFiLcX745+DBplqtVsnn86TTaZlhdtoRc8poNGKz2faNo81mw2AwSIPa6XTI5XJUq1U574xGI3a7Hb1eL9/DarViNBrR6/W02+2hm5MqlQqDwYDb7cZsNhMKhaQL0e/3YzKZsNls6HQ6Oc/EpqnT6QgEAgSDQbRaLZVKZd8NWKS4DyvCnarT6eSX2MfEPDltRkVwJMbFZrNx4cIFRkZGmJ2dxev1yjTWbDaLVqtlY2ODeDxOJBKhUqkQi8X2ZUYYjUbp0/b5fJjNZvk9t9vN0tLSPldNNBplfX2dUqnE7u4ulUqFUqlEuVw+ikccKET9idFo5MyZM0xNTXHp0iWmp6exWCwYjcZ9r+90OjQajX2xF61WK908X0ar1aLdbnP37l0+/PBD0uk0165dI5PJsL29fSTPd1yoVCo5XmfPnuUHP/gBHo+HpaUl6QoT82p7e5tSqcSdO3eIRqNyA/V6vfz4xz/GbDZjNBpl3GViYgKtVksikaBUKp30ox4aIj4QDof5p//0nxIKhTh//jwTExNYLBbcbjcajUYmO/R6PWm8w+GwHK+lpSVu377Nb3/7W5rNJmq1mm63SyqVIpPJnPBTHh1ms1keYoQbLJfLUSgUyOfz8mb8tAeSg7wUx32oORLjotPp8Pl8BAIBHA6HdNW43W7cbjezs7NywtVqNbRaLel0ep9bRa/XyxNfIBDAarXK742MjHDhwgXsdjtarRaNRsPGxgbtdpt8Pi8Xb6PRkCel02r9nwRx8jEYDHi9XsbHxwkGg9jt9gPjLL1eT8ZI+sdHq9V+acqs+Ll2u00ymeTevXukUinu378vT/CnGZVKhV6vx2QyEQqFeOmll3C73YyPj2O1WqnX6zQaDYrFIvF4nFwux/LyMhsbG9jtdnw+H41Gg1qtJsfTaDRiNptxOByUy2Wy2exJP+ahotFoZPLCmTNnmJiY4IUXXmB0dHTf68RNWaBWq+X6np6exm63U6lUuHXr1r51W6lUyOVy+1QQhgmdTofFYpGGRa1W02q1qFQqNBqNr3Vz6/f6CB5WkjgOjsS4VKtVlpeXyefzXLp0SW5iItC/tLTEyMgICwsLpNNpqtUqyWRyX2BUuBT0er107QhsNhvj4+OyWlycOL1eL7VajZdeekkWuK2urpLP59nc3KTRaBzF4544IiXW4/HwjW98gzNnzhAIBB4JInc6HdrtNtFolPfee2/fCVrUZXxZXKbfDXTnzh2uX79OqVQil8ud6piLRqPBaDRiMBiYnZ0lFApx5swZmWKsUqmo1+vcvn2bu3fvkslkuH//PqVSib29PbLZLFarVRaSrq2t0ev1ZNzR7/fzxhtvyMLLYTAw4nB4/vx5Ll++zOjoKEtLS3i9Xmw2GwD1ep1SqUStVmN3d3dfBqLP5+Oll17CZDLhdDrR6/Xy/7daLTQaDZ1Oh/v377OysiIPMcMW8LfZbASDQbxer6xzSSQSrK2tEYvFnsqwiAOm2+3mzTffJBgMSjfvzs4Ov/nNbygWizQajWNxNR6JcREbeywW44c//CHtdlsWnZlMJs6dOwc88KeKr1qttu+BRVBULHyNRrPP6j68CY6Pj3PhwgW63S7NZpNms8n
2024-07-26 21:07:40 +08:00
"text/plain": [
2024-07-29 06:24:50 +08:00
"<Figure size 500x500 with 25 Axes>"
2024-07-26 21:07:40 +08:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
2024-07-30 00:44:16 +08:00
"X = batch_iterate(25, train_images)\n",
"for x in X: \n",
" show_images(x)\n",
2024-07-26 21:07:40 +08:00
" break"
]
},
2024-07-30 07:06:52 +08:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training Cycle"
]
},
2024-07-26 21:07:40 +08:00
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 536,
2024-07-27 06:09:51 +08:00
"metadata": {},
2024-07-29 06:24:50 +08:00
"outputs": [],
2024-07-27 06:09:51 +08:00
"source": [
2024-07-30 18:21:38 +08:00
"lr = 2e-6\n",
"z_dim = 128\n",
2024-07-30 07:06:52 +08:00
"\n",
2024-07-27 06:09:51 +08:00
"gen = Generator(z_dim)\n",
"mx.eval(gen.parameters())\n",
2024-07-30 07:44:41 +08:00
"gen_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999]) #,betas=[0.5, 0.9]\n",
2024-07-27 06:09:51 +08:00
"\n",
"disc = Discriminator()\n",
"mx.eval(disc.parameters())\n",
2024-07-30 07:44:41 +08:00
"disc_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999])"
2024-07-29 06:30:08 +08:00
]
},
2024-07-27 06:09:51 +08:00
{
"cell_type": "code",
2024-07-30 18:21:38 +08:00
"execution_count": 537,
2024-07-27 06:09:51 +08:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-30 18:21:38 +08:00
" 0%| | 0/50 [00:00<?, ?it/s]"
2024-07-30 07:56:13 +08:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-07-30 18:21:38 +08:00
"Epoch: 0, iteration: 468, Discriminator Loss:array(1.33504, dtype=float32), Generator Loss: array(0.460247, dtype=float32)\n"
2024-07-30 07:56:13 +08:00
]
},
{
"data": {
2024-07-30 18:21:38 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9yW+kWXodDp+Y53lmMEgG55wzqzJralV1Vw8otWzZsgT3ygJsGPDCGwOGF174f/DCS3tjaGFJsABbgoyfWi2p1V2lqu7MyjmT8xyMeZ7HN74FdR7eYLd/YkbWhw/4wAsksopkBiPue+8znOc859GNx+MxrtbVulpX62pdrW9w6f9//Qau1tW6Wlfrav3/37pyLlfral2tq3W1vvF15Vyu1tW6Wlfran3j68q5XK2rdbWu1tX6xteVc7laV+tqXa2r9Y2vK+dyta7W1bpaV+sbX1fO5Wpdrat1ta7WN76unMvVulpX62pdrW98XTmXq3W1rtbVulrf+DJe9gf/43/8jxiNRnj69Cm2trYQj8dx9+5d2Gw2GI1G6PV62O12OBwO9Pt9FAoFDIdDWK1WGI1GbG5u4ic/+QmMRiO+9a1vIR6Pw263w+l0wm63IxqNwmg0YjQaYTQaodfrodVqYTwew2AwAACazab8yefz0DQNy8vLCIfD0DQNo9EIer0eVqsVBoMBbrcbDocDmUwGv/jFLzAajfD+++9jfn4enU4HrVYLBoMBXq8XBoMB+/v7OD4+Rr1eRyqVQr/fh8FggE6nw0cffYTPPvsMJpMJ9+/ff+ON/t3f/V0AgMvlgs1mg9lsht1uBwDU63UMBgOsra3h1q1bqFarePToEZrNJlZXVzE7O4tHjx7hD//wD6FpGt555x1EIhFYLBbZf7fbDaPRCKvVCpPJhEqlgkwmA7vdjvfeew/BYBCvX7/GxsYG8vk8njx5Ak3T8PHHH2N1dRXNZhPVahWBQAD/+B//Y8TjcaRSKaTTaWSzWTx58gSj0Qg3btxALBZDKpXC7u4u+v0+2u02RqMRut0uer0ezGYznE4nxuMxKpUKut0url+/jgcPHsBoNOL3f//333j//v2///fQNA2bm5s4PDxELBbD7du35VkbDAZ4PB74fD5UKhU8fvwY7XYbN2/eRDKZxKNHj/BHf/RHMBqN+M3f/E0kk0l4vV74fD4MBgNUq1UMh0P0+32MRiP5b4PBAJfLBbPZjFarhXa7jVarhUKhgPF4jMXFRYRCIfR6PfR6Pbjdbrz//vvw+/342c9+hq+++goGgwEWiwV2u13OH/d7MBig1WrJnSmVShgOh+h2uxgOhygWi2i320gmk7h27Rr0ej3+w3/4D2+8fwDwb//tvwUAVKtVtFot6PV6mM1mGAwGBAIB2Gw2BAIBRCIR1Go1vHjxAt1uFzdu3MD8/Dw2NjbkDn/22WdYWFiA2WyGxWLBeDyWfTs9PUW5XEYgEMDs7CwMBgPG4zE0TUMmk0E2m0Wr1UI+n8d4PEYikUAgEEC/30e324XNZkMymYTL5UK9Xkej0UC1WsX+/j50Oh3u3LmDeDyO4XCIwWCAfr+ParWKXq+HXC6HYrEIq9UKj8eDbreLn/zkJzg4OMDv/M7v4F/9q38Fi8WC73znO2+8f7/9278NvV6PeDyOSCQCg8EAs9kMABiNRhiPx3IOAECn0wGA2JCNjQ385V/+JYxGI7797W8jkUjIvpjNZng8HlgsFkSjUXi9XjSbTZTLZQCA0+mEwWDA48eP8fjxY/T7fbRaLZhMJnz88cdYWVlBJpPB/v4+/H4/fvSjH2Fubg6np6fIZrMoFot4/fo1RqMR1tbWEA6HUa1WUSgU0Gg0sLW1hWaziXA4DL/fj263i2q1in6/j2aziV6vh5WVFdy4cQMGg+FSZ/DSzqXVakGn0+HWrVu4d+8eAECv10PTNPR6PQwGAxiNRjidTphMJng8HmiaBpfLBavVitFohHq9Dp1Oh/n5efh8Pnnjmqah2WzCZDKhUCigVquh2+2i2WxCp9PBbDZDr9ej0+mg3W7DZDIhFArBZDLBZrNB0zRYrVY4HA7odDqMx2OMx2NkMhlUq1V0Oh0AgMViEUfV7XZRLBZhMBgwGo1gNpsRDAYRDofRaDQQj8cxGo1gsVhgMpng8/mQTqeh1+unci7Hx8cwGAxYXFxEIBCAx+NBLBaDpmk4PDxErVaTSwlAnB4/t91ux40bNzAejxEIBGC1WjEcDlGtVmE0GtHtdmEymcR59ft96PV6tFot/M3f/A0AyGU0mUy4du0aNE1Du93G69evYbVa4XK5oNfrkU6n0el0kM/nUSqV0O/3MT8/j/F4jH6/j1QqhXa7DY/HA7PZLI6uUCigXC5D0zQMh0MAQDQahU6ng9VqxenpKfT66ZLl4XCI8XiMWCwmv7ff70On0yESicDhcMBqtcJsNsNqtcLr9cJisaDf7yOXy8HhcOD73/8+AMBms4nz7XQ6MJlM8Hq9MJlMGAwG0DQNer1ezkaxWISmaSgWiyiXyzAajRIgNBoNDIdDeL1ezM7Owmw2o1KpoF6vi8Mym83w+Xyw2WzodDrI5XKoVqsoFovo9/toNBoYjUawWq2IRqPiXMbjMeLxOADI5512/wCg0+lAp9PB6/UiGo1Cr9dLYEgDOBqN5H3H43GMx2P4fD65Hzdu3JA7n0qloGkaNE2DyWSCw+EAAJTLZXGcfEbdbleCxl6vB71ej1gsBr1eD51Oh2q1CqfTifn5eTnzvO+dTgeapmFmZkbud6VSkYBSp9PB5XLB7XbD7XZjfn5eftdgMMAPfvADMZxbW1vQ6/VTORe/3y/BKwBYrVYEg0F5v4PBAM1mE4PBAACgaRp0Oh3sdjvsdjuSySTef/996PV6zM/Pw+/3I5VK4fj4GFarFeFwWILF8XiMRqOBQqEgTns4HOLk5ASdTgcOhwOLi4uwWCxot9t4+fIlbDYblpeX4XK50Gw2cXp6ikKhgEqlgtFoJM9Tp9OhVquh0Wig3W7DaDTi3XffhU6nQ6fTQbfbhcViERteq9XQ6/VgtVpRKpXEaf5D69LOpdvtwmAw4Nq1a1heXkapVMLOzg46nQ4ajQa63S4GgwH0ej30er0Yeq/XK1EsLww3sVQqoVarAQB6vR6GwyGy2Syy2eyEc7HZbDCZTOh2u2i32wgEAkgkErDb7fIgLBYL/H4/AMiDLpVK2NragsFggNPpFCel0+nQ6/VQrVblYlmtVszNzWFmZgatVgt+vx+aponRojHQNO2NDiRXNpuF0WiU9+3xeDAzMwNN01AoFNBsNpHL5VCpVOB0OrGysgKn0ykX02azYWlpSYyQwWAQ56zX69HtdmE0GuXBDwYD+bevXr1CuVxGPB5HLBaDxWLBwsICxuMxjo6OcHp6img0Cr/fD51OJ++nWq2iXq/DYrEgFosBAE5PT1EsFmE2m+FwOOD1enH79m14vV4cHh6Kwa5WqwAAr9crhzKXy2FaKTs6q0AggFgshn6/j1qtJufD5/NBp9NJ9uZ2u8VZlMtl2Gw2vP/++xJZNxoNnJ6e4vj4GB6PB9euXYPdbsdoNJowlpqmSYDCLM7tdmNpaQlGoxGtVgutVgtutxvhcBg6nQ71eh29Xg+NRgOapkkWbbVa0ev1UC6XUSqVkM1m0e/3Ua/XoWkaFhYWEAwGxbnw/litVlQqFRSLRdmHaRZfMxQKIRKJyNkfj8dyZzRNQ6PRmAjgmBV7vV4sLS1JttpqtdDtdtHpdGCz2cRhMdvo9XriGJrNpjhaq9UKq9WKUCgEvV6PUqmEer0Or9eLSCSC0WiEdDqNdruNfr+PXq8Hi8UihpzvkUGkyWRCIBCAxWKB2WyG2WxGs9lEOp2GpmlIJpOwWq3IZDI4PDyc+gx6PB4AZ0GqTqcTm8PPTGfGoIS/x2azwev1ijMEzoIum82Gg4MDpFIpWK1W6HQ6OBwOCdAZoPT7fUEA+LfX68XKygrMZjM2NzeRTqexsrKCa9euwWq1otPpyL9Tnydwlig0m0202230ej04HA6sra3B6XTK+zGbzXC5XAAAh8Mh56NWq11
2024-07-30 07:56:13 +08:00
"text/plain": [
"<Figure size 500x500 with 25 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-30 18:21:38 +08:00
" 6%|▌ | 3/50 [00:14<03:49, 4.89s/it]"
2024-07-29 00:18:35 +08:00
]
2024-07-26 21:07:40 +08:00
}
],
"source": [
2024-07-30 00:44:16 +08:00
"# Set your parameters\n",
2024-07-30 18:21:38 +08:00
"n_epochs = 50\n",
2024-07-30 00:44:16 +08:00
"display_step = 5000\n",
2024-07-28 06:10:19 +08:00
"cur_step = 0\n",
2024-07-30 00:44:16 +08:00
"\n",
2024-07-30 07:44:41 +08:00
"batch_size = 128\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-28 06:10:19 +08:00
"D_loss_grad = nn.value_and_grad(disc, disc_loss)\n",
"G_loss_grad = nn.value_and_grad(gen, gen_loss)\n",
"\n",
"\n",
2024-07-30 00:44:16 +08:00
"for epoch in tqdm(range(n_epochs)):\n",
2024-07-28 06:10:19 +08:00
"\n",
2024-07-30 18:21:38 +08:00
" for idx,real in enumerate(batch_iterate(batch_size, train_images)):\n",
2024-07-28 06:10:19 +08:00
" \n",
2024-07-30 00:44:16 +08:00
" # TODO Train Discriminator\n",
2024-07-30 07:06:52 +08:00
" D_loss,D_grads = D_loss_grad(gen, disc,mx.array(real), batch_size, z_dim)\n",
2024-07-28 06:10:19 +08:00
"\n",
" # Update optimizer\n",
" disc_opt.update(disc, D_grads)\n",
" \n",
" # Update gradients\n",
" mx.eval(disc.parameters(), disc_opt.state)\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-30 00:44:16 +08:00
" # TODO Train Generator\n",
" G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)\n",
2024-07-28 06:10:19 +08:00
" \n",
" # Update optimizer\n",
" gen_opt.update(gen, G_grads)\n",
" \n",
" # Update gradients\n",
2024-07-30 07:37:09 +08:00
" mx.eval(gen.parameters(), gen_opt.state) \n",
2024-07-30 07:44:41 +08:00
" \n",
2024-07-30 18:21:38 +08:00
" # if (cur_step + 1) % display_step == 0:\n",
" # print(f\"Step {epoch}: Generator loss: {G_loss}, discriminator loss: {D_loss}\")\n",
" # fake_noise = mx.array(get_noise(batch_size, z_dim))\n",
" # fake = gen(fake_noise)\n",
" # show_images(fake)\n",
" # show_images(real)\n",
" # cur_step += 1\n",
" \n",
" if epoch%10==0:\n",
" print(\"Epoch: {}, iteration: {}, Discriminator Loss:{}, Generator Loss: {}\".format(epoch,idx,D_loss,G_loss))\n",
2024-07-30 07:44:41 +08:00
" fake_noise = mx.array(get_noise(batch_size, z_dim))\n",
" fake = gen(fake_noise)\n",
" show_images(fake)\n",
2024-07-30 18:21:38 +08:00
" # show_images(real)\n",
" \n",
" # print('Losses D={0} G={1}'.format(D_loss,G_loss))"
2024-07-27 05:19:08 +08:00
]
2024-07-26 21:07:40 +08:00
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}