mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 09:51:19 +08:00
548 lines
108 KiB
Plaintext
548 lines
108 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Import Library"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 427,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import mnist"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 428,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import mlx.core as mx\n",
|
|
"import mlx.nn as nn\n",
|
|
"import mlx.optimizers as optim\n",
|
|
"\n",
|
|
"from tqdm import tqdm\n",
|
|
"import numpy as np\n",
|
|
"import matplotlib.pyplot as plt"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 429,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# mx.set_default_device(mx.gpu)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# GAN Architecture"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Generator 👨🏻🎨"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 430,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def GenBlock(in_dim:int,out_dim:int):\n",
|
|
" \n",
|
|
" return nn.Sequential(\n",
|
|
" nn.Linear(in_dim,out_dim),\n",
|
|
" nn.BatchNorm(out_dim, 0.8),\n",
|
|
" nn.LeakyReLU(0.2)\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 431,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Generator(nn.Module):\n",
|
|
"\n",
|
|
" def __init__(self, z_dim:int = 10, im_dim:int = 784, hidden_dim: int = 256):\n",
|
|
" super(Generator, self).__init__()\n",
|
|
" # Build the neural network\n",
|
|
" self.gen = nn.Sequential(\n",
|
|
" GenBlock(z_dim, hidden_dim),\n",
|
|
" GenBlock(hidden_dim, hidden_dim * 2),\n",
|
|
" GenBlock(hidden_dim * 2, hidden_dim * 4),\n",
|
|
" GenBlock(hidden_dim * 4, hidden_dim * 8),\n",
|
|
"\n",
|
|
"\n",
|
|
" nn.Linear(hidden_dim * 8,im_dim),\n",
|
|
" )\n",
|
|
" \n",
|
|
" def __call__(self, noise):\n",
|
|
" x = self.gen(noise)\n",
|
|
" return mx.tanh(x)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 432,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Generator(\n",
|
|
" (gen): Sequential(\n",
|
|
" (layers.0): Sequential(\n",
|
|
" (layers.0): Linear(input_dims=100, output_dims=256, bias=True)\n",
|
|
" (layers.1): BatchNorm(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (layers.2): LeakyReLU()\n",
|
|
" )\n",
|
|
" (layers.1): Sequential(\n",
|
|
" (layers.0): Linear(input_dims=256, output_dims=512, bias=True)\n",
|
|
" (layers.1): BatchNorm(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (layers.2): LeakyReLU()\n",
|
|
" )\n",
|
|
" (layers.2): Sequential(\n",
|
|
" (layers.0): Linear(input_dims=512, output_dims=1024, bias=True)\n",
|
|
" (layers.1): BatchNorm(1024, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (layers.2): LeakyReLU()\n",
|
|
" )\n",
|
|
" (layers.3): Sequential(\n",
|
|
" (layers.0): Linear(input_dims=1024, output_dims=2048, bias=True)\n",
|
|
" (layers.1): BatchNorm(2048, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (layers.2): LeakyReLU()\n",
|
|
" )\n",
|
|
" (layers.4): Linear(input_dims=2048, output_dims=784, bias=True)\n",
|
|
" )\n",
|
|
")"
|
|
]
|
|
},
|
|
"execution_count": 432,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"gen = Generator(100)\n",
|
|
"gen"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 433,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# make 2D noise with shape n_samples x z_dim\n",
|
|
"def get_noise(n_samples:list[int], z_dim:int)->list[int]:\n",
|
|
" return mx.random.normal(shape=(n_samples, z_dim))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 434,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAWKUlEQVR4nO3ca2zW9d3H8U9LhdJyKNVaoEABoTBkIEJtIDpAHAw2BDQ4j8MlbDFL3IzotmiyLCzOLTPRbImTaba4LFlMJuKGYAQhjoNAK1pAKKeCnAq01GIplNpy3c++ubmf9Pr8Etly5/16fL2vlp4+/J98czKZTEYAAEjK/U9/AgCA/x6MAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAEJeti986aWX7Df/+OOP7SYnJ8duJGno0KF2c9ttt9lNdXW13UyePNluVqxYYTeSNHfuXLvZsWOH3Zw9e9ZuKisr7UaSTp8+bTednZ1209bWdk2aefPm2Y0kvf/++3azceNGu1m8eLHdDBgwwG7y8rL+83OVgwcP2s306dPtpqOjw25qa2vtRpLGjRtnN7m5/v/pn3766e7f135XAMD/W4wCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAABC1hep3nnnHfvNq6qq7CblCJUkrV+/3m5Sjpk1NjbazZo1a+xm4cKFdiNJe/futZuioiK7GT9+vN2kfG6SdPnyZbspLi62mz59+tjNxYsX7aa5udluJKlfv352M2nSJLspLCy0m6amJrtpb2+3G0kqKyuzm88//9xujh8/bjc9e/a0G0latWqV3fzwhz9M+ljd4UkBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAAhJxMJpPJ5oVPPPGE/eZXrlyxm5TjbJJ05MgRu3n00UftprW11W7++te/2k3qYa0FCxbYzccff2w3ly5dspuZM2fajSStXbvWbsrLy+2mvr7ebrq6uuxm0KBBdiNJx44ds5uUA4733HOP3bz77rt2k/q7vmXLFru599577Wbp0qV284c//MFuJCk/P99ufve739lNNn/ueVIAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAIS8bF944403fpWfR6itrU3qOjo67ObDDz+0m6NHj9rNyJEj7SYnJ8dupLSrnU1NTXaTcgG3rq7ObiSppKTEbs6cOWM3p0+ftpu5c+fazfnz5+1GklpaWuxm/PjxdpNycfjmm2+2m5TLpZLUo0cPu7l8+bLdvPbaa3aT8ndIkoqLi+1m+vTpSR+rOzwpAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgJCTyWQy2bzwF7/4hf3mhYWFdtO3b1+7kaTGxka7+eCDD+zm888/t5tBgwbZzbBhw+xGkq677jq7KS8vt5vJkyfbzZNPPmk3ktTe3m43P/vZz+zm/ffft5uamhq7+cEPfmA3kvTZZ5/ZTZa/3le5dOmS3ZSVldlNytdOkkaPHm03KUf0Ug4kDh8+3G4kqaCgwG4uXLhgN7/+9a+7fQ1PCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACDkZfvClGNcU6dOtZuUw3uSdO+999rNfffdZzerVq2ym5TDgPn5+XYjSYcOHbKbW2+91W7+9a9/2c2iRYvsRpKam5vtJuXYYcrX/LHHHrOb3bt3242UdoTwueees5uUf9P27dvtpmfPnnYjSb169bKbrq4uu7l48aLdXMuDnimHAbPBkwIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIWR/EO3XqlP3mmzdvtpuZM2fajST16dPHbg4cOGA3J0+etJuFCxfazZkzZ+xGkhYvXmw3+/bts5uPPvrIblK+3pK0YMECu1m3bp3dPP7443YzYsQIu6mvr7cbKe3o3EsvvWQ3W7dutZuJEyfazdtvv203ktTa2mo3KX9XUv7mZTIZu5GkioqKa/axusOTAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAg5mSyvKi1atMh+8xkzZtjN/v377UaSmpub7ebmm2+2m4KCArtZtmyZ3aQcgZOk/Px8u3n55Zft5q677rKbWbNm2Y0kdXR02M3tt99uN3/729/spri42G5SfoYkqb293W4+/fRTu0k5kDhhwgS7mTp1qt1I0tChQ+2mZ8+edjNnzhy7STlAKEldXV12k/K7/sILL3T7Gp4UAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAAAhL9sX3nDDDfabHzx40G5yc9N2KuUKYt++fe1m165ddpNyhfSmm26yG0lauXKl3fzmN7+xm/LycrspLCy0GyntEumOHTvsZvDgwXaT8vNaUlJiN5K0d+9eu8nLy/pXPLzxxht289Of/tRuUq6dStLRo0ft5uTJk3Zz4cIFu6mtrbUbSZo9e7bdbN68OeljdYcnBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABCyvpbV0tJiv3lXV5fd3HnnnXYjSevXr7ebnJwcu/nyyy/tJuVzu/vuu+1GSjtcePz4cbtJOYh3/vx5u5Gk1atX283SpUvtZtq0aXbz97//3W7WrFljN5L0k5/8xG5Gjx5tNzU1NXYzceJEu0k9DNjW1mY3N9544zVpZsyYYTeSNHz4cLtpbGxM+ljd4UkBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAAhJxMJpPJ5oWPPfaY/ebz5s2zmxdffNFupLTjdln+068yduxYu1m4cKHdLFmyxG4kqbOz026ef/55uzlz5ozd1NfX240kjRs3zm5OnTplNxs3brSbMWPG2M211KtXL7tpbW21m/z8fLtJOQInpX1v33zzTbv54x//aDd1dXV2I0n/+Mc/7Cblb96+ffu6fQ1PCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACDkZfvCbA4p/V8HDhywm1mzZtmNJLW1tdnNpk2b7Gbbtm12U1lZaTcpxwSltH/T/v377Wbnzp12k3JoTZKWLl1qN9u3b7eblMOFLS0tdpNypE6Shg0bZjcNDQ1206NHD7tJOR5XUVFhN5L00EMP2c2AAQPsZuvWrXYze/Zsu5HSDlmm/C3KBk8KAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAICQk8lkMtm88PXXX7ff/JVXXrGbyZMn240k7d69227KysrsprS01G6am5vtJuVCoyRNmjTJbsaPH283R44csZvLly/bjSTl5vr/dykqKrKb0aNH203K5deUC7OStGPHDrt58MEH7WbUqFF2U1BQYDcrV660G0lavHix3XzxxRd2k3ItdtWqVXYjSc8++6zdrF271m6WL1/e7Wt4UgAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAADhKz2Il3Jw7re//a3dSNLXv/71pM41cOBAu9m1a5fd3H///XYjSe+8847dnDt3zm7WrVtnN3fddZfdSFJVVZXd1NXV2U11dbXdzJkzx24uXLhgN5JUWFhoNxs2bLCbjo4Ou0k5xJiXl2c3ktTZ2Wk3+fn5djN48GC72bNnj91Iad/bMWPG2M0zzzzT7Wt4UgAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAAAh64tUZ86csd/8jTfesJu//OUvdiNJTz75pN0UFBTYTXt7u9307t3bbp577jm7kaSSkhK7qa2ttZtly5bZzQsvvGA3kvS1r33NblJ+XisqKuymoaHBbvbt22c3knTLLbfYzbBhw+zmk08+sZvbbrvNblasWGE3kvT000/bzVtvvWU3J06csJvS0lK7kaSamhq7SfneZoMnBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABByMplMJpsXPvDAA/ab9+rVy25mzJhhN5K0d+9eu9mwYYPdfPe737Wb3bt3283WrVvtRpLmz59vN8ePH7ebsWPH2k1eXtb3F69y4MABu8nyx/oqKT9DlZWVdtPS0mI3ktTZ2Wk38+bNs5tLly7ZzZYtW+wm9eehqKjomjQpxy9XrlxpN5K0YMECuzl27JjdZHMYkCcFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAELI+iPfzn//cfvPt27fbzfTp0+1GknJz/X1LOciVctyud+/edtPR0WE3ktTc3Gw3119/vd2cO3fOblIO76V+rJRDcFeuXLGblENrpaWldiNJNTU1djN69Gi7SfkZ2rZtm92MGzfObiRp3759dpNy0PPEiRN28+1vf9tuJOnVV1+1m6lTp9rNsmXLun0NTwoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgJD1ldQ777zTfvPx48fbTY8ePexGkqqqquxm69atdtPZ2Wk369ats5umpia7kaRvfetbdjNr1iy7qa6utpvDhw/bjSSNHDnSburr6+2mvb3dbn75y1/azeOPP243knT33XfbzY4dO+wm5WruF198YTdDhgyxG0k6deqU3fTr189uUn4HZ86caTdS2tfiz3/+s91s2rSp29fwpAAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAABCXrYvLCkpsd+8T58+dpNycE5KO/zVv39/uzly5Ijd/PjHP7ablStX2o2Udrhwz549djNgwAC7ycnJsRtJGjt2rN2kHLe7ePGi3fzpT3+ym/z8fLuR0o7OlZWV2c2JEyfsZuLEiXbTq1cvu5GkgQMH2k3K9zbLW6FX6ejosBtJys31/38+Z86cpI/VHZ4UAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQMj6IN78+fPtN085Unfq1Cm7kaTPPvvMblKOx6V8fgcPHrSbUaNG2U3qx6qqqrKbtWvX2k3KgTFJqq6utpuUY4wPPPCA3bz99tt2s3z5cruR0o4kdnV12c0dd9xhN4cOHbKboUOH2o0ktbS02E15ebndFBUV2c22bdvsRko7BJryu54NnhQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAyPogXkNDg/3m+fn5djN37ly7kaT33nvPbi5evGg3jY2NdtO/f3+7STnoJklHjx61m9dee81uHn30UbtZvXq13UjSlClT7CbluN2aNWvsJuXg3O9//3u7kaQRI0bYTUlJid2kHJesqKiwm5RjgpI0bdo0u2ltbbWbDz74wG5SlZaW2s3ly5e/gs+EJwUAwP/CKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIGR9EK+5udl+8+LiYrtJOTgnSZlMxm569+5tN9/5znfs5uTJk3ZTWVlpN6nuu+8+uykqKrKbZ5991m4kacuWLXazc+dOu2lra7OblpYWu3n44YftRpJeeeUVu5kzZ47dpByKTDlI+cwzz9iNJH300Ud2c/r0abtZtGiR3bz55pt2I0lnz561m5S/r9ngSQEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAELK+krp69Wr7zSdNmmQ3gwcPthtJKi0ttZsrV67YTco11ry8rL/MoampyW4k6ciRI3ZTUlJiNxs2bLCbqqoqu5GkKVOm2E1dXZ3dDBkyxG5SLqu++uqrdiNJ/fr1s5u1a9fazU033WQ3S5YssZuUy6BS2td86tSpdrNnzx67yc1N+3/2wIED7Sb1onR3eFIAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAIetLbXfccYf95vn5+XZz+PBhu5Gkb3zjG3aTcgju+eeft5uUY33Nzc12I6UdySoqKrKb8vJyu2ltbbUbSSosLLSbY8eO2U3KgbYf/ehHdvPvf//bbiRp5MiRdrNp0ya7GTp0qN0sX778mjSSdPLkSbspKyuzm+nTp9vN8OHD7UZKO7S5YsUKu/nVr37V7Wt4UgAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAAAh64N4FRUV9pvv2rXLbiZNmmQ3ktTW1mY3ffv2tZtp06bZTcrX7r333rMbSZo/f77dXLlyxW7++c9/2k3KUTJJ6urqspspU6bYzbvvvms3b731lt1MnDjRbiRp8+bNdpPys1dZWWk3/fv3t5va2lq7kaTz58/bTcqhupdfftluOjs77UaSJkyYYDe333570sfqDk8KAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIORkMplMNi/8/ve/b795yuGqvXv32o0k1dfX281TTz1lNwUFBXZzyy232M26devsRko7HjdixAi7aW1ttZu8vKzvL16lsbHRbtrb2+3m3LlzdvPII4/Yze7du+1Gks6ePWs3X375pd2k/N62tLTYTcqxPkkaNWqU3Rw8eNBuCgsL7aahocFuJOmb3/ym3bz44ot2U11d3e1reFIAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAISsz1amXBksKyuzm+uvv95uJGn27Nl2U1xcbDf79++3m5TLqj179rQbSWpqarKblEua69evt5vJkyfbjST169fPblK+T7feeqvdpFyzLS8vtxtJGjJkiN0cPnzYbq677jq7SfkepVw7laS6ujq7OX78uN1MmDDBbsaMGWM3kvThhx8mdV8FnhQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAyPogXsrxqpTjUClH9CRpxIgRdvP666/bzfe+9z27STlKtmTJEruRpJqaGrtJOXY4YMAAuxk0aJDdSNLGjRvtZtq0aXZTW1trN/PmzbObnTt32o2UdrjwhhtusJvm5ma7efDBB+0m5fdCkg4dOmQ3VVVVdrNp0ya7ueeee+xGSvuZaGhoSPpY3eFJAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAIScTCaT+U9/EgCA/w48KQAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAML/AJEF88QMTNfdAAAAAElFTkSuQmCC",
|
|
"text/plain": [
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"img = get_noise(28,28)\n",
|
|
"plt.imshow(img, cmap='gray')\n",
|
|
"plt.axis('off')\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Discriminator 🕵🏻♂️"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 435,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def DisBlock(in_dim:int,out_dim:int):\n",
|
|
" return nn.Sequential(\n",
|
|
" nn.Linear(in_dim,out_dim),\n",
|
|
" nn.LeakyReLU(negative_slope=0.2),\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 436,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Discriminator(nn.Module):\n",
|
|
"\n",
|
|
" def __init__(self,im_dim:int = 784, hidden_dim:int = 128):\n",
|
|
" super(Discriminator, self).__init__()\n",
|
|
"\n",
|
|
" self.disc = nn.Sequential(\n",
|
|
" # DisBlock(im_dim, hidden_dim * 4),\n",
|
|
" # DisBlock(hidden_dim * 4, hidden_dim * 2),\n",
|
|
" # DisBlock(hidden_dim * 2, hidden_dim),\n",
|
|
" \n",
|
|
" DisBlock(im_dim, hidden_dim * 2),\n",
|
|
" DisBlock(hidden_dim * 2, hidden_dim),\n",
|
|
"\n",
|
|
" nn.Linear(hidden_dim,1),\n",
|
|
" nn.Sigmoid()\n",
|
|
" )\n",
|
|
" \n",
|
|
" def __call__(self, noise):\n",
|
|
" return self.disc(noise)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 437,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Discriminator(\n",
|
|
" (disc): Sequential(\n",
|
|
" (layers.0): Sequential(\n",
|
|
" (layers.0): Linear(input_dims=784, output_dims=256, bias=True)\n",
|
|
" (layers.1): LeakyReLU()\n",
|
|
" )\n",
|
|
" (layers.1): Sequential(\n",
|
|
" (layers.0): Linear(input_dims=256, output_dims=128, bias=True)\n",
|
|
" (layers.1): LeakyReLU()\n",
|
|
" )\n",
|
|
" (layers.2): Linear(input_dims=128, output_dims=1, bias=True)\n",
|
|
" (layers.3): Sigmoid()\n",
|
|
" )\n",
|
|
")"
|
|
]
|
|
},
|
|
"execution_count": 437,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"disc = Discriminator()\n",
|
|
"disc"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Model Training 🏋🏻♂️"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Losses"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"#### Discriminator Loss"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 438,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def disc_loss(gen, disc, real, num_images, z_dim):\n",
|
|
" \n",
|
|
" noise = mx.array(get_noise(num_images, z_dim))\n",
|
|
" fake_images = gen(noise)\n",
|
|
" \n",
|
|
" fake_disc = disc(fake_images)\n",
|
|
" \n",
|
|
" fake_labels = mx.zeros((fake_images.shape[0],1))\n",
|
|
" \n",
|
|
" fake_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels)\n",
|
|
" \n",
|
|
" real_disc = disc(real)\n",
|
|
" real_labels = mx.ones((real.shape[0],1))\n",
|
|
"\n",
|
|
" real_loss = nn.losses.binary_cross_entropy(real_disc,real_labels)\n",
|
|
"\n",
|
|
" disc_loss = (fake_loss + real_loss) / 2\n",
|
|
"\n",
|
|
" return disc_loss"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"#### Generator Loss"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 439,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def gen_loss(gen, disc, num_images, z_dim):\n",
|
|
"\n",
|
|
" noise = mx.array(get_noise(num_images, z_dim))\n",
|
|
" \n",
|
|
" fake_images = gen(noise)\n",
|
|
" fake_disc = disc(fake_images)\n",
|
|
"\n",
|
|
" fake_labels = mx.ones((fake_images.shape[0],1))\n",
|
|
" \n",
|
|
" gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels)\n",
|
|
"\n",
|
|
" return gen_loss"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 440,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Get only the training images\n",
|
|
"train_images,*_ = map(np.array, mnist.mnist())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 441,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Normalize the images to fall between -1,1\n",
|
|
"train_images = train_images * 2.0 - 1.0"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 442,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"<matplotlib.image.AxesImage at 0x156d411b0>"
|
|
]
|
|
},
|
|
"execution_count": 442,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"plt.imshow(train_images[0].reshape(28,28),cmap='gray')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 443,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def batch_iterate(batch_size: int, ipt: list[int])-> list[int]:\n",
|
|
" perm = np.random.permutation(len(ipt))\n",
|
|
" for s in range(0, len(ipt), batch_size):\n",
|
|
" ids = perm[s : s + batch_size]\n",
|
|
" yield ipt[ids]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 444,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def show_images(imgs:list[int],num_imgs:int = 25):\n",
|
|
" if (imgs.shape[0] > 0): \n",
|
|
" fig,axes = plt.subplots(5, 5, figsize=(5, 5))\n",
|
|
" \n",
|
|
" for i, ax in enumerate(axes.flat):\n",
|
|
" img = mx.array(imgs[i]).reshape(28,28)\n",
|
|
" ax.imshow(img,cmap='gray')\n",
|
|
" ax.axis('off')\n",
|
|
" plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### show first batch of images"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 445,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 500x500 with 25 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"X = batch_iterate(25, train_images)\n",
|
|
"for x in X: \n",
|
|
" show_images(x)\n",
|
|
" break"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Training Cycle"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 446,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"lr = 2e-4\n",
|
|
"z_dim = 64\n",
|
|
"\n",
|
|
"gen = Generator(z_dim)\n",
|
|
"mx.eval(gen.parameters())\n",
|
|
"gen_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999]) #,betas=[0.5, 0.9]\n",
|
|
"\n",
|
|
"disc = Discriminator()\n",
|
|
"mx.eval(disc.parameters())\n",
|
|
"disc_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 447,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" 0%| | 0/200 [00:00<?, ?it/s]"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Set your parameters\n",
|
|
"n_epochs = 200\n",
|
|
"display_step = 5000\n",
|
|
"cur_step = 0\n",
|
|
"\n",
|
|
"batch_size = 128\n",
|
|
"\n",
|
|
"D_loss_grad = nn.value_and_grad(disc, disc_loss)\n",
|
|
"G_loss_grad = nn.value_and_grad(gen, gen_loss)\n",
|
|
"\n",
|
|
"\n",
|
|
"for epoch in tqdm(range(n_epochs)):\n",
|
|
"\n",
|
|
" for real in batch_iterate(batch_size, train_images):\n",
|
|
" \n",
|
|
" # TODO Train Discriminator\n",
|
|
" D_loss,D_grads = D_loss_grad(gen, disc,mx.array(real), batch_size, z_dim)\n",
|
|
"\n",
|
|
" # Update optimizer\n",
|
|
" disc_opt.update(disc, D_grads)\n",
|
|
" \n",
|
|
" # Update gradients\n",
|
|
" mx.eval(disc.parameters(), disc_opt.state)\n",
|
|
"\n",
|
|
" # TODO Train Generator\n",
|
|
" G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)\n",
|
|
" \n",
|
|
" # Update optimizer\n",
|
|
" gen_opt.update(gen, G_grads)\n",
|
|
" \n",
|
|
" # Update gradients\n",
|
|
" mx.eval(gen.parameters(), gen_opt.state)\n",
|
|
" \n",
|
|
" \n",
|
|
" if (cur_step + 1) % display_step == 0:\n",
|
|
" print(f\"Step {epoch}: Generator loss: {G_loss}, discriminator loss: {D_loss}\")\n",
|
|
" fake_noise = mx.array(get_noise(batch_size, z_dim))\n",
|
|
" fake = gen(fake_noise)\n",
|
|
" show_images(fake)\n",
|
|
" show_images(real)\n",
|
|
" cur_step += 1"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "base",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.10"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|