mlx-examples/gan/playground.ipynb

539 lines
100 KiB
Plaintext
Raw Normal View History

2024-07-26 21:07:40 +08:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Import Library"
]
},
{
"cell_type": "code",
2024-07-30 07:06:52 +08:00
"execution_count": 369,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"import mnist"
]
},
{
"cell_type": "code",
2024-07-30 07:06:52 +08:00
"execution_count": 370,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"import mlx.core as mx\n",
"import mlx.nn as nn\n",
"import mlx.optimizers as optim\n",
"\n",
"from tqdm import tqdm\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# GAN Architecture"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generator 👨🏻‍🎨"
]
},
{
"cell_type": "code",
2024-07-30 07:06:52 +08:00
"execution_count": 371,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def GenBlock(in_dim:int,out_dim:int):\n",
" \n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
2024-07-30 07:06:52 +08:00
" nn.BatchNorm(out_dim, 0.8),\n",
" nn.LeakyReLU(0.2)\n",
2024-07-26 21:07:40 +08:00
" )"
]
},
{
"cell_type": "code",
2024-07-30 07:06:52 +08:00
"execution_count": 393,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"class Generator(nn.Module):\n",
"\n",
2024-07-30 07:06:52 +08:00
" def __init__(self, z_dim:int = 10, im_dim:int = 784, hidden_dim: int = 256):\n",
2024-07-26 21:07:40 +08:00
" super(Generator, self).__init__()\n",
" # Build the neural network\n",
" self.gen = nn.Sequential(\n",
" GenBlock(z_dim, hidden_dim),\n",
" GenBlock(hidden_dim, hidden_dim * 2),\n",
" GenBlock(hidden_dim * 2, hidden_dim * 4),\n",
" GenBlock(hidden_dim * 4, hidden_dim * 8),\n",
"\n",
"\n",
" nn.Linear(hidden_dim * 8,im_dim),\n",
" )\n",
" \n",
" def __call__(self, noise):\n",
2024-07-30 07:06:52 +08:00
" x = self.gen(noise)\n",
" return mx.tanh(x)"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 07:06:52 +08:00
"execution_count": 394,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Generator(\n",
" (gen): Sequential(\n",
" (layers.0): Sequential(\n",
2024-07-30 07:06:52 +08:00
" (layers.0): Linear(input_dims=100, output_dims=256, bias=True)\n",
" (layers.1): BatchNorm(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): LeakyReLU()\n",
2024-07-26 21:07:40 +08:00
" )\n",
" (layers.1): Sequential(\n",
2024-07-30 07:06:52 +08:00
" (layers.0): Linear(input_dims=256, output_dims=512, bias=True)\n",
" (layers.1): BatchNorm(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): LeakyReLU()\n",
2024-07-26 21:07:40 +08:00
" )\n",
" (layers.2): Sequential(\n",
2024-07-30 07:06:52 +08:00
" (layers.0): Linear(input_dims=512, output_dims=1024, bias=True)\n",
" (layers.1): BatchNorm(1024, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): LeakyReLU()\n",
2024-07-26 21:07:40 +08:00
" )\n",
" (layers.3): Sequential(\n",
2024-07-30 07:06:52 +08:00
" (layers.0): Linear(input_dims=1024, output_dims=2048, bias=True)\n",
" (layers.1): BatchNorm(2048, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): LeakyReLU()\n",
2024-07-26 21:07:40 +08:00
" )\n",
2024-07-30 07:06:52 +08:00
" (layers.4): Linear(input_dims=2048, output_dims=784, bias=True)\n",
2024-07-26 21:07:40 +08:00
" )\n",
")"
]
},
2024-07-30 07:06:52 +08:00
"execution_count": 394,
2024-07-26 21:07:40 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gen = Generator(100)\n",
"gen"
]
},
{
"cell_type": "code",
2024-07-30 07:06:52 +08:00
"execution_count": 374,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-30 00:44:16 +08:00
"# make 2D noise with shape n_samples x z_dim\n",
"def get_noise(n_samples:list[int], z_dim:int)->list[int]:\n",
" return mx.random.normal(shape=(n_samples, z_dim))"
]
},
{
"cell_type": "code",
2024-07-30 07:06:52 +08:00
"execution_count": 375,
2024-07-30 00:44:16 +08:00
"metadata": {},
"outputs": [
{
"data": {
2024-07-30 07:06:52 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAWL0lEQVR4nO3cXWzW9d3H8U+hVChQO7oWWgq1lDKFOmdpwWnBhmWEqYiLDoWNwJI9xEGyGN2WbOMAs5EsJtPEZZlb5qYOlFVlM50PgEp5cD4AhVVKiSJQpUBhFKEtZVCv++ybmPug1+d34H1neb+Or/d1QR/48D/55mQymYwAAJA07P/6DwAA+P+DUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEDIzfaFP/zhD+03f+CBB+wm5XMkqa6uzm5KSkrspr293W7uueceu3nppZfsRpJ27txpNw0NDXZz/fXX283evXvtRpJmzJhhN4cPH7abffv22c2sWbPsprCw0G4kqaWlxW6mTJliN83NzXZz7bXX2s3UqVPtRpI+/PBDu7npppvs5uzZs3bT3d1tN5LU0dFhNym/g/fff/+Qr+FJAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAIScTCaTyeaFS5cutd+8rKzMbsrLy+1Gkrq6uuwm5WhayhGvqqoqu0k9rNXa2mo3RUVFdnPu3Dm7GTt2rN1I0sDAgN1MnjzZbmbOnGk369ats5vBwUG7kaS7777bbjo7O+1m2rRpdtPU1GQ3Kb+zklRfX283X/va1+zmr3/9q928//77diNJEydOtJuUQ5HZHCnlSQEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAACE3GxfOHv2bPvNf/KTn9jNz372M7uRpOPHj9tNT0+P3Zw5c8Zuzp49azelpaV2I0kjR460m5ycHLsZNsz//8SpU6fsRpKuuuoqu3n22Wft5uOPP7ab6667zm727t1rN5K0ceNGuykoKLCbLVu22E3Kz8OECRPsRpJKSkrspr+/327y8vLsZu7cuXYjSaNGjbKbY8eOJX3WUHhSAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAACErK+kvvrqq/ab//znP7ebwsJCu5GkO++8025SrlXecMMNdvPkk0/aTSaTsRsp7fJkytcu5ZLmm2++aTdS2s9ETU2N3Zw7d85uUv5sKddYJWnhwoV2M3HiRLtpamqym4qKCrs5evSo3UjSJ598Yje7du2ym5TvbcrnSNL8+fPtZt68eUmfNRSeFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEDIyWR5ee2ZZ56x37yqqspuHn74YbuRpNmzZ9vNvn377CY3N+sbgmHUqFF2MzAwYDdS2hGvixcv2s3UqVPtprOz024k6cSJE3bzxS9+0W52795tN0VFRXZTWVlpN5J04MABu8nPz7eblINzK1assJuWlha7kaRjx47ZTcrfqa6uzm5Sf297e3vt5oknnrCb1tbWIV/DkwIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIWV93O3XqlP3mmzZtsptp06bZjSTV1NTYzenTp+3m8OHDdjNlyhS7+cIXvmA3krR27Vq7WbRokd2cOXPGbi5fvmw3kjR+/Hi76e/vt5srr7zSburr6+3m6NGjdpPa3XbbbXYzZswYu9m/f7/dpP48pPzspXztUv58OTk5diNJ06dPt5sf//jHSZ81FJ4UAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQMj6IF7KEari4mK72bZtm91I0owZM+zm5MmTdjNhwgS7aW9vt5uLFy/ajSStXr3abnbv3m03Kd/blMOAkvTGG2/Yzc0332w311xzjd2kHH3s6emxGyntMODAwIDdnDhxwm4+/vjjz+RzJKm2ttZuPve5z9nN8uXL7eaJJ56wG0l6++237Wbv3r12s2TJkiFfw5MCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACFkfxNu1a5f95jU1NXZTUVFhN5L0wgsv2M2PfvQju3nzzTftprS01G4mT55sN5K0ceNGu5k7d67dnD9/3m5SDnhJ0iuvvGI3/f39drNs2TK72bp1q92sXLnSbiRp+PDhdlNdXW03nZ2ddpOfn283b731lt1IUm5u1v9shXnz5tlNb2+v3cycOdNuJOngwYOf2WcNhScFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEDIyWQymWxe+Jvf/MZ+8+7ubru5dOmS3UhpFxfnz59vN++++67d3H777XZz33332Y0k/fKXv7SblMuvPT09dlNbW2s3ktTR0WE3o0aNspuUy6/19fV2M3LkSLuRpP/85z92k3Ld+O6777ablCufKVdIJempp56ym4KCArtZunSp3bz33nt2I0llZWV2U1RUZDf33nvvkK/hSQEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAACErA/ibd++3X7zlINX//znP+1GSjuaNmvWLLtZsGCB3Tz33HN2s2jRIruRpI0bN9rNmDFj7GbatGl289vf/tZuJKmystJudu7caTdf//rX7aa8vNxu2tra7EaSSktL7Wb06NF2c+HCBbtJOZDY2dlpN5I0fvx4u0k5kJjy71dfX5/dSNKePXvsJuX3dseOHUO+hicFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAELI+iPerX/3KfvPi4mK7STnyJEmbNm2ym5QjWVl+uT4l5UhWSUmJ3UjSe++9Zzcp36eUQ3B5eXl2I0nV1dV209XVZTcpP0Pz58+3m/z8fLuRpJaWFrv57ne/azfr16+3m7vuustu1q5dazeSNGfOHLtZvHix3SxZssRubr75ZruR0o4xPvLII3bzt7/9bcjX8KQAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAQtYH8Z588kn7zYcN8zenvb3dbiTp8uXLdjNlyhS7STkE19nZaTePPvqo3UhpR8ZycnLsJuV4XMqROkm6cOGC3SxYsMBuzp07ZzcTJ060m97eXruRpJ6eHrs5efLkZ/I548aNs5vx48fbjSQ9++yzdjNixAi7mTRpkt3U1NTYjZR2JLGtrc1uNmzYMORreFIAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAITcbF946NAh+83r6+vtpq+vz25SPyvlwmVTU5PdVFZW2s1XvvIVu5Gkd955x25SrkF++ctftpudO3fajSRlecj3U8aOHWs3U6dOtZuUi51FRUV2I0mtra12s2bNGrt59dVX7aalpcVuFi1aZDeS1NjYaDcvvPCC3RQXF9tNbm7W/6R+Ssp16JTrwdngSQEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAACEnEyW18ZSDqD19PTYzY033mg3klRRUWE3Z8+etZu8vDy7STnyd/78ebuRpA8++MBu5syZYzcffvih3aQcqZOkuXPn2s0//vEPu5kyZYrdtLW12c0PfvADu5HS/k5dXV12U1BQYDeDg4N2U1ZWZjeSNHLkSLuprq62m46ODrt544037EaSFi9ebDcpByYff/zxIV/DkwIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIWR/EW758uf3
2024-07-30 00:44:16 +08:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"img = get_noise(28,28)\n",
"plt.imshow(img, cmap='gray')\n",
"plt.axis('off')\n",
"plt.show()"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Discriminator 🕵🏻‍♂️"
]
},
{
"cell_type": "code",
2024-07-30 07:06:52 +08:00
"execution_count": 376,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def DisBlock(in_dim:int,out_dim:int):\n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
2024-07-30 00:44:16 +08:00
" nn.LeakyReLU(negative_slope=0.2),\n",
2024-07-26 21:07:40 +08:00
" )"
]
},
{
"cell_type": "code",
2024-07-30 07:06:52 +08:00
"execution_count": 377,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"class Discriminator(nn.Module):\n",
"\n",
" def __init__(self,im_dim:int = 784, hidden_dim:int = 128):\n",
" super(Discriminator, self).__init__()\n",
"\n",
" self.disc = nn.Sequential(\n",
2024-07-30 07:06:52 +08:00
" # DisBlock(im_dim, hidden_dim * 4),\n",
" # DisBlock(hidden_dim * 4, hidden_dim * 2),\n",
" # DisBlock(hidden_dim * 2, hidden_dim),\n",
" \n",
" DisBlock(im_dim, hidden_dim * 2),\n",
2024-07-26 21:07:40 +08:00
" DisBlock(hidden_dim * 2, hidden_dim),\n",
"\n",
" nn.Linear(hidden_dim,1),\n",
2024-07-30 07:06:52 +08:00
" nn.Sigmoid()\n",
2024-07-26 21:07:40 +08:00
" )\n",
" \n",
" def __call__(self, noise):\n",
" return self.disc(noise)"
]
},
{
"cell_type": "code",
2024-07-30 07:06:52 +08:00
"execution_count": 378,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Discriminator(\n",
" (disc): Sequential(\n",
" (layers.0): Sequential(\n",
2024-07-30 07:06:52 +08:00
" (layers.0): Linear(input_dims=784, output_dims=256, bias=True)\n",
2024-07-26 21:07:40 +08:00
" (layers.1): LeakyReLU()\n",
" )\n",
" (layers.1): Sequential(\n",
" (layers.0): Linear(input_dims=256, output_dims=128, bias=True)\n",
" (layers.1): LeakyReLU()\n",
" )\n",
2024-07-30 07:06:52 +08:00
" (layers.2): Linear(input_dims=128, output_dims=1, bias=True)\n",
" (layers.3): Sigmoid()\n",
2024-07-26 21:07:40 +08:00
" )\n",
")"
]
},
2024-07-30 07:06:52 +08:00
"execution_count": 378,
2024-07-26 21:07:40 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"disc = Discriminator()\n",
"disc"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model Training 🏋🏻‍♂️"
]
},
{
2024-07-30 00:44:16 +08:00
"cell_type": "markdown",
2024-07-26 21:07:40 +08:00
"metadata": {},
"source": [
2024-07-30 00:44:16 +08:00
"### Losses"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-07-30 00:44:16 +08:00
"#### Discriminator Loss"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 07:06:52 +08:00
"execution_count": 379,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-26 21:36:29 +08:00
"def disc_loss(gen, disc, real, num_images, z_dim):\n",
2024-07-30 00:44:16 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" noise = mx.array(get_noise(num_images, z_dim))\n",
" fake_images = gen(noise)\n",
2024-07-27 06:09:51 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" fake_disc = disc(fake_images)\n",
" \n",
2024-07-27 05:19:08 +08:00
" fake_labels = mx.zeros((fake_images.shape[0],1))\n",
2024-07-30 00:44:16 +08:00
" \n",
2024-07-30 07:06:52 +08:00
" fake_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels)\n",
2024-07-26 21:07:40 +08:00
" \n",
" real_disc = disc(real)\n",
2024-07-27 05:19:08 +08:00
" real_labels = mx.ones((real.shape[0],1))\n",
"\n",
2024-07-30 07:06:52 +08:00
" real_loss = nn.losses.binary_cross_entropy(real_disc,real_labels)\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-30 07:06:52 +08:00
" disc_loss = (fake_loss + real_loss) / 2\n",
2024-07-26 21:07:40 +08:00
"\n",
" return disc_loss"
]
},
2024-07-30 00:44:16 +08:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Generator Loss"
]
},
2024-07-26 21:07:40 +08:00
{
"cell_type": "code",
2024-07-30 07:06:52 +08:00
"execution_count": 380,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-26 21:36:29 +08:00
"def gen_loss(gen, disc, num_images, z_dim):\n",
2024-07-26 21:07:40 +08:00
"\n",
" noise = mx.array(get_noise(num_images, z_dim))\n",
2024-07-30 07:06:52 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" fake_images = gen(noise)\n",
" fake_disc = disc(fake_images)\n",
"\n",
2024-07-27 05:19:08 +08:00
" fake_labels = mx.ones((fake_images.shape[0],1))\n",
2024-07-30 00:44:16 +08:00
" \n",
2024-07-30 07:06:52 +08:00
" gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels)\n",
2024-07-26 21:07:40 +08:00
"\n",
" return gen_loss"
]
},
{
"cell_type": "code",
2024-07-30 07:06:52 +08:00
"execution_count": 381,
2024-07-30 00:44:16 +08:00
"metadata": {},
"outputs": [],
"source": [
"# Get only the training images\n",
"train_images,*_ = map(np.array, mnist.mnist())"
]
},
{
"cell_type": "code",
2024-07-30 07:06:52 +08:00
"execution_count": 382,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-30 00:44:16 +08:00
"# Normalize the images to fall between -1,1\n",
"train_images = train_images * 2.0 - 1.0"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 07:06:52 +08:00
"execution_count": 383,
2024-07-29 06:24:50 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2024-07-30 07:06:52 +08:00
"<matplotlib.image.AxesImage at 0x157b0bb80>"
2024-07-29 06:24:50 +08:00
]
},
2024-07-30 07:06:52 +08:00
"execution_count": 383,
2024-07-29 06:24:50 +08:00
"metadata": {},
"output_type": "execute_result"
2024-07-30 00:44:16 +08:00
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaEAAAGdCAYAAAC7EMwUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAaI0lEQVR4nO3df2jU9x3H8dfVH1d1lytBk7vUmGVF202dpWrVYP3R1cxApf4oWMtGZEPa+YOJ/cGsDNNBjdgpRdI6V0amW239Y9a6KdUMTXRkijpdRYtYjDOdCcFM72LUSMxnf4hHz1j1e975vkueD/iCufu+vY/ffuvTby75xueccwIAwMBD1gsAAHRfRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJjpab2AW3V0dOjcuXMKBALy+XzWywEAeOScU0tLi/Ly8vTQQ3e+1km7CJ07d075+fnWywAA3Kf6+noNHDjwjvuk3afjAoGA9RIAAElwL3+fpyxCH3zwgQoLC/Xwww9r5MiR2rdv3z3N8Sk4AOga7uXv85REaPPmzVq8eLGWLVumI0eO6JlnnlFJSYnOnj2bipcDAGQoXyruoj1mzBg99dRTWrduXeyx73//+5o+fbrKy8vvOBuNRhUMBpO9JADAAxaJRJSVlXXHfZJ+JXTt2jUdPnxYxcXFcY8XFxertra20/5tbW2KRqNxGwCge0h6hM6fP6/r168rNzc37vHc3Fw1NjZ22r+8vFzBYDC28ZVxANB9pOwLE259Q8o5d9s3qZYuXapIJBLb6uvrU7UkAECaSfr3CfXv3189evTodNXT1NTU6epIkvx+v/x+f7KXAQDIAEm/Eurdu7dGjhypqqqquMerqqpUVFSU7JcDAGSwlNwxYcmSJfrpT3+qUaNGady4cfr973+vs2fP6tVXX03FywEAMlRKIjR79mw1NzfrN7/5jRoaGjRs2DDt2LFDBQUFqXg5AECGSsn3Cd0Pvk8IALoGk+8TAgDgXhEhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmelovAEgnPXr08DwTDAZTsJLkWLhwYUJzffv29Tzz+OOPe55ZsGCB55nf/va3nmfmzJnjeUaSrl696nlm5cqVnmfefvttzzNdBVdCAAAzRAgAYCbpESorK5PP54vbQqFQsl8GANAFpOQ9oaFDh+rvf/977ONEPs8OAOj6UhKhnj17cvUDALirlLwndOrUKeXl5amwsFAvvfSSTp8+/a37trW1KRqNxm0AgO4h6REaM2aMNm7cqJ07d+rDDz9UY2OjioqK1NzcfNv9y8vLFQwGY1t+fn6ylwQASFNJj1BJSYlmzZql4cOH67nnntP27dslSRs2bLjt/kuXLlUkEolt9fX1yV4SACBNpfybVfv166fhw4fr1KlTt33e7/fL7/enehkAgDSU8u8Tamtr05dffqlwOJzqlwIAZJikR+j1119XTU2N6urqdODAAb344ouKRqMqLS1N9ksBADJc0j8d9/XXX2vOnDk6f/68BgwYoLFjx2r//v0qKChI9ksBADJc0iP0ySefJPu3RJoaNGiQ55nevXt7nikqKvI8M378eM8zkvTII494npk1a1ZCr9XVfP31155n1q5d63lmxowZnmdaWlo8z0jSv//9b88zNTU1Cb1Wd8W94wAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAMz7nnLNexDdFo1EFg0HrZXQrTz75ZEJzu3fv9jzDf9vM0NHR4XnmZz/7meeZS5cueZ5JRENDQ0JzFy5c8Dxz8uTJhF6rK4pEIsrKyrrjPlwJAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwExP6wXA3tmzZxOaa25u9jzDXbRvOHDggOeZixcvep6ZPHmy5xlJunbtmueZP/3pTwm9Fro3roQAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADPcwBT63//+l9DcG2+84Xnm+eef9zxz5MgRzzNr1671PJOoo0ePep6ZMmWK55nW1lbPM0OHDvU8I0m//OUvE5oDvOJKCABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAw43POOetFfFM0GlUwGLReBlIkKyvL80xLS4vnmfXr13uekaSf//znnmd+8pOfeJ75+OOPPc8AmSYSidz1/3muhAAAZogQAMCM5wjt3btX06ZNU15ennw+n7Zu3Rr3vHNOZWVlysvLU58+fTRp0iQdP348WesFAHQhniPU2tqqESNGqKKi4rbPr1q1SmvWrFFFRYUOHjyoUCikKVOmJPR5fQBA1+b5J6uWlJSopKTkts855/Tee+9p2bJlmjlzpiRpw4YNys3N1aZNm/TKK6/c32oBAF1KUt8TqqurU2Njo4qLi2OP+f1+TZw4UbW1tbedaWtrUzQajdsAAN1DUiPU2NgoScrNzY17PDc3N/bcrcrLyxUMBmNbfn5+MpcEAEhjKfnqOJ/PF/exc67TYzctXbpUkUgkttXX16diSQCANOT5PaE7CYVCkm5cEYXD4djjTU1Nna6ObvL7/fL7/clcBgAgQyT1SqiwsFChUEhVVVWxx65du6aamhoVFRUl86UAAF2A5yuhS5cu6auvvop9XFdXp6NHjyo7O1uDBg3S4sWLtWLFCg0ePFiDBw/WihUr1LdvX7388stJXTgAIPN5jtChQ4c0efLk2MdLliyRJJWWluqPf/yj3nzzTV25ckXz58/XhQsXNGbMGO3atUuBQCB5qwYAdAncwBRd0rvvvpvQ3M1/VHlRU1Pjeea5557zPNPR0eF5BrDEDUwBAGmNCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZriLNrqkfv36JTT317/+1fPMxIkTPc+UlJR4ntm1a5fnGcASd9EGAKQ1IgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMNzAFvuGxxx7zPPOvf/3L88zFixc9z+zZs8fzzKFDhzzPSNL777/veSbN/ipBGuAGpgCAtEaEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmOEGpsB9mjFjhueZyspKzzOBQMDzTKLeeustzzMbN270PNPQ0OB5BpmDG5gCANIaEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGG5gCBoYNG+Z5Zs2aNZ5nfvSjH3meSdT69es9z7zzzjueZ/773/96noENbmAKAEhrRAgAYMZzhPbu3atp06YpLy9PPp9PW7dujXt+7ty58vl8cdvYsWOTtV4AQBfiOUKtra0aMWKEKioqvnWfqVOnqqGhIbbt2LHjvhYJAOiaenodKCkpUUlJyR338fv9CoVCCS8KANA9pOQ9oerqauXk5GjIkCGaN2+empqavnXftrY2RaPRuA0A0D0kPUIlJSX66KOPtHv3bq1evVoHDx7Us88+q7a2ttvuX15ermAwGNvy8/OTvSQAQJry/Om4u5k9e3bs18OGDdOoUaNUUFCg7du3a+bMmZ32X7p0qZYsWRL7OBqNEiI
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
2024-07-29 06:24:50 +08:00
}
],
"source": [
2024-07-30 00:44:16 +08:00
"plt.imshow(train_images[0].reshape(28,28),cmap='gray')"
2024-07-29 06:24:50 +08:00
]
},
{
"cell_type": "code",
2024-07-30 07:06:52 +08:00
"execution_count": 384,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-30 00:44:16 +08:00
"def batch_iterate(batch_size: int, ipt: list[int])-> list[int]:\n",
2024-07-29 06:24:50 +08:00
" perm = np.random.permutation(len(ipt))\n",
" for s in range(0, len(ipt), batch_size):\n",
2024-07-26 21:07:40 +08:00
" ids = perm[s : s + batch_size]\n",
2024-07-30 00:44:16 +08:00
" yield ipt[ids]"
]
},
{
"cell_type": "code",
2024-07-30 07:06:52 +08:00
"execution_count": 385,
2024-07-30 00:44:16 +08:00
"metadata": {},
"outputs": [],
"source": [
"def show_images(imgs:list[int],num_imgs:int = 25):\n",
" if (imgs.shape[0] > 0): \n",
" fig,axes = plt.subplots(5, 5, figsize=(5, 5))\n",
" \n",
" for i, ax in enumerate(axes.flat):\n",
" img = mx.array(imgs[i]).reshape(28,28)\n",
" ax.imshow(img,cmap='gray')\n",
" ax.axis('off')\n",
" plt.show()"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-07-30 00:44:16 +08:00
"### show first batch of images"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "code",
2024-07-30 07:06:52 +08:00
"execution_count": 386,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
2024-07-30 07:06:52 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAADTdUlEQVR4nOz955OeV3rfiX+enHMOnRO6kcEIkhhyyBnOSBoFy2NZtstBsvfV/j2qcrlcu+utqZUsrbWWxjMSR8OciUSkjugcnpxz/r3A7xx2I5Ag2UB3P31/qrpAAt3A3afPfa5zpe+l6vV6PRQUFBQUFPYR9UE/gIKCgoJC/6EYFwUFBQWFfUcxLgoKCgoK+45iXBQUFBQU9h3FuCgoKCgo7DuKcVFQUFBQ2HcU46KgoKCgsO8oxkVBQUFBYd9RjIuCgoKCwr6jfdxPVKlUT/I5jhTfRdRAWb+vUNbv+/FdRTWUNfwKZQ9+Px5n/RTPRUFBQUFh31GMi4KCgoLCvvPYYTEFBYUH0Wg0qFQqAoEAQ0ND6PV6rFYrGo2G5eVlNjY2aLVa1Ov17xzOUlA4iijGRUHhO6JSqdDpdOh0Os6fP8+/+lf/CrfbzejoKEajkf/z//w/+au/+itKpRKJRIJ2u33Qj6yg8NRQjIuCwvdAr9djMplwOBwEg0E8Hg/BYBCDwYDNZkOr1UrvRkFhv1CpVKjValQqFQaDAa1Wi1arRafT0ev1aDQatNttms0mzWbzQJ5RMS4KCt8RjUZDOBwmGAxy4sQJxsfHsdlsGI1Gut3uQT+eQh+j0+mw2+2YTCZmZmbkPhwZGaFWq3Hr1i0ymQxzc3MsLi7S6/WeelhWMS4KCt8RtVqNw+HA7/fj9Xpxu92YTCZ6vR6dTudAXuijwMO8OGWdHh+VSoVWq8VsNmOz2RgaGmJiYoKRkRHOnDlDqVSi0+mwvb1NLBZDrVbT7XYV46KgcNgxGo34fD5sNhsXLlxgamqKkZERer0e2WyWK1eukEgkuHr1Krlcjlqtdqw9GY1Gg9PpxGg0EgqFCIVCGAwGnE4nvV6Pzc1NstksqVSKzc3NY71WD0On0+H3+zGZTHg8HtxuN06nk4mJCaxWK8PDw/Jy43Q6MZlMXLhwgZGREXZ2drhz5w6tVotGo/FUDYxiXBQUviVms5nR0VG8Xi8vvvgi58+fx2az0ev1SKVS/N3f/R137txhc3OTVCpFt9ul0+kc9GMfGDqdDp/Ph9Pp5IUXXuDZZ5/F6XQyPDwMwAcffMDi4iK3b98mFosdWI7gsKLX6xkeHsbn8zE5OcnExATBYJAXX3wRq9WKSqXa8wHg9/tpNpvcuHGDd955h3q9TrPZVIyLgsJhRqfT4XA4cLlc2O12rFYrALlcjmw2Kz+Ex3JcQz42mw2fz4fVamVyclJW0gWDQSwWC2azmV6vRygUot1uU6lUWF9fp1KpkMlkjr2R0ev1mM1maYjD4TBDQ0MEg0Hcbje9Xo92u41er0ej0ez5Wo1Gg1arlUn/g0AxLgoK3xKr1crU1BSRSITh4WGCwSBbW1vMz8+zvr7O4uIiq6urtFqtY+2xTE5O8kd/9Ef4fD7Onz+P1+vFbDZjMpnodDpUq1UAXnjhBS5evMj09DSBQICdnR3eeustdnZ2Dvg7OFg8Hg9TU1OEQiH+5E/+hPHxcaxWKxaLhWazSTqdlsbZbrcf9OM+gGJcFBQeExF2EJU6Ir6t0+lotVpkMhkymQzlcplarXbQj3tgiPJrl8vF8PCwbDD1eDx0Oh1arRbtdls2lvp8Pux2O8FgkHA4LG/jxx2DwYDH48Hv9xMOh4lGo+h0OrRaLaVSiXq9TqfToVarYTAYZHmyKCRpNpu02+0D85wV46Kg8JhYrVZsNhvDw8OcOXOGcDiMTqcjnU4zPz/Pb37zG1KpFLlc7qAf9cDQ6XScPXuW4eFhLly4wLlz5zAajSQSCTY3N5mbm2Nubo5arUYul0Ov1/O7v/u7ssrJZrNhtVrRao/v0aTX69FqtQwPD/Paa68RDAbx+/3o9Xqq1SrVapVSqUQymaRer7O6ukq73cZqteJyueh2uxQKBWq1GgsLC1QqlQMxMsf3J6ig8C0xmUy43W78fj8jIyOEQiEqlQrFYpHNzU2uXbtGoVCgXC4f9KMeGBqNhrGxMZ5//nlmZmaYmJig1WrJCrr333+f3/72t1SrVXK5HGazmUgkQiAQoFqtYjabsVgsqNXHU/ZQlBkbDAYCgQDnzp3D5/PhcrnQ6XQ0m03y+TylUol8Pk+lUmFjY4NMJoPX62VgYIB2u00sFqNcLrO5uUmj0ZCl8U+TvjEuJpMJu92OxWJhYmICm80mFzSZTHL9+vVjFaoQG1S4yg9DpVLJsI5er8doNMrNvftrzGYzXq8XjUZDtVql1WqxtbXF8vLygbrdTxOVSsXAwAAXLlxgcnJSdt+LlzybzVKtVo9t2bHBYMDtdmOz2RgfH2dychK73U4sFqNYLHL16lU2NjZYWVmhUqnQbDblOjUaDbluBoMBo9F4LI2LRqNBp9MxOTnJwMAAp06dwuPxYDAYSKVSdDod7t69y/LyMuVymXg8Tr1eJ51Oy/+PxWJ0u11yuZz8s4MqKukb4+JyuZiYmGBwcJA///M/Z3x8nFarRbPZ5MMPP2Rtbe1YGRej0Yjb7ZayEA+rGNFoNAQCARwOh2wGFM1ZOp1Ofl4oFOL555/HYDAQj8cpFov8+te/5v/6v/4varXasTAwarWa8+fP8+/+3b/D7XYTCAQASKfTUqBSGJh+X4uHYbFYmJmZwe/388orr/Dqq68Si8WYm5tjc3OTv/7rv2Zubk6+kyIv0O1293gxTqeTfD5/7MJiarUanU6H1WrljTfe4LXXXiMcDjMwMECz2eTq1avS8/vwww+p1+sUCgXa7Tbdbpdut4tarZZGWRiUgywqOTQ/QbEwBoMBnU5Hp9ORpYg6nU6W1IlfxYEpbtki4RWJRPD7/fh8PlnbbbPZ+uYmJLSERKmhXq9/qHdit9ulsbjfExFoNBr8fj82mw2Hw4HP50Oj0WCxWPa83IFAAJ/Ph8FgoN1uo9Pp5M29X9b1UYj11uv1OBwO3G43FouFdrtNq9Uim82SSCTki34cDQt8JUfidruxWq2YTCZarRbb29tsb2+TzWYplUoP/dpms0mtVkOv10svWhQFHIdSbrHHRC9QIBCQOZZcLkelUmF7e5t4PE4ikZBl2uVy+VB7yYfCuGi1WqxWK0ajkenpaaLRKOl0mtXVVbRarayLN5lMGI1GTCYTfr9fVlOYzWbZd2C1WolEImg0GllJ8bQ7U58kBoOB8fFxPB6PvNnslnkXuFwuIpGIfFnFi7obYczFi7w7LCY8nV6vh9FoxGKxoFKpZNOWzWbDbDbT6XTodDp9q/hrNBqZmJiQPRoej4dWq8Xq6ir5fJ7f/va3XL58mXQ6TavVOujHPTCsVisnT55kcHAQr9cLwNLSEv/tv/03UqkUsVjsoV/X6XRIJpOsrKwwOjrK+Pg4LpdLflQqlb6OOIhL4sjICH/yJ39CNBrl7NmzDA0Ncfv2bf6f/+f/IZvNcufOHdLpNPl8nmKxKL2Vw8yhMC4ajQaTySSTexMTE5hMJorFIhqNhmg0KvMpomJnYGAAk8lEOBzG6XRiMBgwm80ybin0dA66HG8/EAe9KIP1er0Eg0EZ2zYajTidzj3ehsfjYWBgAK1W+0AM+5vWYncIbffndrtddDodBoNhjxJrPyv+ajQa3G43wWBQSpi0Wi3y+TzpdJrNzU2Wlpb25BCOI6IL3+/3YzQapRTOnTt3vrZ6rtfrUa1WyefztFotmfsTH41G4yl+F08fcV45nU5Onz4tC0XsdjuVSoWbN2+SSCRYWlqiUCgc9ON+Kw7UuNhsNmw2G6FQiEuXLuHxeBgdHSUQCFAsFrlw4QIqlQq32y0PNL1ej16vx263y1+NRqOUmxa9CJ1Oh42NDWZnZ5mfnz+Sm1TkRNxuN16vl7GxMWw2m7xJezwevF7vQw2I2WyWHsv9h/9uYcX7Da8wYCL8uPtrhfuuUqn2vPz9bFw
2024-07-26 21:07:40 +08:00
"text/plain": [
2024-07-29 06:24:50 +08:00
"<Figure size 500x500 with 25 Axes>"
2024-07-26 21:07:40 +08:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
2024-07-30 00:44:16 +08:00
"X = batch_iterate(25, train_images)\n",
"for x in X: \n",
" show_images(x)\n",
2024-07-26 21:07:40 +08:00
" break"
]
},
2024-07-30 07:06:52 +08:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training Cycle"
]
},
2024-07-26 21:07:40 +08:00
{
"cell_type": "code",
2024-07-30 07:06:52 +08:00
"execution_count": 391,
2024-07-27 06:09:51 +08:00
"metadata": {},
2024-07-29 06:24:50 +08:00
"outputs": [],
2024-07-27 06:09:51 +08:00
"source": [
2024-07-30 07:06:52 +08:00
"lr = 0.002\n",
"z_dim = 100\n",
"\n",
2024-07-27 06:09:51 +08:00
"gen = Generator(z_dim)\n",
"mx.eval(gen.parameters())\n",
2024-07-30 07:06:52 +08:00
"gen_opt = optim.Adam(learning_rate=lr,betas=[0.5, 0.999])\n",
2024-07-27 06:09:51 +08:00
"\n",
"disc = Discriminator()\n",
"mx.eval(disc.parameters())\n",
2024-07-30 07:06:52 +08:00
"disc_opt = optim.Adam(learning_rate=lr,betas=[0.5, 0.999])"
2024-07-29 06:30:08 +08:00
]
},
2024-07-27 06:09:51 +08:00
{
"cell_type": "code",
2024-07-30 07:06:52 +08:00
"execution_count": 395,
2024-07-27 06:09:51 +08:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-30 07:06:52 +08:00
" 4%|▍ | 9/200 [00:59<21:10, 6.65s/it]"
2024-07-29 00:18:35 +08:00
]
2024-07-26 21:07:40 +08:00
}
],
"source": [
2024-07-30 00:44:16 +08:00
"# Set your parameters\n",
"n_epochs = 200\n",
"display_step = 5000\n",
2024-07-28 06:10:19 +08:00
"cur_step = 0\n",
2024-07-30 00:44:16 +08:00
"\n",
"batch_size = 128\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-28 06:10:19 +08:00
"D_loss_grad = nn.value_and_grad(disc, disc_loss)\n",
"G_loss_grad = nn.value_and_grad(gen, gen_loss)\n",
"\n",
"\n",
2024-07-30 00:44:16 +08:00
"for epoch in tqdm(range(n_epochs)):\n",
2024-07-28 06:10:19 +08:00
"\n",
2024-07-30 00:44:16 +08:00
" for real in batch_iterate(batch_size, train_images):\n",
2024-07-28 06:10:19 +08:00
" \n",
2024-07-30 00:44:16 +08:00
" # TODO Train Discriminator\n",
2024-07-30 07:06:52 +08:00
" D_loss,D_grads = D_loss_grad(gen, disc,mx.array(real), batch_size, z_dim)\n",
2024-07-28 06:10:19 +08:00
"\n",
" # Update optimizer\n",
" disc_opt.update(disc, D_grads)\n",
" \n",
" # Update gradients\n",
" mx.eval(disc.parameters(), disc_opt.state)\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-30 00:44:16 +08:00
" # TODO Train Generator\n",
" G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)\n",
2024-07-28 06:10:19 +08:00
" \n",
" # Update optimizer\n",
" gen_opt.update(gen, G_grads)\n",
" \n",
" # Update gradients\n",
" mx.eval(gen.parameters(), gen_opt.state)\n",
" \n",
2024-07-29 06:24:50 +08:00
" \n",
2024-07-30 00:44:16 +08:00
" if (cur_step + 1) % display_step == 0:\n",
" print(f\"Step {epoch}: Generator loss: {G_loss}, discriminator loss: {D_loss}\")\n",
" fake_noise = mx.array(get_noise(batch_size, z_dim))\n",
" fake = gen(fake_noise)\n",
" show_images(fake)\n",
" show_images(real)\n",
" cur_step += 1"
2024-07-27 05:19:08 +08:00
]
2024-07-26 21:07:40 +08:00
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}