mlx-examples/gan/playground.ipynb

603 lines
514 KiB
Plaintext
Raw Normal View History

2024-07-26 21:07:40 +08:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Import Library"
]
},
{
"cell_type": "code",
2024-07-30 21:59:35 +08:00
"execution_count": 1,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"import mnist"
]
},
{
"cell_type": "code",
2024-07-30 21:59:35 +08:00
"execution_count": 2,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"import mlx.core as mx\n",
"import mlx.nn as nn\n",
"import mlx.optimizers as optim\n",
"\n",
"from tqdm import tqdm\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
2024-07-30 07:17:12 +08:00
{
"cell_type": "code",
2024-07-30 21:59:35 +08:00
"execution_count": 3,
2024-07-30 07:17:12 +08:00
"metadata": {},
"outputs": [],
"source": [
"# mx.set_default_device(mx.gpu)"
]
},
2024-07-26 21:07:40 +08:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# GAN Architecture"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generator 👨🏻‍🎨"
]
},
{
"cell_type": "code",
2024-07-30 21:59:35 +08:00
"execution_count": 4,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def GenBlock(in_dim:int,out_dim:int):\n",
" \n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
2024-07-30 07:06:52 +08:00
" nn.BatchNorm(out_dim, 0.8),\n",
" nn.LeakyReLU(0.2)\n",
2024-07-26 21:07:40 +08:00
" )"
]
},
{
"cell_type": "code",
2024-07-30 21:59:35 +08:00
"execution_count": 5,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"class Generator(nn.Module):\n",
"\n",
2024-07-31 00:50:02 +08:00
" def __init__(self, z_dim:int = 32, im_dim:int = 784, hidden_dim: int = 256):\n",
2024-07-26 21:07:40 +08:00
" super(Generator, self).__init__()\n",
2024-07-30 18:21:38 +08:00
"\n",
2024-07-26 21:07:40 +08:00
" self.gen = nn.Sequential(\n",
" GenBlock(z_dim, hidden_dim),\n",
" GenBlock(hidden_dim, hidden_dim * 2),\n",
" GenBlock(hidden_dim * 2, hidden_dim * 4),\n",
"\n",
2024-07-30 07:56:13 +08:00
" nn.Linear(hidden_dim * 4,im_dim),\n",
2024-07-26 21:07:40 +08:00
" )\n",
" \n",
" def __call__(self, noise):\n",
2024-07-30 07:06:52 +08:00
" x = self.gen(noise)\n",
" return mx.tanh(x)"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 21:59:35 +08:00
"execution_count": 6,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Generator(\n",
" (gen): Sequential(\n",
" (layers.0): Sequential(\n",
2024-07-31 00:50:02 +08:00
" (layers.0): Linear(input_dims=100, output_dims=256, bias=True)\n",
" (layers.1): BatchNorm(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
2024-07-30 07:06:52 +08:00
" (layers.2): LeakyReLU()\n",
2024-07-26 21:07:40 +08:00
" )\n",
" (layers.1): Sequential(\n",
2024-07-31 00:50:02 +08:00
" (layers.0): Linear(input_dims=256, output_dims=512, bias=True)\n",
" (layers.1): BatchNorm(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
2024-07-30 07:06:52 +08:00
" (layers.2): LeakyReLU()\n",
2024-07-26 21:07:40 +08:00
" )\n",
" (layers.2): Sequential(\n",
2024-07-31 00:50:02 +08:00
" (layers.0): Linear(input_dims=512, output_dims=1024, bias=True)\n",
" (layers.1): BatchNorm(1024, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
2024-07-30 07:06:52 +08:00
" (layers.2): LeakyReLU()\n",
2024-07-26 21:07:40 +08:00
" )\n",
2024-07-31 00:50:02 +08:00
" (layers.3): Linear(input_dims=1024, output_dims=784, bias=True)\n",
2024-07-26 21:07:40 +08:00
" )\n",
")"
]
},
2024-07-30 21:59:35 +08:00
"execution_count": 6,
2024-07-26 21:07:40 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gen = Generator(100)\n",
"gen"
]
},
{
"cell_type": "code",
2024-07-30 21:59:35 +08:00
"execution_count": 7,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-30 00:44:16 +08:00
"# make 2D noise with shape n_samples x z_dim\n",
"def get_noise(n_samples:list[int], z_dim:int)->list[int]:\n",
" return mx.random.normal(shape=(n_samples, z_dim))"
]
},
{
"cell_type": "code",
2024-07-30 21:59:35 +08:00
"execution_count": 8,
2024-07-30 00:44:16 +08:00
"metadata": {},
"outputs": [
{
"data": {
2024-07-31 00:50:02 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAWKElEQVR4nO3cbWzWd73H8U8L1NoCLRToHRtrKeWmYEdXEES0G0zAEUZmJJrhlJktmmiyRYxxTjbNYuLMHG6a6IOxLS6LzjnGYCCBMTZQwrhvhQHtMOWmFAorZdzYsnH57Jscz0l6fX45ek7M+/X4el9lvfv0/2DfnEwmkxEAAJJy/6//AQCA/z8YBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAISB2b7wwQcftN885f+LKy8vtxtJqqqqspuDBw/azS233GI3x48ft5ujR4/ajST19PTYTX5+vt3s2bPHbgoLC+1GksaPH283S5YssZubbrrJbjZs2PBvaSRp4cKFdvOnP/3Jbmpra+2mvr7ebtrb2+1Gkjo7O+0m5fdDyvfdoUOH7EaSdu3aZTdLly61mzvvvLPf1/CkAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAEJOJsurdQ8//LD95s3NzXZTWlpqN5J08eJFu6msrLSb7u5uu0k5/JVy0E1KO1TX2NhoNylH3ZYvX243knThwgW76erqsptXX33VboYNG2Y3CxYssBtJWr16td3MnTvXbtavX283BQUFdpP6Pf7iiy/aTXV1td2kHPQsKiqyG0m6cuWK3aT8Lnr66af7fQ1PCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAMzPaFfX199punHP7au3ev3aS6du2a3Vy/ft1u6urq7Cb18zB27Fi72bZtm92kHC5MOaooSd///vft5qmnnrKblANtKT8XW7dutRtJqqiosJv9+/fbzdWrV+3mgQcesJtx48bZjSS1trbazQ033GA37777rt2cPXvWblK7mpqapI/VH54UAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAAAhJ5PJZLJ5Ycqlz9raWrspLCy0G0nq7e21m09/+tN2k3IV87333rObxYsX240kHT9+3G4GDx5sNzt37rSbESNG2I0knTt3zm4mTJhgN+fPn7eb3Fz/76qU71VJ+vjHP243+fn5dtPd3W03Kd93jz/+uN1I0uc+9zm7ueeee+zmzJkzdjNp0iS7kaTm5ma7GT9+vN3cd999/b6GJwUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQsj6I98Mf/tB+8+HDh9vNpk2b7EaSRo4caTeDBg2ym5ycHLspLS21m87OTruRpMuXL9vNV77yFbv53ve+ZzfLly+3G0mqqKiwm5Qjeinfe6NGjbKbpqYmu5GkZ5991m4KCgrs5ujRo3Zz5coVu5k9e7bdSNKsWbPspri42G7eeustu5kxY4bdSNKuXbvs5uWXX7abd999t9/X8KQAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAwsBsX5hyaG3ZsmV2s379eruRpDlz5thNysG+q1ev2s3GjRvtpqamxm4kqaioyG5+97vf2c3UqVPt5re//a3dSNLkyZPtprKy0m7y8/PtJpsDY/9s9OjRdiOlHZBL+R6/44477KalpcVusrzF+b9iw4YNdpPyOy/laKEk3XvvvXZz/vz5pI/VH54UAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQMjJZHmV6oUXXrDffMuWLXZz6dIlu5GkPXv22E1FRYXdrFq1ym7eeOMNu6mvr7cbSXrttdfs5siRI3bT0NBgN3/961/tRpLefvttu2lsbLSb2tpauxk7dqzdvPnmm3YjSbfeeqvdrFy50m5S/puWL19uNy+99JLdSFJVVZXdDBkyxG7ef/99u0k5QChJbW1tdtPc3Gw3mzdv7vc1PCkAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAMDDbF546dcp+84KCArspLi62G0lauHBhUudas2aN3Rw4cMBuhg4dajdS2pGxefPm/VuaCRMm2I2UdqiupKTEbnJycuzmb3/7m920t7fbjSR1dHTYzVe/+lW7yeZo2j87fPiw3aR87iRp48aNdjNjxgy7KS8vt5uysjK7kaQnn3zSbn784x8nfaz+8KQAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAhZX0lNuTp56NAhu7l48aLdSNKHH35oNylXMfPy8uzmk5/8pN2kXJ2UpNWrV9vN7t277eZHP/qR3YwbN85upLSvU29vr91cuHDBbj766CO7yc1N+1vsa1/7mt1s2LDBboqKiuwm5TpoZWWl3UjSxIkT7aa6utpufv/739vN9evX7UaS5s+fbzddXV1JH6s/PCkAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAkPVBvCNHjthvnnKw6fTp03YjSdOnT7ebp59+2m4++9nP2s2lS5fsZtasWXYjSS+//LLdDB482G4mTZpkN6kH0DZv3mw3Kcf3Uv6b6uvr7ea1116zG0nat2+f3aT8DA4YMMBu7rrrLrtJPX75l7/8xW6GDx9uNyk/6+PHj7cbSRo1apTdnDx5Mulj9YcnBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABCyPoh35swZ+82bmprs5tlnn7UbSaqpqbGbO++8025uuukmu8nLy7Ob5557zm4kadGiRXazatUqu5k5c6bd7N+/324k6bvf/a7dvPXWW3YzcuRIu2ltbbWb+fPn242UdhBv4sSJdtPT02M3jzzyiN0sWLDAbiRp+/btdpPy++Htt9+2m4aGBruR0n6vHDt2LOlj9YcnBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABByMplMJpsX7tixw37zlMNVp0+fthtJmjJlit289NJLdjNo0CC7uf/+++2mpKTEbiQpyy/nf/Hzn//cbvLz8+2mrq7ObiTpww8/tJsbb7zRboYPH243Kf+21MOAHR0ddtPd3W03BQUFdlNdXW03fX19diNJo0aNspuUg54jRoywm5aWFruRpMbGRrspLCy0m2XLlvX7Gp4UAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAAAh6yupDzzwgP3mp06dspuvf/3rdiNJW7ZssZt58+bZzRtvvGE3KZdLKyoq7EaSfvGLX9jNrbfeajfjx4+3m56eHruRpPLycrvp6uqymytXrthNVVWV3Tz66KN2I0nTpk2zmxUrVtjNT37yE7tpaGiwm9SLyHl5eXYzY8YMu3n44YftZvHixXYjpV3bPXv2rN28+OKL/b6GJwUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQBmb7wtLSUvvNjx49ajfDhg2zG0mqq6uzm+eee85uFixYYDcph7/y8/PtRpJGjBhhN21tbXZTU1NjN++8847dSGnH1ubMmWM327dvt5urV6/aTcqBP0mqr6+3m/3799vNpk2b7KasrMxuTpw4YTeStGTJErvZuHGj3SxatMhuUg9ZFhUV2U1u7r/mb3qeFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEDI+iDetWvX7De///7
2024-07-30 00:44:16 +08:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"img = get_noise(28,28)\n",
"plt.imshow(img, cmap='gray')\n",
"plt.axis('off')\n",
"plt.show()"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Discriminator 🕵🏻‍♂️"
]
},
{
"cell_type": "code",
2024-07-30 21:59:35 +08:00
"execution_count": 9,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def DisBlock(in_dim:int,out_dim:int):\n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
2024-07-30 00:44:16 +08:00
" nn.LeakyReLU(negative_slope=0.2),\n",
2024-07-30 18:21:38 +08:00
" nn.Dropout(0.3),\n",
2024-07-26 21:07:40 +08:00
" )"
]
},
{
"cell_type": "code",
2024-07-30 21:59:35 +08:00
"execution_count": 10,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"class Discriminator(nn.Module):\n",
"\n",
2024-07-31 00:50:02 +08:00
" def __init__(self,im_dim:int = 784, hidden_dim:int = 256):\n",
2024-07-26 21:07:40 +08:00
" super(Discriminator, self).__init__()\n",
"\n",
" self.disc = nn.Sequential(\n",
2024-07-30 07:37:09 +08:00
" DisBlock(im_dim, hidden_dim * 4),\n",
" DisBlock(hidden_dim * 4, hidden_dim * 2),\n",
2024-07-26 21:07:40 +08:00
" DisBlock(hidden_dim * 2, hidden_dim),\n",
2024-07-30 07:37:09 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" nn.Linear(hidden_dim,1),\n",
2024-07-30 18:21:38 +08:00
" nn.Sigmoid()\n",
2024-07-26 21:07:40 +08:00
" )\n",
" \n",
2024-07-30 18:21:38 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" def __call__(self, noise):\n",
2024-07-30 18:21:38 +08:00
" return self.disc(noise)"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 21:59:35 +08:00
"execution_count": 11,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Discriminator(\n",
" (disc): Sequential(\n",
" (layers.0): Sequential(\n",
2024-07-31 00:50:02 +08:00
" (layers.0): Linear(input_dims=784, output_dims=1024, bias=True)\n",
2024-07-26 21:07:40 +08:00
" (layers.1): LeakyReLU()\n",
2024-07-30 18:21:38 +08:00
" (layers.2): Dropout(p=0.30000000000000004)\n",
2024-07-26 21:07:40 +08:00
" )\n",
" (layers.1): Sequential(\n",
2024-07-31 00:50:02 +08:00
" (layers.0): Linear(input_dims=1024, output_dims=512, bias=True)\n",
2024-07-30 07:37:09 +08:00
" (layers.1): LeakyReLU()\n",
2024-07-30 18:21:38 +08:00
" (layers.2): Dropout(p=0.30000000000000004)\n",
2024-07-30 07:37:09 +08:00
" )\n",
" (layers.2): Sequential(\n",
2024-07-31 00:50:02 +08:00
" (layers.0): Linear(input_dims=512, output_dims=256, bias=True)\n",
2024-07-26 21:07:40 +08:00
" (layers.1): LeakyReLU()\n",
2024-07-30 18:21:38 +08:00
" (layers.2): Dropout(p=0.30000000000000004)\n",
2024-07-26 21:07:40 +08:00
" )\n",
2024-07-31 00:50:02 +08:00
" (layers.3): Linear(input_dims=256, output_dims=1, bias=True)\n",
2024-07-30 18:21:38 +08:00
" (layers.4): Sigmoid()\n",
2024-07-26 21:07:40 +08:00
" )\n",
")"
]
},
2024-07-30 21:59:35 +08:00
"execution_count": 11,
2024-07-26 21:07:40 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"disc = Discriminator()\n",
"disc"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model Training 🏋🏻‍♂️"
]
},
{
2024-07-30 00:44:16 +08:00
"cell_type": "markdown",
2024-07-26 21:07:40 +08:00
"metadata": {},
"source": [
2024-07-30 00:44:16 +08:00
"### Losses"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-07-30 00:44:16 +08:00
"#### Discriminator Loss"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 21:59:35 +08:00
"execution_count": 12,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-26 21:36:29 +08:00
"def disc_loss(gen, disc, real, num_images, z_dim):\n",
2024-07-30 00:44:16 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" noise = mx.array(get_noise(num_images, z_dim))\n",
" fake_images = gen(noise)\n",
2024-07-27 06:09:51 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" fake_disc = disc(fake_images)\n",
" \n",
2024-07-27 05:19:08 +08:00
" fake_labels = mx.zeros((fake_images.shape[0],1))\n",
2024-07-30 18:21:38 +08:00
" \n",
" fake_loss = mx.mean(nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True))\n",
2024-07-30 00:44:16 +08:00
" \n",
2024-07-30 18:21:38 +08:00
" real_disc = mx.array(disc(real))\n",
2024-07-27 05:19:08 +08:00
" real_labels = mx.ones((real.shape[0],1))\n",
2024-07-30 18:21:38 +08:00
" \n",
" real_loss = mx.mean(nn.losses.binary_cross_entropy(real_disc,real_labels,with_logits=True))\n",
" \n",
2024-07-31 00:50:02 +08:00
" disc_loss = (fake_loss + real_loss) / 2.0\n",
2024-07-26 21:07:40 +08:00
"\n",
" return disc_loss"
]
},
2024-07-30 00:44:16 +08:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Generator Loss"
]
},
2024-07-26 21:07:40 +08:00
{
"cell_type": "code",
2024-07-30 21:59:35 +08:00
"execution_count": 13,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-26 21:36:29 +08:00
"def gen_loss(gen, disc, num_images, z_dim):\n",
2024-07-26 21:07:40 +08:00
"\n",
" noise = mx.array(get_noise(num_images, z_dim))\n",
2024-07-30 07:06:52 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" fake_images = gen(noise)\n",
2024-07-30 18:21:38 +08:00
" fake_disc = mx.array(disc(fake_images))\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-27 05:19:08 +08:00
" fake_labels = mx.ones((fake_images.shape[0],1))\n",
2024-07-30 00:44:16 +08:00
" \n",
2024-07-30 18:21:38 +08:00
" gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)\n",
" \n",
" return mx.mean(gen_loss)"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 21:59:35 +08:00
"execution_count": 14,
2024-07-30 00:44:16 +08:00
"metadata": {},
"outputs": [],
"source": [
"# Get only the training images\n",
2024-07-30 18:21:38 +08:00
"train_images,train_labels,*_ = map(np.array, mnist.mnist())"
2024-07-30 00:44:16 +08:00
]
},
{
"cell_type": "code",
2024-07-30 21:59:35 +08:00
"execution_count": 15,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-30 00:44:16 +08:00
"# Normalize the images to fall between -1,1\n",
"train_images = train_images * 2.0 - 1.0"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 21:59:35 +08:00
"execution_count": 16,
2024-07-29 06:24:50 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2024-07-31 00:50:02 +08:00
"<matplotlib.image.AxesImage at 0x117457fd0>"
2024-07-29 06:24:50 +08:00
]
},
2024-07-30 21:59:35 +08:00
"execution_count": 16,
2024-07-29 06:24:50 +08:00
"metadata": {},
"output_type": "execute_result"
2024-07-30 00:44:16 +08:00
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaEAAAGdCAYAAAC7EMwUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAaI0lEQVR4nO3df2jU9x3H8dfVH1d1lytBk7vUmGVF202dpWrVYP3R1cxApf4oWMtGZEPa+YOJ/cGsDNNBjdgpRdI6V0amW239Y9a6KdUMTXRkijpdRYtYjDOdCcFM72LUSMxnf4hHz1j1e975vkueD/iCufu+vY/ffuvTby75xueccwIAwMBD1gsAAHRfRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJjpab2AW3V0dOjcuXMKBALy+XzWywEAeOScU0tLi/Ly8vTQQ3e+1km7CJ07d075+fnWywAA3Kf6+noNHDjwjvuk3afjAoGA9RIAAElwL3+fpyxCH3zwgQoLC/Xwww9r5MiR2rdv3z3N8Sk4AOga7uXv85REaPPmzVq8eLGWLVumI0eO6JlnnlFJSYnOnj2bipcDAGQoXyruoj1mzBg99dRTWrduXeyx73//+5o+fbrKy8vvOBuNRhUMBpO9JADAAxaJRJSVlXXHfZJ+JXTt2jUdPnxYxcXFcY8XFxertra20/5tbW2KRqNxGwCge0h6hM6fP6/r168rNzc37vHc3Fw1NjZ22r+8vFzBYDC28ZVxANB9pOwLE259Q8o5d9s3qZYuXapIJBLb6uvrU7UkAECaSfr3CfXv3189evTodNXT1NTU6epIkvx+v/x+f7KXAQDIAEm/Eurdu7dGjhypqqqquMerqqpUVFSU7JcDAGSwlNwxYcmSJfrpT3+qUaNGady4cfr973+vs2fP6tVXX03FywEAMlRKIjR79mw1NzfrN7/5jRoaGjRs2DDt2LFDBQUFqXg5AECGSsn3Cd0Pvk8IALoGk+8TAgDgXhEhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmelovAEgnPXr08DwTDAZTsJLkWLhwYUJzffv29Tzz+OOPe55ZsGCB55nf/va3nmfmzJnjeUaSrl696nlm5cqVnmfefvttzzNdBVdCAAAzRAgAYCbpESorK5PP54vbQqFQsl8GANAFpOQ9oaFDh+rvf/977ONEPs8OAOj6UhKhnj17cvUDALirlLwndOrUKeXl5amwsFAvvfSSTp8+/a37trW1KRqNxm0AgO4h6REaM2aMNm7cqJ07d+rDDz9UY2OjioqK1NzcfNv9y8vLFQwGY1t+fn6ylwQASFNJj1BJSYlmzZql4cOH67nnntP27dslSRs2bLjt/kuXLlUkEolt9fX1yV4SACBNpfybVfv166fhw4fr1KlTt33e7/fL7/enehkAgDSU8u8Tamtr05dffqlwOJzqlwIAZJikR+j1119XTU2N6urqdODAAb344ouKRqMqLS1N9ksBADJc0j8d9/XXX2vOnDk6f/68BgwYoLFjx2r//v0qKChI9ksBADJc0iP0ySefJPu3RJoaNGiQ55nevXt7nikqKvI8M378eM8zkvTII494npk1a1ZCr9XVfP31155n1q5d63lmxowZnmdaWlo8z0jSv//9b88zNTU1Cb1Wd8W94wAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAMz7nnLNexDdFo1EFg0HrZXQrTz75ZEJzu3fv9jzDf9vM0NHR4XnmZz/7meeZS5cueZ5JRENDQ0JzFy5c8Dxz8uTJhF6rK4pEIsrKyrrjPlwJAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwExP6wXA3tmzZxOaa25u9jzDXbRvOHDggOeZixcvep6ZPHmy5xlJunbtmueZP/3pTwm9Fro3roQAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADPcwBT63//+l9DcG2+84Xnm+eef9zxz5MgRzzNr1671PJOoo0ePep6ZMmWK55nW1lbPM0OHDvU8I0m//OUvE5oDvOJKCABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAw43POOetFfFM0GlUwGLReBlIkKyvL80xLS4vnmfXr13uekaSf//znnmd+8pOfeJ75+OOPPc8AmSYSidz1/3muhAAAZogQAMCM5wjt3btX06ZNU15ennw+n7Zu3Rr3vHNOZWVlysvLU58+fTRp0iQdP348WesFAHQhniPU2tqqESNGqKKi4rbPr1q1SmvWrFFFRYUOHjyoUCikKVOmJPR5fQBA1+b5J6uWlJSopKTkts855/Tee+9p2bJlmjlzpiRpw4YNys3N1aZNm/TKK6/c32oBAF1KUt8TqqurU2Njo4qLi2OP+f1+TZw4UbW1tbedaWtrUzQajdsAAN1DUiPU2NgoScrNzY17PDc3N/bcrcrLyxUMBmNbfn5+MpcEAEhjKfnqOJ/PF/exc67TYzctXbpUkUgkttXX16diSQCANOT5PaE7CYVCkm5cEYXD4djjTU1Nna6ObvL7/fL7/clcBgAgQyT1SqiwsFChUEhVVVWxx65du6aamhoVFRUl86UAAF2A5yuhS5cu6auvvop9XFdXp6NHjyo7O1uDBg3S4sWLtWLFCg0ePFiDBw/WihUr1LdvX7388stJXTgAIPN5jtChQ4c0efLk2MdLliyRJJWWluqPf/yj3nzzTV25ckXz58/XhQsXNGbMGO3atUuBQCB5qwYAdAncwBRd0rvvvpvQ3M1/VHlRU1Pjeea5557zPNPR0eF5BrDEDUwBAGmNCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZriLNrqkfv36JTT317/+1fPMxIkTPc+UlJR4ntm1a5fnGcASd9EGAKQ1IgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMNzAFvuGxxx7zPPOvf/3L88zFixc9z+zZs8fzzKFDhzzPSNL777/veSbN/ipBGuAGpgCAtEaEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmOEGpsB9mjFjhueZyspKzzOBQMDzTKLeeustzzMbN270PNPQ0OB5BpmDG5gCANIaEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGG5gCBoYNG+Z5Zs2aNZ5nfvSjH3meSdT69es9z7zzzjueZ/773/96noENbmAKAEhrRAgAYMZzhPbu3atp06YpLy9PPp9PW7dujXt+7ty58vl8cdvYsWOTtV4AQBfiOUKtra0aMWKEKioqvnWfqVOnqqGhIbbt2LHjvhYJAOiaenodKCkpUUlJyR338fv9CoVCCS8KANA9pOQ9oerqauXk5GjIkCGaN2+empqavnXftrY2RaPRuA0A0D0kPUIlJSX66KOPtHv3bq1evVoHDx7Us88+q7a2ttvuX15ermAwGNvy8/OTvSQAQJry/Om4u5k9e3bs18OGDdOoUaNUUFCg7du3a+bMmZ32X7p0qZYsWRL7OBqNEiI
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
2024-07-29 06:24:50 +08:00
}
],
"source": [
2024-07-30 00:44:16 +08:00
"plt.imshow(train_images[0].reshape(28,28),cmap='gray')"
2024-07-29 06:24:50 +08:00
]
},
{
"cell_type": "code",
2024-07-30 21:59:35 +08:00
"execution_count": 17,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-30 00:44:16 +08:00
"def batch_iterate(batch_size: int, ipt: list[int])-> list[int]:\n",
2024-07-29 06:24:50 +08:00
" perm = np.random.permutation(len(ipt))\n",
" for s in range(0, len(ipt), batch_size):\n",
2024-07-26 21:07:40 +08:00
" ids = perm[s : s + batch_size]\n",
2024-07-30 00:44:16 +08:00
" yield ipt[ids]"
]
},
{
"cell_type": "code",
2024-07-30 21:59:35 +08:00
"execution_count": 18,
2024-07-30 00:44:16 +08:00
"metadata": {},
"outputs": [],
"source": [
"def show_images(imgs:list[int],num_imgs:int = 25):\n",
" if (imgs.shape[0] > 0): \n",
" fig,axes = plt.subplots(5, 5, figsize=(5, 5))\n",
" \n",
" for i, ax in enumerate(axes.flat):\n",
" img = mx.array(imgs[i]).reshape(28,28)\n",
" ax.imshow(img,cmap='gray')\n",
" ax.axis('off')\n",
" plt.show()"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-07-30 00:44:16 +08:00
"### show first batch of images"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 21:59:35 +08:00
"execution_count": 19,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
2024-07-31 00:50:02 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAADUW0lEQVR4nOy9V3CcV5qf/3TOOaIbjQzmKJHiSBppFCZodnZndr01413v1rocrnztS1/7wr7xre2yXXZ5y+ta22WPJu7OaDSKpCRGMIFERid0zjn8L/Q/Rw0SDCIBAmh+TxWKEtFofn1wvu89b/q9qn6/30dBQUFBQWEbUe/2BSgoKCgoDB+KcVFQUFBQ2HYU46KgoKCgsO0oxkVBQUFBYdtRjIuCgoKCwrajGBcFBQUFhW1HMS4KCgoKCtuOYlwUFBQUFLYdxbgoKCgoKGw72sd9oUql2snr2Fc8iaiBsn5foazf0/GkohrKGn6FsgefjsdZP8VzUVBQUFDYdhTjoqCgoKCw7SjGRUFBQUFh21GMi4KCgoLCtqMYFwUFBQWFbUcxLgoKCgoK285jlyIrPD+IkktljtyDubcsVVkrhaflYaXO+3F/KcZFAQCn00kgEMBisTA+Po7RaGRubo5bt27R6/X25ebebrRaLVqtlkAgwNmzZzGbzRQKBWq1GtFolLt379Ltdnf7MhX2IS6Xi+PHj2O32wkGgzidTvm9crnMF198wcbGBqVSiVKptHsX+jVQjIsCAB6Ph+PHjxMIBHjzzTfxer385//8n7lz5w6A8tAEdDodJpOJqakp/uqv/gq/38/i4iLpdJrz58+zsrKirJPCE+HxePjOd77D+Pg4L774IlNTU/J7sViMf/fv/h1XrlxhbW1NMS6Pg0ql2vSlVqsxm80YDAb0ej0mkwm1Wo3FYkGrfbxL7Xa7tFotut0upVKJcrlMp9OhXq8rJ/BHoFar0el02O12XC4XTqcTh8NBo9GgWq3S6/V2+xKfGSqVCo1Gs2l/BgIBwuEwMzMzeL1eHA4HLpeLTqeDyWRSOrgVnpher0ez2aRWq9Fqteh0Ouh0OvR6PWazmUAgQCQSoVAooFar98W9uKvGRavVotPpUKvV6PV69Ho9Bw8eJBQK4ff7mZqawmazMTMzg8PhkDc5fBmD3MpQNJtN4vG4dCWvXr1KPp9neXmZRqNBp9PZF7+Y3UKr1RIMBhkfH2dmZoaDBw+Sy+VYXFyk2Wzu9uU9M/R6PRaLBbVaLcNh3/nOd3jnnXfw+XwcPHgQvV5Pv9/HZrNx8+ZN1GqlPkbhyajX6ywsLFCpVPD7/ZhMJux2Oz6fD4vFwrlz5xgbG6PVanH37l06nc6e95KfmXEZPAFqNBrUajUGgwGDwYBWq0Wv12M0GvH7/YRCIUZGRpiamsLhcHDkyBFcLtdjGZdGo4HNZqNUKpFKpUgkEqhUKlKpFCqVinq9TqvVelYfe0fY6oS8XR6Z+L2YTCbMZjMWi4VarTYUp3K1Wr1p/z0Ms9mM3W6XhkWr1RIKhZiensZut2OxWNBoNHKd9Hr9M/oU+wMRidBoNJv+vtfryQiCEkX4ik6nQ7FYxGAwUCwWKZfL6HQ6+v0+arUaj8cDfJkbNRqNtFotGo3Gnl7DZ2ZcLBYLXq8Xi8XC7OysDLu43W6MRiMulwu9Xo/H48Fms2GxWHA6ndItFIjFfNCiarVa/H4/TqeTN954g8OHD5NOp5mbmyObzfLpp5+yvLxMt9ul0+k8k8++najVaux2O3q9nm63S6/Xo9Pp7EjYSrz3Xj8hPQ4WiwWPx4PZbObAgQObEqZbEQwGmZ2dRafTodFo0Gg0TE9PEw6H0ev1aLXa+0IZe/lGf5aYTCb0ej2BQIDp6Wl0Oh3w5T0bi8WIxWI0m03K5fJQ7K3toFwuc/PmTdbX1zEYDOTzeQ4fPszIyAgGg4HR0VE8Hg8vvfQShUKBjY0NLl++TLVa3e1LfyDPzLgYjUYCgQBut5uXXnqJ0dFRgsEgIyMjWCwWgsHgI09/9968W93MGo0Gh8MBgN/vByCVSuH3+0kmk6ysrBCLxWRcc78hclAmk0k++JvNpswpbSf9fp9er0e32933D06j0YjP58PtdnPmzBlGR0cf+vqpqSlefPFFjEajzL3A5j3X6/Vot9s0m8197w1vFyqVCoPBgNlsJhQK8eKLL2IwGIAv18tgMFCtVuWXYly+pNFosLq6il6vx+l00m63sdlsdLtduXd7vR4HDx4km81y9+5dbt68+XwbF3Hq83q9HD58GI/Hw/T0tCy3czgcGAyGR4Yp+v2+3JAPMirilGkymTYVABiNRoLBIBqNhkAggMfjoVQq7cvTptls5syZM4yNjUnjkkqlOH/+PMViURYzfF1EGEOEjvr9Po1Gg3w+T61W2/d5KlGo4Ha7iUQiTExMPPT1brebbrdLo9GQYRwRxhWItV9YWCCTyez7NdoONBoNY2NjsvDh4MGDGAwGNBqNDPGYTCbS6TS9Xo96vY7FYkGv18swmkqlQqvVbgqpVatVUqkU7XZbeuxbIcJt+zXs1uv1yOfzRKNRcrncps+pUqnk87PRaGCxWKhWq7Tb7T2593bcuIgbcmJigu985zv4/X4OHTqE2+2Wse/BXMqD6Pf7ZDIZ1tbWttw4JpMJm80mDcmgcbFarRw6dIhAIMDs7CzJZJJoNEqhUNh3G9DhcPDnf/7nvP7663S7XbrdLlevXpVeWbFYpF6vf+33FYnrwVN6qVQikUjQaDT2/QlT7IvR0VFOnjzJ4cOHH/r6drstC0AqlQqdTgev17vJuLTbbebn5/n000+VHpf/H51OxwsvvMCrr75KJBLhxIkTGAwGdDodKpWKw4cPs7a2xt27d6nVahSLRUZHR3G73Wi1WpmDtdlsMpwGsL6+zscff0ypVKLRaNBut7f894XhESHd/Ua322V1dZWNjQ1mZmY27SmVSsX4+Dhutxu9Xs9Pf/pTeeDei8U2O2pcVCoVOp0Oo9GI1WrF7Xbjcrkwm80YjcZNrxWnw2azuSncIP6+1+sRjUalcYHNIQqz2YzNZsNqtWK1WjdV+YhqNIPBgNFoxGQyYTAY9m2Sut/vS3dZ5KtEnqpWqz2RcdFqtVgsFsxms/QiRUhsGB6a/X6fVqtFs9mk0+nQ6XSkp9bpdDZ5KP1+n3q9TqlUot1uUyqV6HQ6Mic4uG/a7Tb1ep12u73vDipPgzgUintK3KPi/hLJaNEK4PV6MRqNsgKqVCoRiURwOByEw2FcLpcs7NFqtVit1k3GBWB8fJxyuSyN/r2I50en06FWq5HNZvfd3u33+3Q6HRmBuHdP6XQ6LBYLRqNRPtv26nNsx42Ly+UiGAwyOTkpE6kWi2XT60T4od1us7q6SjqdJpFIMD8/T6PRkN9LJBIkEoktXUCj0ShzNz/5yU+YnJzE7/fj8/nka9RqNS6Xi1AoRLFY3LO/lIdRrVZ57733iMVinD17lnPnzuF0Ojlw4AAGg0E+DL8uXq+XkydPEgwGsVqtO3Dlu0sul5Me3sGDBymXy7IaLpfLce3aNcrlMs1mk3a7TbVaJZfL0Wq1KJVKdLtd/vIv/5KJiYl9uW+2E9GPptfrOXz4MEePHqXb7VIul4Ev78X19XWWlpb4zW9+g9fr5cc//jFTU1N4PB6sVqtM9nc6HYxGo/RshMG/t6Kv0Wjwve99T7YSbHXA7Pf7pFIpstks169f52/+5m8oFovPdnF2GNG6IXKB91bj7SV23LgYjUZsNht2ux2Hw4HNZkOlUm0yEKLxsV6vk8lkiMViLC0tcenSpU0nyHQ6TSqV2vLfEk2XkUiE1157DafTid1uf+D17FfPpdVqSe9tcnJSlg673W4qlcp9HuHjIpKGHo9HnjqHiWazSSaTASCRSGC32+VXMpnk5s2b5PN56vU6zWaTSqVCNpul1WpRLBbp9Xq89dZbQ7cuT4LYcyLUeOjQIdrtNtlslna7jVarpVqtks1mWV9fJxAIUKlUAKS3LTyWJ2Hwvr23wGJ9fZ1kMkmj0dj35eGDBlR8ZmF8RQj7ufVc+v0++Xyefr/PpUu
2024-07-26 21:07:40 +08:00
"text/plain": [
2024-07-29 06:24:50 +08:00
"<Figure size 500x500 with 25 Axes>"
2024-07-26 21:07:40 +08:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
2024-07-30 00:44:16 +08:00
"X = batch_iterate(25, train_images)\n",
"for x in X: \n",
" show_images(x)\n",
2024-07-26 21:07:40 +08:00
" break"
]
},
2024-07-30 07:06:52 +08:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training Cycle"
]
},
2024-07-26 21:07:40 +08:00
{
"cell_type": "code",
2024-07-30 21:59:35 +08:00
"execution_count": 20,
2024-07-27 06:09:51 +08:00
"metadata": {},
2024-07-29 06:24:50 +08:00
"outputs": [],
2024-07-27 06:09:51 +08:00
"source": [
2024-07-30 18:21:38 +08:00
"lr = 2e-6\n",
"z_dim = 128\n",
2024-07-30 07:06:52 +08:00
"\n",
2024-07-27 06:09:51 +08:00
"gen = Generator(z_dim)\n",
"mx.eval(gen.parameters())\n",
2024-07-30 07:44:41 +08:00
"gen_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999]) #,betas=[0.5, 0.9]\n",
2024-07-27 06:09:51 +08:00
"\n",
"disc = Discriminator()\n",
"mx.eval(disc.parameters())\n",
2024-07-30 07:44:41 +08:00
"disc_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999])"
2024-07-29 06:30:08 +08:00
]
},
2024-07-27 06:09:51 +08:00
{
"cell_type": "code",
2024-07-30 21:59:35 +08:00
"execution_count": 21,
2024-07-27 06:09:51 +08:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-30 18:45:09 +08:00
" 0%| | 0/200 [00:00<?, ?it/s]"
2024-07-30 18:24:53 +08:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-07-31 00:50:02 +08:00
"Epoch: 0, iteration: 468, Discriminator Loss:array(0.657422, dtype=float32), Generator Loss: array(0.469956, dtype=float32)\n"
2024-07-30 18:24:53 +08:00
]
},
{
"data": {
2024-07-31 00:50:02 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9yY/k2XXejT8xz3NERuSclTVXz90kmiIpvSYgAbZheGEY9sbeeOWN/wEDhpf2ykvDOy8MGF57J0OSIVMUJTbZJLuru2uunMeY5zneRb6fkzeSLSkrqo0f8ENeoMBmVWbE93vvuWd4znPO8cxms5lu1s26WTfrZt2s73B5/3/9ADfrZt2sm3Wz/v9v3RiXm3WzbtbNulnf+boxLjfrZt2sm3WzvvN1Y1xu1s26WTfrZn3n68a43KybdbNu1s36zteNcblZN+tm3ayb9Z2vG+Nys27WzbpZN+s7XzfG5WbdrJt1s27Wd75ujMvNulk362bdrO98+a/7gx9++KE8Ho/W19e1tLSkwWCgZrMpSYpGowoGg8rlciqVSup0Onr58qX6/b7S6bQSiYR6vZ6azabG47Ha7bZGo5E8Ho88Ho9CoZDy+bz8fr+63a6Gw6G8Xq+8Xq/6/b6eP3+uRqMhn88nn88nv9+vSCQiSep2uxoMBlpdXdX9+/c1nU51dnamfr8vv98vv9+veDyu1dVV+f1+jUYjTSYTjUYjDYdDe9Zer6eHDx/q9u3bqtfrevXqlfx+v370ox9pa2tLOzs7evr0qSaTif77f//vb7zR//pf/2tJUr/f13A41HQ61WQy0Ww202g00nQ6VTQaVSwW03g8Vr1e13A4VK/X02AwUDweV6FQkMfjUbfb1WQy0fr6ujY3N9Vut/XixQuNx2M9ePBAa2trevXqlX71q1/J6/Xqzp07SqVSqlQqKpfLmk6nGg6HkqRUKqVIJKJms6lyuaxgMKiNjQ1Fo1H1ej37/kajoel0qkQioXA4rEajoUqlYs/o8Xj0k5/8RJ9++qn29/f1s5/9TJL0wx/+UBsbG/rqq6/0y1/+UpPJRH/+53/+xvv3n/7Tf5IkjcdjTSYTTSYTjcdj0WBiNpup3++r1+spGo1qfX1dPp9Pf/mXf6nHjx9rfX1dH3zwgSaTiX7zm9+oUqkol8spl8spkUhoY2NDPp9PX3/9tfb39zUYDNTr9RSLxfTxxx8rl8tpNBppNBpJknw+n8bjsXZ3d1Uul+25IpGItre3FY/H9ezZM718+VJra2v64Q9/qGAwqKdPn+r8/FypVEr5fF7D4VDHx8fq9/vK5XJKp9Pqdruq1WoKBAJ69OiRisWims2m6vW6ZrOZ/uN//I9vvH+S9M//+T+X1+tVKpVSIpGwe8Q9nM1mev36tZ4/f65cLqff+73fUywW0+eff66XL18qEAgoEoloNpup2+1qNBppa2tLt27d0mAw0NnZmQaDgZ3PdDrVeDxWMBjUysqKYrGYwuGwQqGQms2m9vf37cyGw6GWl5e1vb0tSfb5JycnqlQqisViWl1dVSAQUDAYlNfrVafTUbPZlMfjkd9/ocpOTk50fn6uUCikZDKpbDarf/pP/6nee+89/dVf/ZX+5E/+ROPxWP/1v/7XN96/f/Nv/o0kKRKJKBQKqVaraXd3V16vV9vb28pkMqrX66rVagqFQlpaWlIwGFSr1VK32zWd5vF4FIvF5Pf79erVK7148UKBQEDpdFrhcFhra2vK5/OqVCo6ODgw2ZKkpaUlFQoFjcdjdTod26NGo6F4PK50Oi2fz6dQKCSv16vxeKzxeKzhcKhWq6XRaKRGo6Fut6twOKx4PG5ygL4MBAKm330+n+7cuaNcLie3mcu/+3f/7u/cr2sbl48//nju/0ejUaVSKY3HYx0eHqrdbqvX62k8HmswGJgCDAQCisfj6vf7qlarmkwmJmCFQsEu2Pn5ubrdrhqNhnq9nikPv9+vW7duKRQK2cv3ej1VKhW76Ag9hxePxxUMBm1j+ePxeOxn2HCMDQan3+9rNpspkUhIkl6+fKmDgwMNh0PNZjN5vYsFe9FoVJLsuSSZgWk2mxoMBmZoJpOJXRYEqNPpyOPxSJJGo5Fms5l8Pp8KhYIkqdfrqd1ua2dnR41GQ8PhUCsrK/J4PJpMJmo0GgoGg1pdXdVkMlG/3597vkgkoqWlpTlBCwQCmkwmCgaDth/9ft/2qdFo2L55PB6dnJzo+fPnarVaisVikqRarSbpwiisrKxoOp0utH/JZFKz2UydTkf9fl+TyUSdTkeTycTOZDQaaTweq1qtan9/X+PxWP1+X8ViUalUypyLarWqw8NDra6u6v3339dsNjMj6vF4lEgklE6nTZF5PB7bL5Ta8fGxxuOxMpmMCoWCjo+P9fLlS7u0OD+rq6tKp9NmCAOBgGKxmILBoGazmcbjsVqtljqdjmKxmMlEr9dTv9/XycmJer2e/f3brGKxKI/Ho0KhYM/Ena1Wq+r3+/L5fLp165YikYg6nY4Gg4HG47F8Pp/y+by2t7c1m820u7urZrOpQqGgjY0N9ft9eb1eDQYDO59+v69Op6NgMKh8Pq90Oq1YLKZ4PK6zszMdHR3J4/FoeXlZkUhEw+FQOzs78vl8SqfT5nRGo1H5/X75fD5JMseo1+uZEeLeLi8v69GjRzo+Ptbnn3+u/f19bW9vq9/v69WrV+p0OgvL4OHhoSQpnU4rmUxqMBjM6ZN+v6/xeGzOW71el8/nM4Pr8/kUDoc1HA715MkT1Wo15fN53bt3T7PZzOSPu4V8S1I4HDZHOR6PazgcajAYaDqdKhaLyePxaDwe6/T0VF6vV8Fg0IwM/51OpzWdTuX3+xWLxcxRH4/HOj4+1nQ61fr6unK5nDwej5aWliRd3KujoyOFw2GFw+Fr68BrG5cHDx5oOp2qXC6rXq8rHA4rn8+r1+vpxYsXOjk5MQU+nU7tMrreUaPR0Gw2UywWUywW0/Lysu7evatqtaqTkxPzRIhw+v2+4vG43n33XZVKJfl8PgWDQVWrVQ2HQ3W7Xc1mMw2HQ/NmPB6PKUeiIC6x1+tVIBCQz+czQUAhudGMJMXjcY1GI+3v76vT6SidTiuTyZiAv+kKhUJz/4sXgJBgMFzDwUJgMUqz2Uwej8cEBoXbarU0nU5Vq9WUTCZVKpUkSc1mU/1+X5lMRplMxs4HBT0YDBQOhxWNRk0gfT6feZ/BYFDxeFyz2cy87MFgoHa7bRfV6/WaUp/NZiaEnU7HIlEM4SIrFotpMploMBhIkilgV1HMZjMz1l999ZV6vZ42NzdVLBYVj8fl9Xo1nU5Vr9fNu71z547a7ba++eYbdTodSReGNh6PK5PJmHHu9/sKBALm1WG8VlZWtLW1pW63a/J4enqqcDisVCqlQqGgaDRq0anP51MkElEgELCz7XQ6arfbJqu852w2U7lcVrfbVSgUUigUMgdjkZXNZuX1elUsFpXP59VqtXR6eqrRaKRWq6Vms6l0Om1OCXdkPB6bZ33v3j1Np1OTtWw2q1KppG63q36/r36/bwofxRoKhZROp5XP55VMJpVKpcwp8Hq9yufzyufz2t/f1+vXr82ZwYGNx+OmsDGy7BHGrFwuazKZ6OHDh/roo480m810dHSk0WikJ0+eSJLOz8/NCC2yzs/PJcmU8nA4tKgP/cFdnc1marVac+fF+w6HQ3399dfa29vT7//+7+uTTz7RYDDQ8fGxhsOhhsOh2u227b/X6zUjEYlEFI1G7f4jTz6fT7VaTdVqVZLsDqfTaXMUw+GwJNnZImvj8VjlclmDwUClUkmxWEyhUEjxeFzj8Vg7Ozuq1+tKJpPy+/3XlsFrG5d6vS5JtqH9ft8EE8GLx+Py+XyazWa2wcALkpTL5TSdTi1k6/V6KpfLqtVqpiCJPPCQ8Yb4DI/Ho+FwqHw+r36/r0qlYvCcq3yx3h6PR+Fw2MJ+nt81NqlUSuFwWD6fT71ez95nPB6r2+2a0XTD0zddu7u78vl8KpVKBrGgzBBISWYg8dQSiYR8Pp+8Xq/tLR5Lp9PR/v6+ms2mstmsKR8Ebzqdyuv
2024-07-30 18:24:53 +08:00
"text/plain": [
"<Figure size 500x500 with 25 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-31 00:50:02 +08:00
" 25%|██▌ | 50/200 [05:08<15:11, 6.07s/it]"
2024-07-30 18:24:53 +08:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-07-31 00:50:02 +08:00
"Epoch: 50, iteration: 468, Discriminator Loss:array(0.505264, dtype=float32), Generator Loss: array(0.691265, dtype=float32)\n"
2024-07-30 18:24:53 +08:00
]
},
{
"data": {
2024-07-31 00:50:02 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz914+lV3YejD8n55xP5diZ3Uwz1JCcIAkaSLJlW9CVYTgABixAF/5XfOsbYwwbFgwbsjXSJ2kGk8fkkGySnbu6qrty1amTc47fRX3PqnUOm1L1qTZ+wA+1gQbZoc55373XXuFZz1rLMBqNRrhcl+tyXa7Ldble4zL+//oBLtflulyX63L9/9+6NC6X63Jdrst1uV77ujQul+tyXa7Ldble+7o0Lpfrcl2uy3W5Xvu6NC6X63Jdrst1uV77ujQul+tyXa7Ldble+7o0Lpfrcl2uy3W5Xvu6NC6X63Jdrst1uV77ujQul+tyXa7Ldble+zKf+x+aT/+pwWAAANjtdvh8PlgsFjidTlgsFhQKBWSzWVitVoRCIZjNZtRqNTSbTQwGA3S7Xfksk8mE+fl5LC0todVq4fDwEK1WC5VKBc1mc+y7uAKBAPx+PwBgNBrBaDRidnYWgUAAx8fH2NzcRK/XQ7/fx3A4xGg0wmg0gtPpRDgcBgAUi0U0Gg2YTCZYLBYYDAZYLBaYzWbcvHkTa2trOD4+xmeffYZOpwOv1wubzYZ+v49erwcAOD4+fuWNNhpP7bjNZoPFYoHVaoXT6YTRaJT3HA6HGAwGMJlMcDgcMJlMcLvdcDgcyOfz2NnZkWcwGo34N//m3+BP//RP8ezZM/yH//AfcHBw8Pfun9/vh8/nQzAYxPr6OrxeL9bX1xGLxbC5uYlPPvkE1WoVu7u7qNfrGAwGGA6H8Hq9mJubg8lkQj6fl7/r9/swGAwwm80wGo1wu91wOp2w2+0IBAIAgEqlgna7jX6/j36/DwDY3t5+5f17//33MRwOcXR0hGw2i0QigZs3b8JkMiGdTqNer8uZm0wm2O12GAwGtNttdDod+Hw+xONxmM1mmM1mGAwGbG1tYWNjA1arFcFgECaTCa1WC91uF6FQCIuLiwgEAnjvvfcQDofxt3/7t/i7v/s7jEYjeWfKjsPhgMfjAQB5T+53LpfDgwcP0O12EQ6H4XK54Ha74ff70Wq1sLOzg0ajAbvdLrLWarXkfQAgHA4jFovBaDTi008/feX9A87uMD+Td8BqtWJ+fh5erxdHR0c4PDyUv+e/17+MRqPc+U6ng263C6PRCJvNBoPBgGazKXfdYDDAYDDAaDTCaDTKHW6328jn8zCZTPjwww+xvr6Op0+f4qOPPkKn08FgMIBuHsLPACD7YjAYYDKZRN9YrVb5961WC8ViEcCp3nA4HOj1euh0OgCAdDr9yvv3R3/0RxgOh9je3sbx8TGMRiNMJhNsNhuWl5cRCASwt7eHFy9ewOFwYGFhAXa7HdlsFuVyGSaTSWRlfX0dfr8fL168wPPnzwEAFosFRqMRVqsVZrMZPp8P0WgUAFCtVtHr9VCtVkXW9R4ZDAa5kzwzg8EAt9sNl8sl7zAYDFCtVtFutzE7O4uVlRX0ej0cHh6i3W5jdXUV8/PzSKfTuHfvHtrtNux2u9wZ6pT9/f1/cL/ObVyoBHk5+cImkwnD4RBWq1UuTK/XQ6lUwnA4xPLyMhKJBDKZjCh/flaj0cDh4SF6vR5arZYoTpPJBKfTCZ/Ph+FwiFKphG63i3a7jWKxCLPZDLvdDqPRiMPDQ6RSKQyHQ0QiEZhMJng8HpjNZhwdHeHk5EQOzGg0IhgMwu12w2QyyeXh+5VKJTx9+lSUS7/fR6PRQKVSgc1mg9PpPO92fXWj1cWm8ev1emJILBaLvLvZbIbNZoPZbMZwOES73ZZ943MbjUY8efIEP/jBD1Aul9HpdOBwONBqtdDpdGCz2eB2uzEcDlGv19Hv92Gz2eByudDtdvHs2TOYzWak02n4/X5kMhkcHh5iNBohmUzCYDAgm80im83KudvtdiwvL8NgMCCTyWBvbw8WiwWzs7Ow2+0olUooFApwOByiCMrlMtrt9pjcTLNWV1cxGAxQr9dRKBRQr9fx/PlzGAwG1Ot19Ho9UX4WiwU+nw8mk0mcm+FwCJvNBqPRKPtfLpflTAaDAQCIwmo0GtjZ2YHT6US73Ybb7cbe3h5sNptcXL57KBRCs9lErVZDr9dDoVAQeaXDNBwOx96nXq+j2Wyi3+9jNBrBarViNBqJ8qOi5M/1+3153mmXw+GQz7ZYLOh2u2g0GhgOh8hkMqhUKhgMBohEIuj1eqjVahiNRohEIvB6vWi326jVarJXdIRoyPm8RqNRjK/JZBKZHwwGqNVq6Ha7MJlMcka7u7tIp9MYDAZYXV1Ft9tFsVhEu90W54DnxH03Go1wuVzwer1yR9rtNpxOJ1wuF/r9vtwf6paLdro6OTkBcKorotEout2uyF65XJY7tr6+jn6/j1KpBACIxWJYW1sbc7CbzSaazSbq9brsk9PphNlshtVqFSe00+lgOBzK/vV6PblPdGYajYY4UKFQSGS70+mIg0WjRj3jcDjG9BkNR6fTQb1ex2g0kndstVpitLSB+YfWKxkX4NS62mw2eQgqnuFwiGg0ikQigUqlgnQ6jU6ng2QyibfffhtbW1s4OjpCq9WSh2u1Wmi1WgAw5qVxoyORCPr9PprNpngdnU5HLLvBYEChUECr1UI4HMbc3BxcLheSySTsdjva7TZOTk7kGa1Wq1xaelLD4VB+1Wo15PN52O12hEIhuQy1Wm3MoE2z+HO8aNoLZCSnDTjfke89GAxE8VksFphMJuzs7CCdTosHTQ/DaDTCbrfD7/eLEaPStdvtYtT7/T7S6TRcLheazSaq1SqcTifm5uZEmeRyuTEPMZlMwuv1YjQaYXd3FyaTCbFYDB6PB/V6HZVKBb1eT4x5vV4X74eR2jRrZmYGg8EAu7u7MBqNaLVaODg4+IrsjEYj2Gw2Mc6NRgO1Wg0Gg0G+v9FooNfrifzyM/j//PxSqQSj0YhisQir1YrhcCiXk9FLIpHA/Pw8MpkMGo2GGAEaNB2F6NVut9FsNiXSMpvNEt3RGzYYDPJulMWLLKvVCoPBAJfLBYfDgXq9Lh5xsVgUhR8IBNButyVC9fv9mJubQ6VSAQD0ej0xjHQaRqORyBmjDLPZDIvFMmZcqFRdLheCwSCMRiNSqRQqlQoWFhZw8+ZNMf40fJ1OZ8xD53e63W5RgPv7+2i323A4HPL3NGzUG9zn8yrHyZXP5wEAbrcbgUAA9XpdHDc6FsFgEMlkUnRgu93GlStXcP36dTHORGqq1SqazeaY82axWOBwOETeut2unD8jExpn3vdWq4XBYACHw4H5+Xn0ej1xKrvdrhhnbfAHgwFsNpu8G/+83++LcQkGg+j1eshkMvKdr7J35zYuHo9HlAwfjgrR7XaLl5zL5dBsNuWCVCoVHB0dIZ/PywWj56OVgtFoHAu9O52OCLPL5YLdbhfBBDCmTEajEbrdLqrVKrrdrijnbrcLt9stntVwOJRLwfegkNKjpRBaLJaxX7xgOtp5lbW0tITRaIRSqYRqtQrg9KIwpHe73ajX6yiXyyIcWuEQejCZTPB6vbBYLGJ4eZF5wQFI1AVADNVoNEKlUhGBBU6VHIWYnlU+n0ej0UCr1RJobmlpCU6nE41GAycnJ2i1WgLZGQyGMUPp9Xolwnn+/Dk6nY7s37QX+8WLF+j3+6hUKmNQC9+D+0XDqyEDXqrRaCQRDI0EF5Uf9zgWi4lzQyVBmaHB5ruYzWYMBgM0Gg00Go2xvbBarWPy5nQ6xTjzEtPR4XvxHLVxeR2LBpSf3e125U47nc4xh4aQJ+WdhrLT6Yh8AhClx73g89Nh4/2kYtLOXKPRgNFolH1vt9vIZrMYDAYSgfNn9aL+aLfbcpfosevn19/F6FBHWa+6aDhtNpugKpQHGgB6+iaTCaurqzAYDFhaWhKnm851p9MZgw+5v/xMAGNyYbfbx/bS6XQiFovBYDAIlKsNMg0PALkP1CV
2024-07-30 18:24:53 +08:00
"text/plain": [
"<Figure size 500x500 with 25 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-31 00:50:02 +08:00
" 26%|██▌ | 52/200 [05:21<15:25, 6.25s/it]"
2024-07-29 00:18:35 +08:00
]
2024-07-26 21:07:40 +08:00
}
],
"source": [
2024-07-30 00:44:16 +08:00
"# Set your parameters\n",
2024-07-30 18:45:09 +08:00
"n_epochs = 200\n",
2024-07-30 00:44:16 +08:00
"display_step = 5000\n",
2024-07-28 06:10:19 +08:00
"cur_step = 0\n",
2024-07-30 00:44:16 +08:00
"\n",
2024-07-30 07:44:41 +08:00
"batch_size = 128\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-28 06:10:19 +08:00
"D_loss_grad = nn.value_and_grad(disc, disc_loss)\n",
"G_loss_grad = nn.value_and_grad(gen, gen_loss)\n",
"\n",
"\n",
2024-07-30 00:44:16 +08:00
"for epoch in tqdm(range(n_epochs)):\n",
2024-07-28 06:10:19 +08:00
"\n",
2024-07-30 18:21:38 +08:00
" for idx,real in enumerate(batch_iterate(batch_size, train_images)):\n",
2024-07-28 06:10:19 +08:00
" \n",
2024-07-30 00:44:16 +08:00
" # TODO Train Discriminator\n",
2024-07-30 07:06:52 +08:00
" D_loss,D_grads = D_loss_grad(gen, disc,mx.array(real), batch_size, z_dim)\n",
2024-07-28 06:10:19 +08:00
"\n",
" # Update optimizer\n",
" disc_opt.update(disc, D_grads)\n",
" \n",
" # Update gradients\n",
" mx.eval(disc.parameters(), disc_opt.state)\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-30 00:44:16 +08:00
" # TODO Train Generator\n",
" G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)\n",
2024-07-28 06:10:19 +08:00
" \n",
" # Update optimizer\n",
" gen_opt.update(gen, G_grads)\n",
" \n",
" # Update gradients\n",
2024-07-30 07:37:09 +08:00
" mx.eval(gen.parameters(), gen_opt.state) \n",
2024-07-30 07:44:41 +08:00
" \n",
2024-07-30 18:21:38 +08:00
" # if (cur_step + 1) % display_step == 0:\n",
" # print(f\"Step {epoch}: Generator loss: {G_loss}, discriminator loss: {D_loss}\")\n",
" # fake_noise = mx.array(get_noise(batch_size, z_dim))\n",
" # fake = gen(fake_noise)\n",
" # show_images(fake)\n",
" # show_images(real)\n",
" # cur_step += 1\n",
" \n",
2024-07-30 18:45:09 +08:00
" if epoch%50==0:\n",
2024-07-30 18:21:38 +08:00
" print(\"Epoch: {}, iteration: {}, Discriminator Loss:{}, Generator Loss: {}\".format(epoch,idx,D_loss,G_loss))\n",
2024-07-30 07:44:41 +08:00
" fake_noise = mx.array(get_noise(batch_size, z_dim))\n",
" fake = gen(fake_noise)\n",
" show_images(fake)\n",
2024-07-30 18:45:09 +08:00
" # show_images(real) likjmnh jy,t\n",
2024-07-30 18:21:38 +08:00
" \n",
" # print('Losses D={0} G={1}'.format(D_loss,G_loss))"
2024-07-27 05:19:08 +08:00
]
2024-07-26 21:07:40 +08:00
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}