mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00
539 lines
100 KiB
Plaintext
539 lines
100 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Import Library"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 369,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import mnist"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 370,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import mlx.core as mx\n",
|
|
"import mlx.nn as nn\n",
|
|
"import mlx.optimizers as optim\n",
|
|
"\n",
|
|
"from tqdm import tqdm\n",
|
|
"import numpy as np\n",
|
|
"import matplotlib.pyplot as plt"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# GAN Architecture"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Generator 👨🏻🎨"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 371,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def GenBlock(in_dim:int,out_dim:int):\n",
|
|
" \n",
|
|
" return nn.Sequential(\n",
|
|
" nn.Linear(in_dim,out_dim),\n",
|
|
" nn.BatchNorm(out_dim, 0.8),\n",
|
|
" nn.LeakyReLU(0.2)\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 393,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Generator(nn.Module):\n",
|
|
"\n",
|
|
" def __init__(self, z_dim:int = 10, im_dim:int = 784, hidden_dim: int = 256):\n",
|
|
" super(Generator, self).__init__()\n",
|
|
" # Build the neural network\n",
|
|
" self.gen = nn.Sequential(\n",
|
|
" GenBlock(z_dim, hidden_dim),\n",
|
|
" GenBlock(hidden_dim, hidden_dim * 2),\n",
|
|
" GenBlock(hidden_dim * 2, hidden_dim * 4),\n",
|
|
" GenBlock(hidden_dim * 4, hidden_dim * 8),\n",
|
|
"\n",
|
|
"\n",
|
|
" nn.Linear(hidden_dim * 8,im_dim),\n",
|
|
" )\n",
|
|
" \n",
|
|
" def __call__(self, noise):\n",
|
|
" x = self.gen(noise)\n",
|
|
" return mx.tanh(x)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 394,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Generator(\n",
|
|
" (gen): Sequential(\n",
|
|
" (layers.0): Sequential(\n",
|
|
" (layers.0): Linear(input_dims=100, output_dims=256, bias=True)\n",
|
|
" (layers.1): BatchNorm(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (layers.2): LeakyReLU()\n",
|
|
" )\n",
|
|
" (layers.1): Sequential(\n",
|
|
" (layers.0): Linear(input_dims=256, output_dims=512, bias=True)\n",
|
|
" (layers.1): BatchNorm(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (layers.2): LeakyReLU()\n",
|
|
" )\n",
|
|
" (layers.2): Sequential(\n",
|
|
" (layers.0): Linear(input_dims=512, output_dims=1024, bias=True)\n",
|
|
" (layers.1): BatchNorm(1024, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (layers.2): LeakyReLU()\n",
|
|
" )\n",
|
|
" (layers.3): Sequential(\n",
|
|
" (layers.0): Linear(input_dims=1024, output_dims=2048, bias=True)\n",
|
|
" (layers.1): BatchNorm(2048, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (layers.2): LeakyReLU()\n",
|
|
" )\n",
|
|
" (layers.4): Linear(input_dims=2048, output_dims=784, bias=True)\n",
|
|
" )\n",
|
|
")"
|
|
]
|
|
},
|
|
"execution_count": 394,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"gen = Generator(100)\n",
|
|
"gen"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 374,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# make 2D noise with shape n_samples x z_dim\n",
|
|
"def get_noise(n_samples:list[int], z_dim:int)->list[int]:\n",
|
|
" return mx.random.normal(shape=(n_samples, z_dim))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 375,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAWL0lEQVR4nO3cXWzW9d3H8U+hVChQO7oWWgq1lDKFOmdpwWnBhmWEqYiLDoWNwJI9xEGyGN2WbOMAs5EsJtPEZZlb5qYOlFVlM50PgEp5cD4AhVVKiSJQpUBhFKEtZVCv++ybmPug1+d34H1neb+Or/d1QR/48D/55mQymYwAAJA07P/6DwAA+P+DUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEDIzfaFP/zhD+03f+CBB+wm5XMkqa6uzm5KSkrspr293W7uueceu3nppZfsRpJ27txpNw0NDXZz/fXX283evXvtRpJmzJhhN4cPH7abffv22c2sWbPsprCw0G4kqaWlxW6mTJliN83NzXZz7bXX2s3UqVPtRpI+/PBDu7npppvs5uzZs3bT3d1tN5LU0dFhNym/g/fff/+Qr+FJAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAIScTCaTyeaFS5cutd+8rKzMbsrLy+1Gkrq6uuwm5WhayhGvqqoqu0k9rNXa2mo3RUVFdnPu3Dm7GTt2rN1I0sDAgN1MnjzZbmbOnGk369ats5vBwUG7kaS7777bbjo7O+1m2rRpdtPU1GQ3Kb+zklRfX283X/va1+zmr3/9q928//77diNJEydOtJuUQ5HZHCnlSQEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAACE3GxfOHv2bPvNf/KTn9jNz372M7uRpOPHj9tNT0+P3Zw5c8Zuzp49azelpaV2I0kjR460m5ycHLsZNsz//8SpU6fsRpKuuuoqu3n22Wft5uOPP7ab6667zm727t1rN5K0ceNGuykoKLCbLVu22E3Kz8OECRPsRpJKSkrspr+/327y8vLsZu7cuXYjSaNGjbKbY8eOJX3WUHhSAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAACErK+kvvrqq/ab//znP7ebwsJCu5GkO++8025SrlXecMMNdvPkk0/aTSaTsRsp7fJkytcu5ZLmm2++aTdS2s9ETU2N3Zw7d85uUv5sKddYJWnhwoV2M3HiRLtpamqym4qKCrs5evSo3UjSJ598Yje7du2ym5TvbcrnSNL8+fPtZt68eUmfNRSeFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEDIyWR5ee2ZZ56x37yqqspuHn74YbuRpNmzZ9vNvn377CY3N+sbgmHUqFF2MzAwYDdS2hGvixcv2s3UqVPtprOz024k6cSJE3bzxS9+0W52795tN0VFRXZTWVlpN5J04MABu8nPz7eblINzK1assJuWlha7kaRjx47ZTcrfqa6uzm5Sf297e3vt5oknnrCb1tbWIV/DkwIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIWV93O3XqlP3mmzZtsptp06bZjSTV1NTYzenTp+3m8OHDdjNlyhS7+cIXvmA3krR27Vq7WbRokd2cOXPGbi5fvmw3kjR+/Hi76e/vt5srr7zSburr6+3m6NGjdpPa3XbbbXYzZswYu9m/f7/dpP48pPzspXztUv58OTk5diNJ06dPt5sf//jHSZ81FJ4UAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQMj6IF7KEari4mK72bZtm91I0owZM+zm5MmTdjNhwgS7aW9vt5uLFy/ajSStXr3abnbv3m03Kd/blMOAkvTGG2/Yzc0332w311xzjd2kHH3s6emxGyntMODAwIDdnDhxwm4+/vjjz+RzJKm2ttZuPve5z9nN8uXL7eaJJ56wG0l6++237Wbv3r12s2TJkiFfw5MCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACFkfxNu1a5f95jU1NXZTUVFhN5L0wgsv2M2PfvQju3nzzTftprS01G4mT55sN5K0ceNGu5k7d67dnD9/3m5SDnhJ0iuvvGI3/f39drNs2TK72bp1q92sXLnSbiRp+PDhdlNdXW03nZ2ddpOfn283b731lt1IUm5u1v9shXnz5tlNb2+v3cycOdNuJOngwYOf2WcNhScFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEDIyWQymWxe+Jvf/MZ+8+7ubru5dOmS3UhpFxfnz59vN++++67d3H777XZz33332Y0k/fKXv7SblMuvPT09dlNbW2s3ktTR0WE3o0aNspuUy6/19fV2M3LkSLuRpP/85z92k3Ld+O6777ablCufKVdIJempp56ym4KCArtZunSp3bz33nt2I0llZWV2U1RUZDf33nvvkK/hSQEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAACErA/ibd++3X7zlINX//znP+1GSjuaNmvWLLtZsGCB3Tz33HN2s2jRIruRpI0bN9rNmDFj7GbatGl289vf/tZuJKmystJudu7caTdf//rX7aa8vNxu2tra7EaSSktL7Wb06NF2c+HCBbtJOZDY2dlpN5I0fvx4u0k5kJjy71dfX5/dSNKePXvsJuX3dseOHUO+hicFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAELI+iPerX/3KfvPi4mK7STnyJEmbNm2ym5QjWVl+uT4l5UhWSUmJ3UjSe++9Zzcp36eUQ3B5eXl2I0nV1dV209XVZTcpP0Pz58+3m/z8fLuRpJaWFrv57ne/azfr16+3m7vuustu1q5dazeSNGfOHLtZvHix3SxZssRubr75ZruR0o4xPvLII3bzt7/9bcjX8KQAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAQtYH8Z588kn7zYcN8zenvb3dbiTp8uXLdjNlyhS7STkE19nZaTePPvqo3UhpR8ZycnLsJuV4XMqROkm6cOGC3SxYsMBuzp07ZzcTJ060m97eXruRpJ6eHrs5efLkZ/I548aNs5vx48fbjSQ9++yzdjNixAi7mTRpkt3U1NTYjZR2JLGtrc1uNmzYMORreFIAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAITcbF946NAh+83r6+vtpq+vz25SPyvlwmVTU5PdVFZW2s1XvvIVu5Gkd955x25SrkF++ctftpudO3fajSRlecj3U8aOHWs3U6dOtZuUi51FRUV2I0mtra12s2bNGrt59dVX7aalpcVuFi1aZDeS1NjYaDcvvPCC3RQXF9tNbm7W/6R+Ssp16JTrwdngSQEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAACEnEyW18ZSDqD19PTYzY033mg3klRRUWE3Z8+etZu8vDy7STnyd/78ebuRpA8++MBu5syZYzcffvih3aQcqZOkuXPn2s0//vEPu5kyZYrdtLW12c0PfvADu5HS/k5dXV12U1BQYDeDg4N2U1ZWZjeSNHLkSLuprq62m46ODrt544037EaSFi9ebDcpByYff/zxIV/DkwIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIWR/EW758uf3m9913n928/vrrdiNJx48ft5ucnBy7eeedd+xm1qxZdnPs2DG7kaQRI0bYzfTp0+2mvLzcblKOs0nSiRMn7CblANrAwIDdpByP2717t91I0sWLF+1m2rRpdnP69Gm72b59u93U1tbajSRVVVXZTXNzs92UlpbaTUNDg91I0vr16+2mvr7ebh566KEhX8OTAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAhZH8Rbs2aN/eYlJSV2093dbTeStH//fru59tpr7aampsZuNm/ebDeVlZV2I6UdBhwcHLSbbdu22c2dd95pN5I0adIku2lqarKbpUuX2s2ZM2fspq+vz24kqa2tzW7mzp1rN1u2bLGbVatW2c2OHTvsRpI+//nP203Kvyv5+fl2k+U/p/9Lyt+ps7PTblavXj3ka3hSAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAACE32xeWlpbabz4wMGA35eXldiOlHZBrbm62mzlz5tjN7bffbjcXLlywG0kaM2aM3RQWFtpNyvepqKjIbqS0A3IpR/Ta29vtZsGCBXaTcthOkg4fPmw33//+9+0m5fv0zDPP2E1eXp7dSNJjjz1mN62trXbz0EMP2U1FRYXdSGnf29TjoUPhSQEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAELK+kppyOfGll16ym/7+fruRpEuXLtnNxYsX7Wb79u12c++999pNytdOkjo6Ouzmtddes5uf/vSndrN37167kaRvfvObdvPyyy/bzZe+9CW7Sfl56OzstBtJGjt2rN28+OKLdnPFFVfYzfnz5+0m9WruqlWr7Obvf/+73ZSVldnNn/70J7uRpIULF9pNY2Nj0mcNhScFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAELI+iHfgwAH7zVOO29XW1tqNJI0YMcJuMpmM3WzatMluWltb7SY3N+tvzafcdddddlNdXW03Dz74oN2kHLaTpHXr1tnNsmXL7ObUqVN2k3II7jvf+Y7dSGlHEk+fPm03dXV1dpNy9HHlypV2I0njxo2zm7feestuUg4X5uXl2Y0klZaW2k13d3fSZw2FJwUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQsr66tnXrVvvNFy9ebDcPP/yw3UjS/fffbzfPP/+83RQXF9tNWVmZ3RQWFtqNJB06dMhujh49ajc33HCD3bS3t9uNlPY137Bhg92kHGPcvHmz3RQUFNiNlHYk8bM6+nj27Fm7STlSJ0mNjY12M2yY///f+vp6u3nkkUfsRkr783V1dSV91lB4UgAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAAAhJ5Plxaxf/OIX9puPGzfObrq7u+0mVcqRv4aGBrv56KOP7KakpMRuJGnixIl2k3KYbPny5Xazc+dOu5GkG2+80W5ycnLspq2tzW7y8/PtJuXnTpKuuuoqu0k5XJhylHLNmjV2c/DgQbuRpOHDh9tNyu9gSjNhwgS7kdIO4qV8/Zqamob+s9jvCgD4r8UoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAg5Gb7wpSjaRUVFXZz22232Y0kNTc32011dbXdDAwM2E1NTY3d7N69224k6f3337ebRYsW2U3KIbgxY8bYjSSdOHHCblKOx40ePdpurrjiCrtJPZq2ePFiu0k5HpfyOR988IHd9PX12Y0k7dmzx24eeOABu/n1r39tNyk/q5JUVlZmN9ddd13SZw2FJwUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQMj6Suq3v/1t+827u7vt5umnn7YbSfrqV79qN8ePH7ebjz76yG727dtnNyUlJXYjSdOnT7ebmTNn2s3GjRvtZtSoUXYjSV1dXXZz9dVX282hQ4fsJuUa6ze+8Q27kaRt27bZzZYtW+ymoaHBbtavX283qV+HFStW2E17e7vdpPzeLly40G4kacSIEXaTyWSSPmsoPCkAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAkPVBvL/85S/2m8+YMcNuioqK7EaSiouL7SblyN+tt95qN3V1dXazbt06u5GkyspKu+no6LCbU6dO2U1eXp7dSFJZWZnd9PX12c2ePXvs5o477rCbzZs3240kHTlyxG5SfvaGDx9uNytXrrSblpYWu5Gk5uZmu6moqLCbxsZGu0k5bCdJg4ODdpPye5ENnhQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAyPog3pw5c+w3z8/Pt5vXXnvNbiTpe9/7nt08+OCDdpNyWKu0tNRuent77UZK+5qnHECbPHmy3aQeC3vllVfsprW11W5SjtulHBNMOdYnSStWrLCbf//733aT8ud78cUX7aaqqspuJOn666+3m2HD/P//Hjp0yG4uXbpkN5J0+fJlu/nkk0+SPmsoPCkAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAkPVBvHnz5tlv/txzz9lNUVGR3UjSo48+ajd//vOf7ebo0aN2k3K4qra21m6ktGNmjY2NdnP48GG7uXDhgt1I0re+9S272bFjh93k5mb96xCef/55u1m2bJndSNK6devs5o9//KPdpByKbGhosJuU75EkTZo0yW7Ky8vt5sCBA3Zz00032Y0kHTlyxG5SDysOhScFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEDI+izk66+/br95yiXNd999126ktIuGhYWFdvO73/3Obqqqquwm5aqjJO3fv99upk6dajcpf74VK1bYjSS9/fbbdvOvf/3Lbq6++mq7GRwctJtrrrnGbiRp9erVdrNkyRK72bBhg93ccsstdpP6daisrLSbxx57zG5mz55tN3V1dXYjSb29vXaza9eupM8aCk8KAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIORkMplMNi9cu3at/ebjxo2zm46ODruRpP7+frs5cuSI3YwePdpuVq5caTdjx461G0l6/PHH7Sbl75RyRG/kyJF2I6UdFKyoqLCb3//+93bz9NNP203KkTpJKigosJuTJ0/aTU1Njd2kHJf8wx/+YDeSVFpaajcpX4d77rnHbrZu3Wo3knTw4EG7Sfk5WrVq1ZCv4UkBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAAhKwP4gEA/vvxpAAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAj/AyUT0oK8p4s5AAAAAElFTkSuQmCC",
|
|
"text/plain": [
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"img = get_noise(28,28)\n",
|
|
"plt.imshow(img, cmap='gray')\n",
|
|
"plt.axis('off')\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Discriminator 🕵🏻♂️"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 376,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def DisBlock(in_dim:int,out_dim:int):\n",
|
|
" return nn.Sequential(\n",
|
|
" nn.Linear(in_dim,out_dim),\n",
|
|
" nn.LeakyReLU(negative_slope=0.2),\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 377,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Discriminator(nn.Module):\n",
|
|
"\n",
|
|
" def __init__(self,im_dim:int = 784, hidden_dim:int = 128):\n",
|
|
" super(Discriminator, self).__init__()\n",
|
|
"\n",
|
|
" self.disc = nn.Sequential(\n",
|
|
" # DisBlock(im_dim, hidden_dim * 4),\n",
|
|
" # DisBlock(hidden_dim * 4, hidden_dim * 2),\n",
|
|
" # DisBlock(hidden_dim * 2, hidden_dim),\n",
|
|
" \n",
|
|
" DisBlock(im_dim, hidden_dim * 2),\n",
|
|
" DisBlock(hidden_dim * 2, hidden_dim),\n",
|
|
"\n",
|
|
" nn.Linear(hidden_dim,1),\n",
|
|
" nn.Sigmoid()\n",
|
|
" )\n",
|
|
" \n",
|
|
" def __call__(self, noise):\n",
|
|
" return self.disc(noise)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 378,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Discriminator(\n",
|
|
" (disc): Sequential(\n",
|
|
" (layers.0): Sequential(\n",
|
|
" (layers.0): Linear(input_dims=784, output_dims=256, bias=True)\n",
|
|
" (layers.1): LeakyReLU()\n",
|
|
" )\n",
|
|
" (layers.1): Sequential(\n",
|
|
" (layers.0): Linear(input_dims=256, output_dims=128, bias=True)\n",
|
|
" (layers.1): LeakyReLU()\n",
|
|
" )\n",
|
|
" (layers.2): Linear(input_dims=128, output_dims=1, bias=True)\n",
|
|
" (layers.3): Sigmoid()\n",
|
|
" )\n",
|
|
")"
|
|
]
|
|
},
|
|
"execution_count": 378,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"disc = Discriminator()\n",
|
|
"disc"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Model Training 🏋🏻♂️"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Losses"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"#### Discriminator Loss"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 379,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def disc_loss(gen, disc, real, num_images, z_dim):\n",
|
|
" \n",
|
|
" noise = mx.array(get_noise(num_images, z_dim))\n",
|
|
" fake_images = gen(noise)\n",
|
|
" \n",
|
|
" fake_disc = disc(fake_images)\n",
|
|
" \n",
|
|
" fake_labels = mx.zeros((fake_images.shape[0],1))\n",
|
|
" \n",
|
|
" fake_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels)\n",
|
|
" \n",
|
|
" real_disc = disc(real)\n",
|
|
" real_labels = mx.ones((real.shape[0],1))\n",
|
|
"\n",
|
|
" real_loss = nn.losses.binary_cross_entropy(real_disc,real_labels)\n",
|
|
"\n",
|
|
" disc_loss = (fake_loss + real_loss) / 2\n",
|
|
"\n",
|
|
" return disc_loss"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"#### Generator Loss"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 380,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def gen_loss(gen, disc, num_images, z_dim):\n",
|
|
"\n",
|
|
" noise = mx.array(get_noise(num_images, z_dim))\n",
|
|
" \n",
|
|
" fake_images = gen(noise)\n",
|
|
" fake_disc = disc(fake_images)\n",
|
|
"\n",
|
|
" fake_labels = mx.ones((fake_images.shape[0],1))\n",
|
|
" \n",
|
|
" gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels)\n",
|
|
"\n",
|
|
" return gen_loss"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 381,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Get only the training images\n",
|
|
"train_images,*_ = map(np.array, mnist.mnist())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 382,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Normalize the images to fall between -1,1\n",
|
|
"train_images = train_images * 2.0 - 1.0"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 383,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"<matplotlib.image.AxesImage at 0x157b0bb80>"
|
|
]
|
|
},
|
|
"execution_count": 383,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"plt.imshow(train_images[0].reshape(28,28),cmap='gray')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 384,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def batch_iterate(batch_size: int, ipt: list[int])-> list[int]:\n",
|
|
" perm = np.random.permutation(len(ipt))\n",
|
|
" for s in range(0, len(ipt), batch_size):\n",
|
|
" ids = perm[s : s + batch_size]\n",
|
|
" yield ipt[ids]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 385,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def show_images(imgs:list[int],num_imgs:int = 25):\n",
|
|
" if (imgs.shape[0] > 0): \n",
|
|
" fig,axes = plt.subplots(5, 5, figsize=(5, 5))\n",
|
|
" \n",
|
|
" for i, ax in enumerate(axes.flat):\n",
|
|
" img = mx.array(imgs[i]).reshape(28,28)\n",
|
|
" ax.imshow(img,cmap='gray')\n",
|
|
" ax.axis('off')\n",
|
|
" plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### show first batch of images"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 386,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 500x500 with 25 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"X = batch_iterate(25, train_images)\n",
|
|
"for x in X: \n",
|
|
" show_images(x)\n",
|
|
" break"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Training Cycle"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 391,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"lr = 0.002\n",
|
|
"z_dim = 100\n",
|
|
"\n",
|
|
"gen = Generator(z_dim)\n",
|
|
"mx.eval(gen.parameters())\n",
|
|
"gen_opt = optim.Adam(learning_rate=lr,betas=[0.5, 0.999])\n",
|
|
"\n",
|
|
"disc = Discriminator()\n",
|
|
"mx.eval(disc.parameters())\n",
|
|
"disc_opt = optim.Adam(learning_rate=lr,betas=[0.5, 0.999])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 395,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" 4%|▍ | 9/200 [00:59<21:10, 6.65s/it]"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Set your parameters\n",
|
|
"n_epochs = 200\n",
|
|
"display_step = 5000\n",
|
|
"cur_step = 0\n",
|
|
"\n",
|
|
"batch_size = 128\n",
|
|
"\n",
|
|
"D_loss_grad = nn.value_and_grad(disc, disc_loss)\n",
|
|
"G_loss_grad = nn.value_and_grad(gen, gen_loss)\n",
|
|
"\n",
|
|
"\n",
|
|
"for epoch in tqdm(range(n_epochs)):\n",
|
|
"\n",
|
|
" for real in batch_iterate(batch_size, train_images):\n",
|
|
" \n",
|
|
" # TODO Train Discriminator\n",
|
|
" D_loss,D_grads = D_loss_grad(gen, disc,mx.array(real), batch_size, z_dim)\n",
|
|
"\n",
|
|
" # Update optimizer\n",
|
|
" disc_opt.update(disc, D_grads)\n",
|
|
" \n",
|
|
" # Update gradients\n",
|
|
" mx.eval(disc.parameters(), disc_opt.state)\n",
|
|
"\n",
|
|
" # TODO Train Generator\n",
|
|
" G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)\n",
|
|
" \n",
|
|
" # Update optimizer\n",
|
|
" gen_opt.update(gen, G_grads)\n",
|
|
" \n",
|
|
" # Update gradients\n",
|
|
" mx.eval(gen.parameters(), gen_opt.state)\n",
|
|
" \n",
|
|
" \n",
|
|
" if (cur_step + 1) % display_step == 0:\n",
|
|
" print(f\"Step {epoch}: Generator loss: {G_loss}, discriminator loss: {D_loss}\")\n",
|
|
" fake_noise = mx.array(get_noise(batch_size, z_dim))\n",
|
|
" fake = gen(fake_noise)\n",
|
|
" show_images(fake)\n",
|
|
" show_images(real)\n",
|
|
" cur_step += 1"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "base",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.10"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|