mlx-examples/gan/playground.ipynb
2024-07-30 02:37:09 +03:00

554 lines
106 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Import Library"
]
},
{
"cell_type": "code",
"execution_count": 573,
"metadata": {},
"outputs": [],
"source": [
"import mnist"
]
},
{
"cell_type": "code",
"execution_count": 574,
"metadata": {},
"outputs": [],
"source": [
"import mlx.core as mx\n",
"import mlx.nn as nn\n",
"import mlx.optimizers as optim\n",
"\n",
"from tqdm import tqdm\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 575,
"metadata": {},
"outputs": [],
"source": [
"# mx.set_default_device(mx.gpu)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# GAN Architecture"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generator 👨🏻‍🎨"
]
},
{
"cell_type": "code",
"execution_count": 576,
"metadata": {},
"outputs": [],
"source": [
"def GenBlock(in_dim:int,out_dim:int):\n",
" \n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
" nn.BatchNorm(out_dim, 0.8),\n",
" nn.LeakyReLU(0.2)\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 577,
"metadata": {},
"outputs": [],
"source": [
"class Generator(nn.Module):\n",
"\n",
" def __init__(self, z_dim:int = 10, im_dim:int = 784, hidden_dim: int = 256):\n",
" super(Generator, self).__init__()\n",
" # Build the neural network\n",
" self.gen = nn.Sequential(\n",
" GenBlock(z_dim, hidden_dim),\n",
" GenBlock(hidden_dim, hidden_dim * 2),\n",
" GenBlock(hidden_dim * 2, hidden_dim * 4),\n",
" GenBlock(hidden_dim * 4, hidden_dim * 8),\n",
"\n",
"\n",
" nn.Linear(hidden_dim * 8,im_dim),\n",
" )\n",
" \n",
" def __call__(self, noise):\n",
" x = self.gen(noise)\n",
" return mx.tanh(x)"
]
},
{
"cell_type": "code",
"execution_count": 578,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Generator(\n",
" (gen): Sequential(\n",
" (layers.0): Sequential(\n",
" (layers.0): Linear(input_dims=100, output_dims=256, bias=True)\n",
" (layers.1): BatchNorm(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): LeakyReLU()\n",
" )\n",
" (layers.1): Sequential(\n",
" (layers.0): Linear(input_dims=256, output_dims=512, bias=True)\n",
" (layers.1): BatchNorm(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): LeakyReLU()\n",
" )\n",
" (layers.2): Sequential(\n",
" (layers.0): Linear(input_dims=512, output_dims=1024, bias=True)\n",
" (layers.1): BatchNorm(1024, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): LeakyReLU()\n",
" )\n",
" (layers.3): Sequential(\n",
" (layers.0): Linear(input_dims=1024, output_dims=2048, bias=True)\n",
" (layers.1): BatchNorm(2048, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): LeakyReLU()\n",
" )\n",
" (layers.4): Linear(input_dims=2048, output_dims=784, bias=True)\n",
" )\n",
")"
]
},
"execution_count": 578,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gen = Generator(100)\n",
"gen"
]
},
{
"cell_type": "code",
"execution_count": 579,
"metadata": {},
"outputs": [],
"source": [
"# make 2D noise with shape n_samples x z_dim\n",
"def get_noise(n_samples:list[int], z_dim:int)->list[int]:\n",
" return mx.random.normal(shape=(n_samples, z_dim))"
]
},
{
"cell_type": "code",
"execution_count": 580,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAWPUlEQVR4nO3ca2zW9d3H8U8VBqUW5NDSzohFOpBCy6GcJgcngm7iWGEZBBPNItlWt+yQxejiRjIzD8nAxI0tSzYNE+dp2YA4poIgWhyUldUhWg4ddIUiLeVYoBxaet3Pvske9fr8Huy+c+f9eny9/5eDls/+T745mUwmIwAAJF33v/0fAAD4v4NRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQOiT7Qefe+45++H79++3m7a2NruRpPvuu89u/va3v9nNHXfcYTdnz561m507d9qNJBUXF9vNJ598YjfTp0+3mylTptiNJH344Yd209LSYjcXLlywm9LSUrtJVVJSYjdvvvmm3Tz22GN284Mf/MBuHnjgAbuRpC1bttjN7bffbjfvvvuu3aT8HUnSvn377GbmzJl28+STT/b6Gd4UAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQMjJZDKZbD5499132w+fNGmS3cyYMcNuJGnTpk12k3KobtCgQXaTl5dnNynH2SSps7PTbm677Ta7yfLH5j/MmjXLbiTphRdesJuUY2vr1q2zm5SjZCmHIqW0n9e9e/fazb333ms3Q4cOtZvCwkK7kdJ+N2pra+3m7bfftptvf/vbdiOl/T796Ec/spuurq5eP8ObAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAh9sv1gdXW1/fCdO3fazTvvvGM3ktTa2mo3S5YssZuGhga7ue46f3ubmprsRko7ZnblypX/SpNylEySRowYYTcpf0/Hjx+3m8OHD9vNsWPH7EaSvva1r9nN+fPn7ebMmTN2k3IQb+PGjXYjpf0bsWHDBru5du2a3eTn59uNJPXpk/U/xWHp0qVJ39Ub3hQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAACHr03yvvfaa/fABAwbYzfjx4+1Gkj766CO7aWtrs5v29na7GTNmjN2kSrleunv3brvZtWuX3fz0pz+1G0maOHGi3ezZs8duli1bZjcpP0M1NTV2I0l5eXl2U1JSYjeVlZV209LSYjfjxo2zG0kqLy+3m5UrV9pNZ2en3Vy6dMluJKmwsNBuDhw4kPRdveFNAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAISsD+L16ZP1R8OkSZPs5umnn7YbSVqxYoXdnDhxwm7OnTtnNymH1lIO70nSli1b7GbJkiV2M2TIELvp27ev3UjS73//e7spKyuzmzVr1thNV1eX3Tz66KN2I6UdOxw5cqTdzJgxw242bdpkNwcPHrQbSdqxY4fdNDU12c3YsWPt5p///KfdSFJra6vdfP/730/6rt7wpgAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAABCTiaTyWTzwXfffdd+eENDg92kHIaSpPLycrt588037WbKlCl2M2jQILsZOHCg3aR+V8rBuZ6eHrs5c+aM3UhSUVGR3fTv399uRo8ebTdvvfWW3eTl5dmNlPbfl3LIsrS01G6Ki4vtJuXvVZJefvnl/8p3ffzxx3bT2dlpN5J0yy232M0HH3xgN9u3b+/1M7wpAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgJD1tayVK1faD583b57d9O3b124k6ejRo3aTm5trN42NjXZTVVVlN21tbXYjSb/+9a/tZunSpXZz+fJlu6mvr7cbKe3IWElJid2k/Ox961vfspvNmzfbjZR2hPDkyZN2M23aNLv5xz/+YTeXLl2yGyntd6O6utpuCgsL7eajjz6yGyntZ+KJJ55I+q7e8KYAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAQtYH8WbMmGE/vLW11W5ycnLsRko70LZr1y67+fnPf2433d3ddnPt2jW7kaSJEyfaTcqRv5QjeqkH0E6cOGE3/fr1s5u8vDy7eeaZZ+xm/vz5diNJZWVldrNx40a7OXDggN2kOH36dFI3atQou/nlL39pNxcvXrSbW265xW4kaezYsXZz5cqVpO/qDW8KAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAICQk8lkMtl88Jvf/Kb98N/97nd289prr9mNlHbx9OTJk3Zz00032U3K5ddt27bZjSQ9/PDDdrN161a7KSwstJuKigq7kaSioiK7OXLkiN3k5ubaTX19vd0MGTLEbiTp/PnzdpPy3zd79my7Sfk7euWVV+xGkrL8J+s/jB8/3m5SLhXff//9diOl/d6m2Lt3b6+f4U0BABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAAhD7ZfrCkpMR++JYtW+wm5TibJE2aNMluUo7OHTp0yG6WLVtmN8XFxXYjSbfeeqvdpByPW7Vqld2MHDnSbiRp+vTpdnPmzBm7aWhosJuUQ5G7d++2G0lasGCB3bS1tdlNZ2en3aT8eS9atMhuJGnTpk12k/J7cfHiRbtJ+fdBkh5//HG7uXTpUtJ39YY3BQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABCyPohXUFBgP7y9vd1uxo4dazeSNHDgQLs5deqU3VRWVtrN0aNH7aalpcVuJKmxsdFuvvCFL9jNN77xDbuZMGGC3UjSH/7wB7uZM2eO3Rw4cMBuTpw4YTd5eXl2I0m1tbV2U1paajcVFRV2k8lk7Gb//v12I0mzZs2ym66uLrtJOQx43XVp/z/7ySeftJuysjK7eeihh3r9DG8KAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIORksrxkVVxcbD98wYIFdnP99dfbjSTt2bPHbk6ePGk3999/v928+OKLdrNmzRq7kdL+/FK+6+rVq3YzefJku5Gk8+fP282xY8fsZty4cXZTVVVlNytXrrQbSSoqKrKbV1991W7uueceuzl9+rTdDB482G4k6cc//rHdbNiwwW6OHz9uN6nHDlOOh6YcKX3kkUd6/QxvCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAkPWV1Jdeesl+eF1dnd20tLTYjSTl5+fbzdSpU+3mhRdesJs777zTboYNG2Y3kjRlyhS7uXjxot3U1NTYTaqSkhK7SbmA+/bbb9vNo48+ajeffvqp3Uhpf0+5ubl209raajeLFy+2m9raWruR0i4Bf/DBB3Zz9OhRuykvL7cbSZozZ47d1NfX282qVat6/QxvCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACD0yfqDfbL+aBg9erTd3HTTTXYjSf3797ebfv362c2SJUvspr293W5SjsBJaQcFjxw5YjfLly+3m+rqaruRpNmzZ9tNc3Oz3XzpS1+ym3379tlNR0eH3UjSXXfdZTcvv/yy3fT09NjNpUuX7Cb1MOChQ4fsZtSoUXbzuc99zm5uuOEGu5GkLVu22M3kyZOTvqs3vCkAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAkJPJZDLZfPCHP/yh/fAJEybYTVFRkd1I0j333GM3zzzzjN28+OKLdlNQUGA3S5cutRtJ2r9/v92kHEAbNmyY3YwfP95uJOngwYN2M3ToULupq6uzm5Q/h+7ubruRpI8//thuRowYYTcpR91Sjkvu3bvXbqS045cpxw7PnDljN+vXr7cbSaqqqrKbq1ev2s2KFSt6/QxvCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACBkfRBv7ty59sMfeOABu0k58iRJf/7zn+1m2rRpdpNyaK25udluOjo67EaSFi1aZDdPP/203ZSWltpNytE0SZozZ47dpPz5pRy3e//99+0m5edOkhoaGuwm5bjdhg0b7OaOO+6wm0GDBtmNJM2YMcNu9u3bZzeNjY12U1ZWZjeStH37dru577777ObBBx/s9TO8KQAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDQJ9sPTpw40X74gQMH7ObcuXN2I0mjR4+2m5SjafPmzbObG2+80W6Ki4vtRpKamprs5oknnrCblKNphYWFdiNJPT09dvPGG2/YzeLFi+2moqLCblKON0pSbm6u3dx88812U1lZaTfbtm2zm9tuu81uJCknJ8duBgwYYDfvvfee3Zw+fdpupLQDk93d3Unf1RveFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAIesrqWfPnvUf3ifrx4cRI0bYTep3ffazn7WbX/3qV3Yzbdo0u7l69ardSFJ9fb3d7Ny5025aW1vtJuXCrJR2MTalGTJkiN2kXKqcO3eu3UhSQUGB3UyaNMluduzYYTfLli2zm927d9uNJLW1tdlNTU2N3Tz11FN2c+3aNbuRpDVr1thNykXWhx56qNfP8KYAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAQtZX5G699Vb74TfccIPdDB061G4k6dixY3bT2NhoNwsWLLCbVatW2c2Xv/xlu5GkqVOn2s369evtpqSkxG7y8vLsRpLWrVtnN11dXXaTk5NjN7/5zW/sJuWYoCQ98sgjdvOTn/zEblIOA6b8/jU3N9uNJE2YMMFuLl++bDcDBgywm08//dRupLQ/v3HjxiV9V294UwAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAAAh64N4V65csR/+97//3W7uvPNOu5HSDlGdOXPGblIOa124cMFuUg4QSlJ+fr7dfOUrX7Gb1atX282f/vQnu5Gk6upqu0n5c9i1a5fdDB8+3G5mz55tN1Laz9HkyZPtpr293W5mzpxpN4cOHbIbSTp16pTd9O/f327ef/99u0k5qihJFRUVdnP27Nmk7+oNbwoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgZH0Qr6Ojw374/Pnz7Wbr1q12I6UdkJs6darddHZ22s3y5cvtpqSkxG4k6eDBg3Zz/fXX283ChQvt5sYbb7QbSWpoaLCb8vJyu0k54Dhw4EC7STlsJ0n9+vWzm3HjxtlNbW2t3aQcnPvXv/5lN1La73pTU5PdpPz7sHbtWruRpMGDB9tNys94NnhTAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAACHrg3gpR6iOHDliN0VFRXYjSWPGjLGb119/3W4KCwvtpqyszG5aWlrsRpIOHz5sNylHCCdPnmw3KcfjJOnee++1m3PnztnNzJkz7aa7u9tu/vrXv9qNJH3yySd2s2LFCrupqKiwm0OHDtlNVVWV3aR+19e//nW7aWxstJvUQ5azZs2ym/b29qTv6g1vCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAkPWV1Llz59oPr6urs5vx48fbjSRt2LDBbkaNGmU3y5cvt5stW7bYzRtvvGE3kvTss8/azbRp0+ymp6fHbvLz8+1Gkj788EO7OXXqlN3U19fbzenTp+1m9erVdiNJ69ats5urV6/aTU1Njd1Mnz7dblIug0ppl5Tfe+89u0m5VJxywVWSRo4caTddXV1J39Ub3hQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAyPog3uOPP24/POVIViaTsRtJ+upXv2o3v/jFL+xmz549dvPggw/azfPPP283kvTUU0/ZzZAhQ+ymoqLCbl599VW7kaTLly/bzcKFC+1m+PDhdjNlyhS7OXjwoN1IUmVlpd1cuHDBbg4fPmw35eXldvP666/bjSRNmDDBbubPn283bW1tdvPvf//bbqS045zHjh1L+q7e8KYAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAQtYH8QoKCuyH5+Xl2U1HR4fdSNLmzZvtZuvWrXZTV1dnNz09PXazdu1au5Gk3/72t3azadMmu5kzZ47dVFdX240kdXd3282JEyf+K01NTY3dpP7dzps3z26+973v2c2wYcPsZvv27XYzYsQIu5HSDva1t7fbTcq/RYsWLbIbSWpoaLCbixcvJn1Xb3hTAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAACHrg3h9+/a1H75t2za7+ctf/mI3kvTKK6/Yzd133203zc3NdtPU1GQ3Fy5csBtJGjJkiN089thjdlNbW2s3ly9fthtJKikpsZsxY8bYzZEjR+xm4cKFdvPss8/ajSTt37/fbjo7O+3mnXfesZvz58/bTcrvhSQ9/PDDdpNyKPL222+3m8GDB9uNJC1YsMBuNm7cmPRdveFNAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQsr6SWlBQYD/8O9/5jt0cP37cbiRp7dq1dnPzzTfbTco1w5T/tsrKSrtJ9cc//tFuUi5VPvfcc3YjSdu3b7eblGuxKVdcv/vd79pNTU2N3UjS+vXr7SblZ2/x4sV209HRYTdLly61G0k6ffq03RQWFtpNyuXX3Nxcu5Gkuro6u8nPz0/6rt7wpgAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAABCTiaTyWTzwc2bN9sP37Nnj900NzfbjSR98YtftJsdO3bYTcqhtZTDWrNmzbIbKe1o2mc+8xm7OXnypN2kHCWTpMmTJ9tNyv+mlJ+HlONsAwYMsBtJyvJX9T8MHz7cbkpKSuwm5WDmSy+9ZDeSVFpaajcTJ060m7feestuPv/5z9uNJK1evdpufvazn9lNVVVVr5/hTQEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAACErA/iAQD+/+NNAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEP4H4NnwTGy59PcAAAAASUVORK5CYII=",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"img = get_noise(28,28)\n",
"plt.imshow(img, cmap='gray')\n",
"plt.axis('off')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Discriminator 🕵🏻‍♂️"
]
},
{
"cell_type": "code",
"execution_count": 581,
"metadata": {},
"outputs": [],
"source": [
"def DisBlock(in_dim:int,out_dim:int):\n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
" nn.LeakyReLU(negative_slope=0.2),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 582,
"metadata": {},
"outputs": [],
"source": [
"class Discriminator(nn.Module):\n",
"\n",
" def __init__(self,im_dim:int = 784, hidden_dim:int = 256):\n",
" super(Discriminator, self).__init__()\n",
"\n",
" self.disc = nn.Sequential(\n",
" DisBlock(im_dim, hidden_dim * 4),\n",
" DisBlock(hidden_dim * 4, hidden_dim * 2),\n",
" DisBlock(hidden_dim * 2, hidden_dim),\n",
" \n",
" nn.Dropout(0.3),\n",
" nn.Linear(hidden_dim,1),\n",
" # nn.Sigmoid()\n",
" )\n",
" \n",
" def __call__(self, noise):\n",
" x = noise + 1.0\n",
" x = self.disc(noise)\n",
" out = mx.log(mx.softmax(x)) \n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": 583,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Discriminator(\n",
" (disc): Sequential(\n",
" (layers.0): Sequential(\n",
" (layers.0): Linear(input_dims=784, output_dims=1024, bias=True)\n",
" (layers.1): LeakyReLU()\n",
" )\n",
" (layers.1): Sequential(\n",
" (layers.0): Linear(input_dims=1024, output_dims=512, bias=True)\n",
" (layers.1): LeakyReLU()\n",
" )\n",
" (layers.2): Sequential(\n",
" (layers.0): Linear(input_dims=512, output_dims=256, bias=True)\n",
" (layers.1): LeakyReLU()\n",
" )\n",
" (layers.3): Dropout(p=0.30000000000000004)\n",
" (layers.4): Linear(input_dims=256, output_dims=1, bias=True)\n",
" )\n",
")"
]
},
"execution_count": 583,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"disc = Discriminator()\n",
"disc"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model Training 🏋🏻‍♂️"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Losses"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Discriminator Loss"
]
},
{
"cell_type": "code",
"execution_count": 584,
"metadata": {},
"outputs": [],
"source": [
"def disc_loss(gen, disc, real, num_images, z_dim):\n",
" \n",
" noise = mx.array(get_noise(num_images, z_dim))\n",
" fake_images = gen(noise)\n",
" \n",
" fake_disc = disc(fake_images)\n",
" \n",
" fake_labels = mx.zeros((fake_images.shape[0],1))\n",
" \n",
" fake_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels)\n",
" \n",
" real_disc = disc(real)\n",
" real_labels = mx.ones((real.shape[0],1))\n",
"\n",
" real_loss = nn.losses.binary_cross_entropy(real_disc,real_labels)\n",
"\n",
" disc_loss = (fake_loss + real_loss) / 2\n",
"\n",
" return disc_loss"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Generator Loss"
]
},
{
"cell_type": "code",
"execution_count": 585,
"metadata": {},
"outputs": [],
"source": [
"def gen_loss(gen, disc, num_images, z_dim):\n",
"\n",
" noise = mx.array(get_noise(num_images, z_dim))\n",
" \n",
" fake_images = gen(noise)\n",
" fake_disc = disc(fake_images)\n",
"\n",
" fake_labels = mx.ones((fake_images.shape[0],1))\n",
" \n",
" gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels)\n",
"\n",
" return gen_loss"
]
},
{
"cell_type": "code",
"execution_count": 586,
"metadata": {},
"outputs": [],
"source": [
"# Get only the training images\n",
"train_images,*_ = map(np.array, mnist.mnist())"
]
},
{
"cell_type": "code",
"execution_count": 587,
"metadata": {},
"outputs": [],
"source": [
"# Normalize the images to fall between -1,1\n",
"train_images = train_images * 2.0 - 1.0"
]
},
{
"cell_type": "code",
"execution_count": 588,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x156eb0df0>"
]
},
"execution_count": 588,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.imshow(train_images[0].reshape(28,28),cmap='gray')"
]
},
{
"cell_type": "code",
"execution_count": 589,
"metadata": {},
"outputs": [],
"source": [
"def batch_iterate(batch_size: int, ipt: list[int])-> list[int]:\n",
" perm = np.random.permutation(len(ipt))\n",
" for s in range(0, len(ipt), batch_size):\n",
" ids = perm[s : s + batch_size]\n",
" yield ipt[ids]"
]
},
{
"cell_type": "code",
"execution_count": 590,
"metadata": {},
"outputs": [],
"source": [
"def show_images(imgs:list[int],num_imgs:int = 25):\n",
" if (imgs.shape[0] > 0): \n",
" fig,axes = plt.subplots(5, 5, figsize=(5, 5))\n",
" \n",
" for i, ax in enumerate(axes.flat):\n",
" img = mx.array(imgs[i]).reshape(28,28)\n",
" ax.imshow(img,cmap='gray')\n",
" ax.axis('off')\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### show first batch of images"
]
},
{
"cell_type": "code",
"execution_count": 591,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 500x500 with 25 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"X = batch_iterate(25, train_images)\n",
"for x in X: \n",
" show_images(x)\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training Cycle"
]
},
{
"cell_type": "code",
"execution_count": 592,
"metadata": {},
"outputs": [],
"source": [
"lr = 2e-8\n",
"z_dim = 64\n",
"\n",
"gen = Generator(z_dim)\n",
"mx.eval(gen.parameters())\n",
"gen_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.9]) #,betas=[0.5, 0.9]\n",
"\n",
"disc = Discriminator()\n",
"mx.eval(disc.parameters())\n",
"disc_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.9])"
]
},
{
"cell_type": "code",
"execution_count": 593,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/200 [00:00<?, ?it/s]"
]
}
],
"source": [
"# Set your parameters\n",
"n_epochs = 200\n",
"display_step = 5000\n",
"cur_step = 0\n",
"\n",
"batch_size = 64\n",
"\n",
"D_loss_grad = nn.value_and_grad(disc, disc_loss)\n",
"G_loss_grad = nn.value_and_grad(gen, gen_loss)\n",
"\n",
"\n",
"for epoch in tqdm(range(n_epochs)):\n",
"\n",
" for real in batch_iterate(batch_size, train_images):\n",
" \n",
" # TODO Train Discriminator\n",
" D_loss,D_grads = D_loss_grad(gen, disc,mx.array(real), batch_size, z_dim)\n",
"\n",
" # Update optimizer\n",
" disc_opt.update(disc, D_grads)\n",
" \n",
" # Update gradients\n",
" mx.eval(disc.parameters(), disc_opt.state)\n",
"\n",
" # TODO Train Generator\n",
" G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)\n",
" \n",
" # Update optimizer\n",
" gen_opt.update(gen, G_grads)\n",
" \n",
" # Update gradients\n",
" mx.eval(gen.parameters(), gen_opt.state) \n",
" \n",
" print('Losses D={0} G={1}'.format(D_loss,G_loss))\n",
" \n",
" if (cur_step + 1) % display_step == 0:\n",
" print(f\"Step {epoch}: Generator loss: {G_loss}, discriminator loss: {D_loss}\")\n",
" fake_noise = mx.array(get_noise(batch_size, z_dim))\n",
" fake = gen(fake_noise)\n",
" show_images(fake)\n",
" show_images(real)\n",
" cur_step += 1"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}