mlx-examples/gan/playground.ipynb

540 lines
112 KiB
Plaintext
Raw Normal View History

2024-07-26 21:07:40 +08:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Import Library"
]
},
{
"cell_type": "code",
2024-07-30 00:44:16 +08:00
"execution_count": 244,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"import mnist"
]
},
{
"cell_type": "code",
2024-07-30 00:44:16 +08:00
"execution_count": 245,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"import mlx.core as mx\n",
"import mlx.nn as nn\n",
"import mlx.optimizers as optim\n",
"\n",
"from tqdm import tqdm\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# GAN Architecture"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generator 👨🏻‍🎨"
]
},
{
"cell_type": "code",
2024-07-30 00:44:16 +08:00
"execution_count": 246,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def GenBlock(in_dim:int,out_dim:int):\n",
" \n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
" nn.BatchNorm(out_dim),\n",
" nn.ReLU()\n",
" )"
]
},
{
"cell_type": "code",
2024-07-30 00:44:16 +08:00
"execution_count": 247,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"class Generator(nn.Module):\n",
"\n",
2024-07-30 00:44:16 +08:00
" def __init__(self, z_dim:int = 10, im_dim:int = 784, hidden_dim: int = 128):\n",
2024-07-26 21:07:40 +08:00
" super(Generator, self).__init__()\n",
" # Build the neural network\n",
" self.gen = nn.Sequential(\n",
" GenBlock(z_dim, hidden_dim),\n",
" GenBlock(hidden_dim, hidden_dim * 2),\n",
" GenBlock(hidden_dim * 2, hidden_dim * 4),\n",
" GenBlock(hidden_dim * 4, hidden_dim * 8),\n",
"\n",
"\n",
" nn.Linear(hidden_dim * 8,im_dim),\n",
" nn.Sigmoid()\n",
" )\n",
" \n",
" def __call__(self, noise):\n",
"\n",
" return self.gen(noise)"
]
},
{
"cell_type": "code",
2024-07-30 00:44:16 +08:00
"execution_count": 248,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Generator(\n",
" (gen): Sequential(\n",
" (layers.0): Sequential(\n",
" (layers.0): Linear(input_dims=100, output_dims=128, bias=True)\n",
" (layers.1): BatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): ReLU()\n",
" )\n",
" (layers.1): Sequential(\n",
" (layers.0): Linear(input_dims=128, output_dims=256, bias=True)\n",
" (layers.1): BatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): ReLU()\n",
" )\n",
" (layers.2): Sequential(\n",
" (layers.0): Linear(input_dims=256, output_dims=512, bias=True)\n",
" (layers.1): BatchNorm(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): ReLU()\n",
" )\n",
" (layers.3): Sequential(\n",
" (layers.0): Linear(input_dims=512, output_dims=1024, bias=True)\n",
" (layers.1): BatchNorm(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): ReLU()\n",
" )\n",
" (layers.4): Linear(input_dims=1024, output_dims=784, bias=True)\n",
" (layers.5): Sigmoid()\n",
" )\n",
")"
]
},
2024-07-30 00:44:16 +08:00
"execution_count": 248,
2024-07-26 21:07:40 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gen = Generator(100)\n",
"gen"
]
},
{
"cell_type": "code",
2024-07-30 00:44:16 +08:00
"execution_count": 249,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-30 00:44:16 +08:00
"# make 2D noise with shape n_samples x z_dim\n",
"def get_noise(n_samples:list[int], z_dim:int)->list[int]:\n",
" return mx.random.normal(shape=(n_samples, z_dim))"
]
},
{
"cell_type": "code",
"execution_count": 250,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAWK0lEQVR4nO3ca2zW9d3H8Q+09ACFcqjQFm05g9QyDnMHhCEoyNg6wE2dhMWwB8tGNLplMdkeLI3JsixhajC6mQ0PmwZiRZgQNyYHcZO2ih0tuBbWTqCAnAoFBpRhue5n3+TOHvT6/JJ7950779fj631dru3FZ/8n336ZTCYjAAAk9f/f/g8AAPzfwSgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAg5Gb7wjfeeMN+8w8++MBuLl26ZDeSdOTIkaTONXv2bLtpaGiwmy9+8Yt2I0kbN260m5kzZ9pNU1OT3eTmZv3n9t+k/E0sWbLEbo4dO2Y3ixcvtpt//etfdpPa/frXv7abtWvX2s2ePXvs5tSpU3YjSfn5+XYzZMgQu+no6LCbadOm2Y2U9rvNycmxm+9///t9voYnBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABD6ZTKZTDYvrK2ttd981KhRdtPa2mo3kjRp0iS7OXfunN2cPHnSbgoKCuwmy1/Lvxk8eLDdpBy327p1q92sWrXKbiTpxIkTdlNVVWU38+bNs5u6ujq7KS8vtxtJysvLs5uU72DKAcLhw4fbzY4dO+xGksaOHWs3X//61+3mueees5uU45JS2r8RW7ZssZtsvrc8KQAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAICQm+0Lp06dar/5xYsX7aaiosJuJKmtrc1uUg6TpRzeSzl29Ytf/MJuJGn16tV28+mnn9rNt7/9bbu5/fbb7UZKO0J47do1uzl8+LDd7Nu3z256e3vtRpK+8pWv2E3K395vf/tbu3nooYfsZuLEiXYjpf1vamlpsZsvfOELdvPhhx/ajSRNnz7dbsaPH5/0WX3hSQEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAELK+kppyZfD48eN2k3LdUpImT55sN2fPnrWbS5cu2U3KtcV7773XbiTppZdespvHHnvMbnbt2mU3ra2tdiNJXV1ddjNz5ky7SbkOmqKkpCSpKyoqspvTp0/bzfz58+2mu7vbbr761a/ajSTV1tbazYwZM+yms7PTbsaMGWM3ktTR0WE3KRd6s8GTAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAhZH8Srqamx33z16tV2853vfMduJGnbtm12k5OTYzezZs2ym+HDh9tNyuE9Ke1nnnJwbvHixXbzj3/8w24kqb293W5SDuJt3brVbr73ve/ZzTvvvGM3kvSzn/3MbqZOnWo3KQfxiouL7WbNmjV2I6X9zA8ePGg3mUzGbsrLy+1GkvLz8+1m2LBhSZ/VF54UAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQMj6IF7K8apx48bZzfXr1+1GkkpLS+1m6NChdjNkyBC7yc3N+sccKisr7UaSdu7caTff+ta37GbPnj12k3I0TZIKCwvtZvfu3XaT8jd05coVu6mqqrIbSaqurrablKOP3d3ddvO73/3OblK+F1LaYcWWlha7OXPmjN2kfC8k6dSpU3bz1FNPJX1WX3hSAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAACHri1Qph+pmzJhhNwcOHLAbSdq7d6/d/OQnP7Gb3/zmN3aTcghu0aJFdiNJ7e3tdrN161a7aW5utpvnn3/ebqS0/00px8zmzJljN729vXbT0dFhN1LaIbgvf/nLdpPy3/fwww/bTWNjo91I0vnz5+2moKDAbqZPn243c+fOtRtJamhosJvUv6O+8KQAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAQr9MJpPJ5oUvvPCC/eZvv/223dx///12I0nPPPOM3UyYMMFuenp67ObcuXN2M3ToULuRpNGjR9tNTk6O3RQVFdlNWVmZ3UjS7Nmz7WbDhg12U1tbazfLly+3m5kzZ9qNJO3evdtuqqur7ebYsWN2k3L8squry26ktJ/fhQsX7ObIkSN2k/q9vXjxot188skndvPiiy/2+RqeFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAITfbF27evNl+8ylTpthNyjVWSbrzzjvt5vz583ZTWlpqNwMHDrSb1Iuin/nMZ+ymvr7ebn71q1/ZTWVlpd1IUm9vr92kXLj8xje+YTeFhYV2k3K5VJLy8/PtpqWlxW5GjRplN0ePHrWbLA80/5uU3+3IkSPtpri42G7+8Ic/2I2UdmX2S1/6UtJn9YUnBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABCyPoh34cIF+81LSkrspqamxm4kaf/+/XZTXl5uN11dXXbT3d1tNxcvXrQbSWpubrabSZMm2c3atWvtZtOmTXYjSQcPHrSb999/326WLFliN5MnT7abxsZGu5Gkjz/+2G5SDivefPPNdrNmzRq7WbFihd2k6ujosJtdu3bZTcrfkCRt377dbioqKpI+qy88KQAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDQL5PJZLJ54dy5c+03P336tN2sWrXKbiSpra3NbiZMmGA3+/bts5vc3KzvDoZFixbZjST985//tJsHH3zQbl588UW7OXTokN1IUnt7u92UlpbazcqVK+3mvffes5thw4bZjSR1dnbazcKFC+2mrq7OboqLi+1m1KhRdiOlHYscMmSI3bS0tNhNlv+c/psBAwbYzZkzZ+wmm8N7PCkAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAkPWlturqavvNOzo67Gb37t12I0k//elP7WbTpk12k3JoLaVpbW21G0kqLy+3m0cffdRu8vLy7Gb69Ol2I0knT560m7vuustumpqa7Gb06NF28+abb9qNJA0aNMhuenp67GbEiBF2M3DgQLvZv3+/3UjSsmXL7ObPf/6z3UycONFupkyZYjeS9PHHH9vN448/nvRZfeFJAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQsr6SOmDAAPvN77nnHrupqKiwGynt4mJjY6PdrFq1ym4uX75sN9euXbMbKe33tGjRIrs5fPiw3VRWVtqNJNXU1NhNyoXegoICu+nu7rabq1ev2o0kLV++3G5SLrIuXbrUbnbt2mU39fX1diNJd9xxh92cO3fObt5++227KSwstBsp7Weeckl5xowZfb6GJwUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQ/kcP4s2dO9du3nrrLbuRpFmzZtlNTk6O3QwaNMhutmzZYjff/e537UaSnn32WbuZNGmS3fT29trNnj177EZKOzp3yy232E3K4cIU1dXVSd2PfvQju3niiSfs5rXXXrObhQsX2s3AgQPtRpLOnz9vNyn/fi1btsxuUg7vSVJ5ebndnDx5Mumz+sKTAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAj9MplMJpsXbty40X7zAwcO2E1eXp7dSNKOHTvs5rbbbrObgoI
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"img = get_noise(28,28)\n",
"plt.imshow(img, cmap='gray')\n",
"plt.axis('off')\n",
"plt.show()"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Discriminator 🕵🏻‍♂️"
]
},
{
"cell_type": "code",
2024-07-30 00:44:16 +08:00
"execution_count": 251,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def DisBlock(in_dim:int,out_dim:int):\n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
2024-07-30 00:44:16 +08:00
" nn.LeakyReLU(negative_slope=0.2),\n",
2024-07-26 21:07:40 +08:00
" )"
]
},
{
"cell_type": "code",
2024-07-30 00:44:16 +08:00
"execution_count": 252,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"class Discriminator(nn.Module):\n",
"\n",
" def __init__(self,im_dim:int = 784, hidden_dim:int = 128):\n",
" super(Discriminator, self).__init__()\n",
"\n",
" self.disc = nn.Sequential(\n",
" DisBlock(im_dim, hidden_dim * 4),\n",
" DisBlock(hidden_dim * 4, hidden_dim * 2),\n",
" DisBlock(hidden_dim * 2, hidden_dim),\n",
"\n",
" nn.Linear(hidden_dim,1),\n",
" )\n",
" \n",
" def __call__(self, noise):\n",
" return self.disc(noise)"
]
},
{
"cell_type": "code",
2024-07-30 00:44:16 +08:00
"execution_count": 253,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Discriminator(\n",
" (disc): Sequential(\n",
" (layers.0): Sequential(\n",
" (layers.0): Linear(input_dims=784, output_dims=512, bias=True)\n",
" (layers.1): LeakyReLU()\n",
" )\n",
" (layers.1): Sequential(\n",
" (layers.0): Linear(input_dims=512, output_dims=256, bias=True)\n",
" (layers.1): LeakyReLU()\n",
" )\n",
" (layers.2): Sequential(\n",
" (layers.0): Linear(input_dims=256, output_dims=128, bias=True)\n",
" (layers.1): LeakyReLU()\n",
" )\n",
" (layers.3): Linear(input_dims=128, output_dims=1, bias=True)\n",
" )\n",
")"
]
},
2024-07-30 00:44:16 +08:00
"execution_count": 253,
2024-07-26 21:07:40 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"disc = Discriminator()\n",
"disc"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model Training 🏋🏻‍♂️"
]
},
{
2024-07-30 00:44:16 +08:00
"cell_type": "markdown",
2024-07-26 21:07:40 +08:00
"metadata": {},
"source": [
2024-07-30 00:44:16 +08:00
"### Losses"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-07-30 00:44:16 +08:00
"#### Discriminator Loss"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 00:44:16 +08:00
"execution_count": 254,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-26 21:36:29 +08:00
"def disc_loss(gen, disc, real, num_images, z_dim):\n",
2024-07-30 00:44:16 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" noise = mx.array(get_noise(num_images, z_dim))\n",
" fake_images = gen(noise)\n",
2024-07-27 06:09:51 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" fake_disc = disc(fake_images)\n",
" \n",
2024-07-27 05:19:08 +08:00
" fake_labels = mx.zeros((fake_images.shape[0],1))\n",
2024-07-30 00:44:16 +08:00
" \n",
2024-07-26 21:36:29 +08:00
" fake_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)\n",
2024-07-26 21:07:40 +08:00
" \n",
" real_disc = disc(real)\n",
2024-07-27 05:19:08 +08:00
" real_labels = mx.ones((real.shape[0],1))\n",
"\n",
2024-07-27 06:09:51 +08:00
" real_loss = nn.losses.binary_cross_entropy(real_disc,real_labels,with_logits=True)\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-30 00:44:16 +08:00
" disc_loss = (fake_loss + real_loss) / 2.0\n",
2024-07-26 21:07:40 +08:00
"\n",
" return disc_loss"
]
},
2024-07-30 00:44:16 +08:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Generator Loss"
]
},
2024-07-26 21:07:40 +08:00
{
"cell_type": "code",
2024-07-30 00:44:16 +08:00
"execution_count": 255,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-26 21:36:29 +08:00
"def gen_loss(gen, disc, num_images, z_dim):\n",
2024-07-26 21:07:40 +08:00
"\n",
" noise = mx.array(get_noise(num_images, z_dim))\n",
" fake_images = gen(noise)\n",
" fake_disc = disc(fake_images)\n",
"\n",
2024-07-27 05:19:08 +08:00
" fake_labels = mx.ones((fake_images.shape[0],1))\n",
2024-07-30 00:44:16 +08:00
" \n",
2024-07-26 21:36:29 +08:00
" gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)\n",
2024-07-26 21:07:40 +08:00
"\n",
" return gen_loss"
]
},
{
"cell_type": "code",
2024-07-30 00:44:16 +08:00
"execution_count": 256,
"metadata": {},
"outputs": [],
"source": [
"# Get only the training images\n",
"train_images,*_ = map(np.array, mnist.mnist())"
]
},
{
"cell_type": "code",
"execution_count": 257,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-30 00:44:16 +08:00
"# Normalize the images to fall between -1,1\n",
"train_images = train_images * 2.0 - 1.0"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 00:44:16 +08:00
"execution_count": 258,
2024-07-29 06:24:50 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2024-07-30 00:44:16 +08:00
"<matplotlib.image.AxesImage at 0x302feb1f0>"
2024-07-29 06:24:50 +08:00
]
},
2024-07-30 00:44:16 +08:00
"execution_count": 258,
2024-07-29 06:24:50 +08:00
"metadata": {},
"output_type": "execute_result"
2024-07-30 00:44:16 +08:00
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaEAAAGdCAYAAAC7EMwUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAaI0lEQVR4nO3df2jU9x3H8dfVH1d1lytBk7vUmGVF202dpWrVYP3R1cxApf4oWMtGZEPa+YOJ/cGsDNNBjdgpRdI6V0amW239Y9a6KdUMTXRkijpdRYtYjDOdCcFM72LUSMxnf4hHz1j1e975vkueD/iCufu+vY/ffuvTby75xueccwIAwMBD1gsAAHRfRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJjpab2AW3V0dOjcuXMKBALy+XzWywEAeOScU0tLi/Ly8vTQQ3e+1km7CJ07d075+fnWywAA3Kf6+noNHDjwjvuk3afjAoGA9RIAAElwL3+fpyxCH3zwgQoLC/Xwww9r5MiR2rdv3z3N8Sk4AOga7uXv85REaPPmzVq8eLGWLVumI0eO6JlnnlFJSYnOnj2bipcDAGQoXyruoj1mzBg99dRTWrduXeyx73//+5o+fbrKy8vvOBuNRhUMBpO9JADAAxaJRJSVlXXHfZJ+JXTt2jUdPnxYxcXFcY8XFxertra20/5tbW2KRqNxGwCge0h6hM6fP6/r168rNzc37vHc3Fw1NjZ22r+8vFzBYDC28ZVxANB9pOwLE259Q8o5d9s3qZYuXapIJBLb6uvrU7UkAECaSfr3CfXv3189evTodNXT1NTU6epIkvx+v/x+f7KXAQDIAEm/Eurdu7dGjhypqqqquMerqqpUVFSU7JcDAGSwlNwxYcmSJfrpT3+qUaNGady4cfr973+vs2fP6tVXX03FywEAMlRKIjR79mw1NzfrN7/5jRoaGjRs2DDt2LFDBQUFqXg5AECGSsn3Cd0Pvk8IALoGk+8TAgDgXhEhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmelovAEgnPXr08DwTDAZTsJLkWLhwYUJzffv29Tzz+OOPe55ZsGCB55nf/va3nmfmzJnjeUaSrl696nlm5cqVnmfefvttzzNdBVdCAAAzRAgAYCbpESorK5PP54vbQqFQsl8GANAFpOQ9oaFDh+rvf/977ONEPs8OAOj6UhKhnj17cvUDALirlLwndOrUKeXl5amwsFAvvfSSTp8+/a37trW1KRqNxm0AgO4h6REaM2aMNm7cqJ07d+rDDz9UY2OjioqK1NzcfNv9y8vLFQwGY1t+fn6ylwQASFNJj1BJSYlmzZql4cOH67nnntP27dslSRs2bLjt/kuXLlUkEolt9fX1yV4SACBNpfybVfv166fhw4fr1KlTt33e7/fL7/enehkAgDSU8u8Tamtr05dffqlwOJzqlwIAZJikR+j1119XTU2N6urqdODAAb344ouKRqMqLS1N9ksBADJc0j8d9/XXX2vOnDk6f/68BgwYoLFjx2r//v0qKChI9ksBADJc0iP0ySefJPu3RJoaNGiQ55nevXt7nikqKvI8M378eM8zkvTII494npk1a1ZCr9XVfP31155n1q5d63lmxowZnmdaWlo8z0jSv//9b88zNTU1Cb1Wd8W94wAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAMz7nnLNexDdFo1EFg0HrZXQrTz75ZEJzu3fv9jzDf9vM0NHR4XnmZz/7meeZS5cueZ5JRENDQ0JzFy5c8Dxz8uTJhF6rK4pEIsrKyrrjPlwJAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwExP6wXA3tmzZxOaa25u9jzDXbRvOHDggOeZixcvep6ZPHmy5xlJunbtmueZP/3pTwm9Fro3roQAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADPcwBT63//+l9DcG2+84Xnm+eef9zxz5MgRzzNr1671PJOoo0ePep6ZMmWK55nW1lbPM0OHDvU8I0m//OUvE5oDvOJKCABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAw43POOetFfFM0GlUwGLReBlIkKyvL80xLS4vnmfXr13uekaSf//znnmd+8pOfeJ75+OOPPc8AmSYSidz1/3muhAAAZogQAMCM5wjt3btX06ZNU15ennw+n7Zu3Rr3vHNOZWVlysvLU58+fTRp0iQdP348WesFAHQhniPU2tqqESNGqKKi4rbPr1q1SmvWrFFFRYUOHjyoUCikKVOmJPR5fQBA1+b5J6uWlJSopKTkts855/Tee+9p2bJlmjlzpiRpw4YNys3N1aZNm/TKK6/c32oBAF1KUt8TqqurU2Njo4qLi2OP+f1+TZw4UbW1tbedaWtrUzQajdsAAN1DUiPU2NgoScrNzY17PDc3N/bcrcrLyxUMBmNbfn5+MpcEAEhjKfnqOJ/PF/exc67TYzctXbpUkUgkttXX16diSQCANOT5PaE7CYVCkm5cEYXD4djjTU1Nna6ObvL7/fL7/clcBgAgQyT1SqiwsFChUEhVVVWxx65du6aamhoVFRUl86UAAF2A5yuhS5cu6auvvop9XFdXp6NHjyo7O1uDBg3S4sWLtWLFCg0ePFiDBw/WihUr1LdvX7388stJXTgAIPN5jtChQ4c0efLk2MdLliyRJJWWluqPf/yj3nzzTV25ckXz58/XhQsXNGbMGO3atUuBQCB5qwYAdAncwBRd0rvvvpvQ3M1/VHlRU1Pjeea5557zPNPR0eF5BrDEDUwBAGmNCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZriLNrqkfv36JTT317/+1fPMxIkTPc+UlJR4ntm1a5fnGcASd9EGAKQ1IgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMNzAFvuGxxx7zPPOvf/3L88zFixc9z+zZs8fzzKFDhzzPSNL777/veSbN/ipBGuAGpgCAtEaEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmOEGpsB9mjFjhueZyspKzzOBQMDzTKLeeustzzMbN270PNPQ0OB5BpmDG5gCANIaEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGG5gCBoYNG+Z5Zs2aNZ5nfvSjH3meSdT69es9z7zzzjueZ/773/96noENbmAKAEhrRAgAYMZzhPbu3atp06YpLy9PPp9PW7dujXt+7ty58vl8cdvYsWOTtV4AQBfiOUKtra0aMWKEKioqvnWfqVOnqqGhIbbt2LHjvhYJAOiaenodKCkpUUlJyR338fv9CoVCCS8KANA9pOQ9oerqauXk5GjIkCGaN2+empqavnXftrY2RaPRuA0A0D0kPUIlJSX66KOPtHv3bq1evVoHDx7Us88+q7a2ttvuX15ermAwGNvy8/OTvSQAQJry/Om4u5k9e3bs18OGDdOoUaNUUFCg7du3a+bMmZ32X7p0qZYsWRL7OBqNEiI
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
2024-07-29 06:24:50 +08:00
}
],
"source": [
2024-07-30 00:44:16 +08:00
"plt.imshow(train_images[0].reshape(28,28),cmap='gray')"
2024-07-29 06:24:50 +08:00
]
},
{
"cell_type": "code",
2024-07-30 00:44:16 +08:00
"execution_count": 259,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-30 00:44:16 +08:00
"def batch_iterate(batch_size: int, ipt: list[int])-> list[int]:\n",
2024-07-29 06:24:50 +08:00
" perm = np.random.permutation(len(ipt))\n",
" for s in range(0, len(ipt), batch_size):\n",
2024-07-26 21:07:40 +08:00
" ids = perm[s : s + batch_size]\n",
2024-07-30 00:44:16 +08:00
" yield ipt[ids]"
]
},
{
"cell_type": "code",
"execution_count": 260,
"metadata": {},
"outputs": [],
"source": [
"def show_images(imgs:list[int],num_imgs:int = 25):\n",
" if (imgs.shape[0] > 0): \n",
" fig,axes = plt.subplots(5, 5, figsize=(5, 5))\n",
" \n",
" for i, ax in enumerate(axes.flat):\n",
" img = mx.array(imgs[i]).reshape(28,28)\n",
" ax.imshow(img,cmap='gray')\n",
" ax.axis('off')\n",
" plt.show()"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-07-30 00:44:16 +08:00
"### show first batch of images"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 00:44:16 +08:00
"execution_count": 261,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
2024-07-30 00:44:16 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAD3kklEQVR4nOy9V5OcV5rf+Uvvvc/KyvIFU/CGIIBm03RT3cORZlrqmZGPidjVboQ+xH4B3eyVdLGxitDEhDSSttUT49Q9zSabFvSwBVu+MrMqvfd2L6BzmAVHkCygDN5fRAVIVCIr31Pve57zuP+jGgwGAxQUFBQUFLYR9U5/AAUFBQWF/YdiXBQUFBQUth3FuCgoKCgobDuKcVFQUFBQ2HYU46KgoKCgsO0oxkVBQUFBYdtRjIuCgoKCwrajGBcFBQUFhW1HMS4KCgoKCtuO9mlfqFKpnuXn2FN8F1EDZf2+Rlm/78d3FdVQ1vBrlHvw+/E066d4LgoKCgoK245iXBQUFBQUth3FuCgoKCgobDuKcVFQUFBQ2HYU46KgoKCgsO0oxkVBQUFBYdt56lLk3cSTSgKV2WcKCs+O7SzHVZ7V/c2eMC4ajQa1Ws3IyAjhcBiTyYTP50Ov12O32zEajWQyGdbX16nVaiwtLVGpVOj1evR6vZ3++AoKexaVSoXb7cZms+F0OolEIhiNRnw+HxaL5Tu/by6X48aNG1SrVfL5PLVajU6nQ7vd3sZPr7CT7HrjolKp0Gg0aLVapqamePnll/H5fMzNzWG324lEIrjdbubn53n//fdJJBKUy2U6nQ7NZlMxLgoK3wO1Wo3P5yMajTI5Ocn58+dxu90cO3YMn88nX6dSqaRXMxgMvtEruXPnDn/+53/O5uYmd+7cYXNzk0ajQafTUTyafcKuNS7iZjUYDIRCIWw2G5OTk4yNjeFyufB4PFgsFoxGIxqNBpvNRiQSQavVMjExgVarJZ1Ok06nd/pSFPYwarUao9GIVqvFZrNhs9nQarXyvnsSrVaLZrNJs9mkVCrJA0+3231On/77IzwXETXw+/04nU4sFgsGg0G+bjAY0Ov1GAwGaDQaVCrVFiOhVqu3hNQcDgeTk5NYrVYGgwEOh4NCoUAqlaLdblOpVPbUOik8jGrwlMeE5y19oNFoMBgMeL1efv7znzM5Ocnc3ByHDh2SD7darabT6dDtduVpqVAo8PHHH5NIJPjtb3/LBx98sO0nIUU64vuxl9bPYDAwPj6O0+nk1KlTnD59GqfTyfj4+JbN9UEGgwGJRIJkMsna2hrvv/8+hUKBtbU1CoXC9/pMz1P+xWg08q//9b/mzTffJBAIcPDgQQwGgzS4gna7TalUotvtYjab0ev19Pt9BoMBKpXqIWPcarUolUo0m03W19fJ5/MsLCzwxRdfkMvluHr16vdepyexl+7B3cjTrN+u9Vw0Gg16vR6z2Uw4HGZ8fJxwOIzX62UwGNDtdun1ejQaDZrNJkajEavVSrfbxeVy0Wq1MJvNO30ZzwW1Wo1er0elUskTYq/Xo91uMxgM6Pf73+l9NRrNlg1BnE6/6/vtJYbvP4/Hg9vtJhqNMjMzg9PpZHJyEpPJtGUDFWuvVqsZDAbYbDasVitqtZo7d+6gVqtJp9OUy2X573Y7KpUKq9WK1+vF7XZjt9ulURn+/N1ul3K5LO85ca+0223pyeh0OtRqNWq1GoPBQCAQoNfroVarcTqdNBoNVldXGQwG6PX6nbrkXYlarZZ/ivtMPJv9fl/eT8PPpvg97NR9tmuNi9vtZnp6mtHRUY4cOcKBAwcwGAwUi0WSySSfffYZpVKJUqlEvV7HYrHgcDhotVqsrq5SLpdZW1vbEw/w9yUUCvHmm2/i9Xrx+/3Y7XZu3brFe++9R7lcJpVK0Wg0vtV7qlQqZmZmmJ6eln/Xbre5ceMGm5ub230Ju44DBw7w6quv4na7OXDgAC6Xi1AoRCgUotfrkUwm6XQ6ZDIZyuUyDocDn8+H0WgkGAxiNBrxeDzS+w4Gg+RyOX75y19y/fp1isUimUxm39yfsViM//f//X/J5/OcP3+eAwcOkMlkWFlZQafTSaMciUSIRCLy3wnDYjKZaLVa1Ot14vE4N2/elLnTFz08ptVq8Xq98t4KBALY7Xai0SgajYZYLEaxWJT3VLfblV/i73fkc+/IT30KrFYrU1NTjI6OMjExwdjYmDQm8Xicd955h3Q6TS6Xo16vY7Vacblc9Pt9CoUCnU7nhcm3uFwuXn/9dSYnJ5meniYYDPKb3/yGlZUVUqkUhULhWxsXtVpNJBLh9OnTMhxQq9WIx+MvhHGJRCL83u/9Hn6/nwMHDuB0OuX3isUiGxsbVCoVFhcXSaVShEIh2u02drsdt9uN0WiUOZpQKMTBgwcpFovcvXuXbDZLv98nm83uG+OSyWT4n//zfxKLxaQnvbS0xOeff47RaKTRaBAKhTCZTIyMjMh7SnhGFouFdrtNo9FAp9PhcDhkaE0xLlqcTidOp5MDBw4wOzuL3+/n9OnT6HQ6rl+/TiwWY2Njg4WFBZnbE9V3inF5AKvVSjgcJhAI0O/3qVQq3L17lzt37rC+vk4sFqNQKFCr1Wi1WrTbbdrtNv1+n3q9TrfbpdVq7fRlPBf0ej1utxufz4fJZAIgGAzyyiuvkEgkSKfTlEqlp3ovjUYjy7snJiaYm5uTLnmlUuH69etks1nq9Tq1Wm1fbI46nQ6bzYZerycYDOJyuTh16pQsJKlWq9IjXltbo1qtsrGxQb1el2sbj8dZX1/HaDQyPz+PxWJhcnKSaDSKwWCQhQBHjhxBrVaztLSExWKhVquRSCS+tfF/XgwGAxqNBsViEY1GQzqdRqfTPfS6QqFAt9ul3+/TbDapVCryJK1Wq5mfnycej1MoFIjFYrhcLg4cOCDzMyKP6vF46Ha7nDp1CpfLxcrKCmtra/R6PZlbfVEwm804HA4cDgfnz58nHA4TDAYJhUI4HA6cTidqtRqLxYLZbMbr9crqWIvFglqt5tq1a+j1ehqNBtls9rlWz+5a4+JyuTh69Cgej4der0cmk+G9997jf/yP/0G5XCaRSEhjAo8uhXwRcgMAJpOJ8fFxpqamZBz20KFDjI+Ps7i4yJdffsnq6upTvZder2dkZASPx8P58+f5vd/7PWlcxMm7UqmQTCZlfHyvYzKZiEajuFwu3njjDY4ePcrIyAiHDx+m1+uxurpKoVDgv//3/84vfvELme8bzkGp1WqZWxD5mn/2z/4Z/+Sf/BPcbjdTU1NYLBbeeust3nzzTb766iveffddEokEv/71r3e1cSkUCsTjcarVKr1e75FVcvF4XJYRi/sjHo+ztLREu91meXkZrVaL3W7Hbrdz5MgR/o//4/+QeVSr1YrdbsdqtRIIBFCpVGQyGf72b/+WYrEoDdZ+uN++CbGPud1uDh48yOjoKP/b//a/cfDgQfk9rVaLyWSi2+3idrupVCo4nU6i0ShWq5UjR45gsVj49a9/zbvvvsvGxgZffPEF9Xr9uV3HrjMuwkjo9XpZatzpdOj3+5RKJXlqbrVaL7y7bLFYsNlseL1eTCbTluodnU4nT+TDf/9NqNVq+aBbLBYsFou8obvdLsFgkGg0SrvdZn19fV8YcJ1Oh9frxev1EggEZM5EbGrxeJxcLsfGxoY8iZvNZmlERB+WVquVnnOn06FQKMjGXrPZjNlsxm6343A4ZDl9tVr9Vr+f581gMNiStxsMBvKwMUw6nZbGRTRGigS/OASq1Wp6vR7NZpNUKsXm5qasJDMYDKjVaunBuN1uVCqVLHuG+2HZ/XC/fRN6vR6dTidzVOFwGIfDgdlsplarUavVUKvV8j7L5/Pk83l0Op0shLBYLLjdbpxOJw6Hg2Kx+Mjf27Nk193VJpMJvV6Pw+GQsetMJiPj/SJh9aI3R6pUKl5++WX+wT/4B/LUvR0YDAZGR0eJRCJ4PJ4t3zObzfyjf/SPuHjxIr/4xS+4efPmvjDwLpeLV199lbGxMY4ePcrY2Bg3btzgv/23/0Yul+PevXsUCgU2Nzfp9/v4/X5efvllXC4
2024-07-26 21:07:40 +08:00
"text/plain": [
2024-07-29 06:24:50 +08:00
"<Figure size 500x500 with 25 Axes>"
2024-07-26 21:07:40 +08:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
2024-07-30 00:44:16 +08:00
"X = batch_iterate(25, train_images)\n",
"for x in X: \n",
" show_images(x)\n",
2024-07-26 21:07:40 +08:00
" break"
]
},
{
"cell_type": "code",
2024-07-30 00:44:16 +08:00
"execution_count": 262,
2024-07-27 06:09:51 +08:00
"metadata": {},
2024-07-29 06:24:50 +08:00
"outputs": [],
2024-07-27 06:09:51 +08:00
"source": [
"z_dim = 64\n",
"gen = Generator(z_dim)\n",
"mx.eval(gen.parameters())\n",
"gen_opt = optim.Adam(learning_rate=lr)\n",
"\n",
"disc = Discriminator()\n",
"mx.eval(disc.parameters())\n",
2024-07-30 00:44:16 +08:00
"disc_opt = optim.Adam(learning_rate=lr)"
2024-07-29 06:30:08 +08:00
]
},
{
2024-07-30 00:44:16 +08:00
"cell_type": "markdown",
2024-07-29 06:30:08 +08:00
"metadata": {},
"source": [
2024-07-30 00:44:16 +08:00
"### Training Cycle"
2024-07-29 06:30:08 +08:00
]
},
2024-07-27 06:09:51 +08:00
{
"cell_type": "code",
2024-07-30 00:44:16 +08:00
"execution_count": 263,
2024-07-27 06:09:51 +08:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-30 00:44:16 +08:00
" 0%| | 0/200 [00:00<?, ?it/s]"
2024-07-29 00:18:35 +08:00
]
2024-07-26 21:07:40 +08:00
}
],
"source": [
2024-07-30 00:44:16 +08:00
"# Set your parameters\n",
"n_epochs = 200\n",
"z_dim = 64\n",
"display_step = 5000\n",
2024-07-28 06:10:19 +08:00
"cur_step = 0\n",
2024-07-30 00:44:16 +08:00
"\n",
"lr = 0.00001\n",
"\n",
"batch_size = 128\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-28 06:10:19 +08:00
"D_loss_grad = nn.value_and_grad(disc, disc_loss)\n",
"G_loss_grad = nn.value_and_grad(gen, gen_loss)\n",
"\n",
"\n",
2024-07-30 00:44:16 +08:00
"for epoch in tqdm(range(n_epochs)):\n",
2024-07-28 06:10:19 +08:00
"\n",
2024-07-30 00:44:16 +08:00
" for real in batch_iterate(batch_size, train_images):\n",
2024-07-28 06:10:19 +08:00
" \n",
2024-07-30 00:44:16 +08:00
" # TODO Train Discriminator\n",
" D_loss,D_grads = D_loss_grad(gen, disc, mx.array(real), batch_size, z_dim)\n",
2024-07-28 06:10:19 +08:00
"\n",
" # Update optimizer\n",
" disc_opt.update(disc, D_grads)\n",
" \n",
" # Update gradients\n",
" mx.eval(disc.parameters(), disc_opt.state)\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-30 00:44:16 +08:00
" # TODO Train Generator\n",
" G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)\n",
2024-07-28 06:10:19 +08:00
" \n",
" # Update optimizer\n",
" gen_opt.update(gen, G_grads)\n",
" \n",
" # Update gradients\n",
" mx.eval(gen.parameters(), gen_opt.state)\n",
" \n",
2024-07-29 06:24:50 +08:00
" \n",
2024-07-30 00:44:16 +08:00
" if (cur_step + 1) % display_step == 0:\n",
" print(f\"Step {epoch}: Generator loss: {G_loss}, discriminator loss: {D_loss}\")\n",
" fake_noise = mx.array(get_noise(batch_size, z_dim))\n",
" fake = gen(fake_noise)\n",
" show_images(fake)\n",
" show_images(real)\n",
" cur_step += 1"
2024-07-27 05:19:08 +08:00
]
2024-07-26 21:07:40 +08:00
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}