2024-07-26 21:07:40 +08:00
|
|
|
{
|
|
|
|
"cells": [
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"# Import Library"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 07:37:09 +08:00
|
|
|
"execution_count": 573,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"import mnist"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 07:37:09 +08:00
|
|
|
"execution_count": 574,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"import mlx.core as mx\n",
|
|
|
|
"import mlx.nn as nn\n",
|
|
|
|
"import mlx.optimizers as optim\n",
|
|
|
|
"\n",
|
|
|
|
"from tqdm import tqdm\n",
|
|
|
|
"import numpy as np\n",
|
|
|
|
"import matplotlib.pyplot as plt"
|
|
|
|
]
|
|
|
|
},
|
2024-07-30 07:17:12 +08:00
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 07:37:09 +08:00
|
|
|
"execution_count": 575,
|
2024-07-30 07:17:12 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"# mx.set_default_device(mx.gpu)"
|
|
|
|
]
|
|
|
|
},
|
2024-07-26 21:07:40 +08:00
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"# GAN Architecture"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"## Generator 👨🏻🎨"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 07:37:09 +08:00
|
|
|
"execution_count": 576,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"def GenBlock(in_dim:int,out_dim:int):\n",
|
|
|
|
" \n",
|
|
|
|
" return nn.Sequential(\n",
|
|
|
|
" nn.Linear(in_dim,out_dim),\n",
|
2024-07-30 07:06:52 +08:00
|
|
|
" nn.BatchNorm(out_dim, 0.8),\n",
|
|
|
|
" nn.LeakyReLU(0.2)\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" )"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 07:37:09 +08:00
|
|
|
"execution_count": 577,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"class Generator(nn.Module):\n",
|
|
|
|
"\n",
|
2024-07-30 07:06:52 +08:00
|
|
|
" def __init__(self, z_dim:int = 10, im_dim:int = 784, hidden_dim: int = 256):\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" super(Generator, self).__init__()\n",
|
|
|
|
" # Build the neural network\n",
|
|
|
|
" self.gen = nn.Sequential(\n",
|
|
|
|
" GenBlock(z_dim, hidden_dim),\n",
|
|
|
|
" GenBlock(hidden_dim, hidden_dim * 2),\n",
|
|
|
|
" GenBlock(hidden_dim * 2, hidden_dim * 4),\n",
|
|
|
|
" GenBlock(hidden_dim * 4, hidden_dim * 8),\n",
|
|
|
|
"\n",
|
|
|
|
"\n",
|
|
|
|
" nn.Linear(hidden_dim * 8,im_dim),\n",
|
|
|
|
" )\n",
|
|
|
|
" \n",
|
|
|
|
" def __call__(self, noise):\n",
|
2024-07-30 07:06:52 +08:00
|
|
|
" x = self.gen(noise)\n",
|
|
|
|
" return mx.tanh(x)"
|
2024-07-26 21:07:40 +08:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 07:37:09 +08:00
|
|
|
"execution_count": 578,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"text/plain": [
|
|
|
|
"Generator(\n",
|
|
|
|
" (gen): Sequential(\n",
|
|
|
|
" (layers.0): Sequential(\n",
|
2024-07-30 07:06:52 +08:00
|
|
|
" (layers.0): Linear(input_dims=100, output_dims=256, bias=True)\n",
|
|
|
|
" (layers.1): BatchNorm(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
|
|
" (layers.2): LeakyReLU()\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" )\n",
|
|
|
|
" (layers.1): Sequential(\n",
|
2024-07-30 07:06:52 +08:00
|
|
|
" (layers.0): Linear(input_dims=256, output_dims=512, bias=True)\n",
|
|
|
|
" (layers.1): BatchNorm(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
|
|
" (layers.2): LeakyReLU()\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" )\n",
|
|
|
|
" (layers.2): Sequential(\n",
|
2024-07-30 07:06:52 +08:00
|
|
|
" (layers.0): Linear(input_dims=512, output_dims=1024, bias=True)\n",
|
|
|
|
" (layers.1): BatchNorm(1024, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
|
|
" (layers.2): LeakyReLU()\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" )\n",
|
|
|
|
" (layers.3): Sequential(\n",
|
2024-07-30 07:06:52 +08:00
|
|
|
" (layers.0): Linear(input_dims=1024, output_dims=2048, bias=True)\n",
|
|
|
|
" (layers.1): BatchNorm(2048, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
|
|
" (layers.2): LeakyReLU()\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" )\n",
|
2024-07-30 07:06:52 +08:00
|
|
|
" (layers.4): Linear(input_dims=2048, output_dims=784, bias=True)\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" )\n",
|
|
|
|
")"
|
|
|
|
]
|
|
|
|
},
|
2024-07-30 07:37:09 +08:00
|
|
|
"execution_count": 578,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"output_type": "execute_result"
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"gen = Generator(100)\n",
|
|
|
|
"gen"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 07:37:09 +08:00
|
|
|
"execution_count": 579,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
2024-07-30 00:44:16 +08:00
|
|
|
"# make 2D noise with shape n_samples x z_dim\n",
|
|
|
|
"def get_noise(n_samples:list[int], z_dim:int)->list[int]:\n",
|
|
|
|
" return mx.random.normal(shape=(n_samples, z_dim))"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 07:37:09 +08:00
|
|
|
"execution_count": 580,
|
2024-07-30 00:44:16 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"data": {
|
2024-07-30 07:37:09 +08:00
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAWPUlEQVR4nO3ca2zW9d3H8U8VBqUW5NDSzohFOpBCy6GcJgcngm7iWGEZBBPNItlWt+yQxejiRjIzD8nAxI0tSzYNE+dp2YA4poIgWhyUldUhWg4ddIUiLeVYoBxaet3Pvske9fr8Huy+c+f9eny9/5eDls/+T745mUwmIwAAJF33v/0fAAD4v4NRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQOiT7Qefe+45++H79++3m7a2NruRpPvuu89u/va3v9nNHXfcYTdnz561m507d9qNJBUXF9vNJ598YjfTp0+3mylTptiNJH344Yd209LSYjcXLlywm9LSUrtJVVJSYjdvvvmm3Tz22GN284Mf/MBuHnjgAbuRpC1bttjN7bffbjfvvvuu3aT8HUnSvn377GbmzJl28+STT/b6Gd4UAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQMjJZDKZbD5499132w+fNGmS3cyYMcNuJGnTpk12k3KobtCgQXaTl5dnNynH2SSps7PTbm677Ta7yfLH5j/MmjXLbiTphRdesJuUY2vr1q2zm5SjZCmHIqW0n9e9e/fazb333ms3Q4cOtZvCwkK7kdJ+N2pra+3m7bfftptvf/vbdiOl/T796Ec/spuurq5eP8ObAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAh9sv1gdXW1/fCdO3fazTvvvGM3ktTa2mo3S5YssZuGhga7ue46f3ubmprsRko7ZnblypX/SpNylEySRowYYTcpf0/Hjx+3m8OHD9vNsWPH7EaSvva1r9nN+fPn7ebMmTN2k3IQb+PGjXYjpf0bsWHDBru5du2a3eTn59uNJPXpk/U/xWHp0qVJ39Ub3hQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAACHr03yvvfaa/fABAwbYzfjx4+1Gkj766CO7aWtrs5v29na7GTNmjN2kSrleunv3brvZtWuX3fz0pz+1G0maOHGi3ezZs8duli1bZjcpP0M1NTV2I0l5eXl2U1JSYjeVlZV209LSYjfjxo2zG0kqLy+3m5UrV9pNZ2en3Vy6dMluJKmwsNBuDhw4kPRdveFNAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAISsD+L16ZP1R8OkSZPs5umnn7YbSVqxYoXdnDhxwm7OnTtnNymH1lIO70nSli1b7GbJkiV2M2TIELvp27ev3UjS73//e7spKyuzmzVr1thNV1eX3Tz66KN2I6UdOxw5cqTdzJgxw242bdpkNwcPHrQbSdqxY4fdNDU12c3YsWPt5p///KfdSFJra6vdfP/730/6rt7wpgAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAABCTiaTyWTzwXfffdd+eENDg92kHIaSpPLycrt588037WbKlCl2M2jQILsZOHCg3aR+V8rBuZ6eHrs5c+aM3UhSUVGR3fTv399uRo8ebTdvvfWW3eTl5dmNlPbfl3LIsrS01G6Ki4vtJuXvVZJefvnl/8p3ffzxx3bT2dlpN5J0yy232M0HH3xgN9u3b+/1M7wpAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgJD1tayVK1faD583b57d9O3b124k6ejRo3aTm5trN42NjXZTVVVlN21tbXYjSb/+9a/tZunSpXZz+fJlu6mvr7cbKe3IWElJid2k/Ox961vfspvNmzfbjZR2hPDkyZN2M23aNLv5xz/+YTeXLl2yGyntd6O6utpuCgsL7eajjz6yGyntZ+KJJ55I+q7e8KYAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAQtYH8WbMmGE/vLW11W5ycnLsRko70LZr1y67+fnPf2433d3ddnPt2jW7kaSJEyfaTcqRv5QjeqkH0E6cOGE3/fr1s5u8vDy7eeaZZ+xm/vz5diNJZWVldrNx40a7OXDggN2kOH36dFI3atQou/nlL39pNxcvXrSbW265xW4kaezYsXZz5cqVpO/qDW8KAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAICQk8lkMtl88Jvf/Kb98N/97nd289prr9mNlHbx9OTJk3Zz00032U3K5ddt27bZjSQ9/PDDdrN161a7KSwstJuKigq7kaSioiK7OXLkiN3k5ubaTX19vd0MGTLEbiTp/PnzdpPy3zd79my7Sfk7euWVV+xGkrL8J+s/jB8/3m5SLhXff//9diOl/d6m2Lt3b6+f4U0BABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAAhD7ZfrCkpMR++JYtW+wm5TibJE2aNMluUo7OHTp0yG6WLVtmN8XFxXYjSbfeeqvdpByPW7Vqld2MHDnSbiRp+vTpdnPmzBm7aWhosJuUQ5G7d++2G0lasGCB3bS1tdlNZ2en3aT8eS9atMhuJGnTpk12k/J7cfHiRbtJ+fdBkh5//HG7uXTpUtJ39YY3BQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABCyPohXUFBgP7y9vd1uxo4dazeSNHDgQLs5deqU3VRWVtrN0aNH7aalpcVuJKmxsdFuvvCFL9jNN77xDbuZMGGC3UjSH/7wB7uZM2eO3Rw4cMBuTpw4YTd5eXl2I0m1tbV2U1paajcVFRV2k8lk7Gb//v12I0mzZs2ym66uLrtJOQx43XVp/z/7ySeftJuysjK7eeihh3r9DG8KAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIORksrxkVVxcbD98wYIFdnP99dfbjSTt2bPHbk6ePGk3999/v928+OKLdrNmzRq7kdL+/FK+6+rVq3YzefJku5Gk8+fP282xY8fsZty4cXZTVVVlNytXrrQbSSoqKrKbV1991W7uueceuzl9+rTdDB482G4k6cc//rHdbNiwwW6OHz9uN6nHDlOOh6YcKX3kkUd6/QxvCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAkPWV1Jdeesl+eF1dnd20tLTYjSTl5+fbzdSpU+3mhRdesJs777zTboYNG2Y3kjRlyhS7uXjxot3U1NTYTaqSkhK7SbmA+/bbb9vNo48+ajeffvqp3Uhpf0+5ubl209raajeLFy+2m9raWruR0i4Bf/DBB3Zz9OhRuykvL7cbSZozZ47d1NfX282qVat6/QxvCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACD0yfqDfbL+aBg9erTd3HTTTXYjSf3797ebfv362c2SJUvspr293W5SjsBJaQcFjxw5YjfLly+3m+rqaruRpNmzZ9tNc3Oz3XzpS1+ym3379tlNR0eH3UjSXXfdZTcvv/yy3fT09NjNpUuX7Cb1MOChQ4fsZtSoUXbzuc99zm5uuOEGu5GkLVu22M3kyZOTvqs3vCkAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAkJPJZDLZfPCHP/yh/fAJEyb
|
2024-07-30 00:44:16 +08:00
|
|
|
"text/plain": [
|
|
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "display_data"
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"img = get_noise(28,28)\n",
|
|
|
|
"plt.imshow(img, cmap='gray')\n",
|
|
|
|
"plt.axis('off')\n",
|
|
|
|
"plt.show()"
|
2024-07-26 21:07:40 +08:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"## Discriminator 🕵🏻♂️"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 07:37:09 +08:00
|
|
|
"execution_count": 581,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"def DisBlock(in_dim:int,out_dim:int):\n",
|
|
|
|
" return nn.Sequential(\n",
|
|
|
|
" nn.Linear(in_dim,out_dim),\n",
|
2024-07-30 00:44:16 +08:00
|
|
|
" nn.LeakyReLU(negative_slope=0.2),\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" )"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 07:37:09 +08:00
|
|
|
"execution_count": 582,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"class Discriminator(nn.Module):\n",
|
|
|
|
"\n",
|
2024-07-30 07:37:09 +08:00
|
|
|
" def __init__(self,im_dim:int = 784, hidden_dim:int = 256):\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" super(Discriminator, self).__init__()\n",
|
|
|
|
"\n",
|
|
|
|
" self.disc = nn.Sequential(\n",
|
2024-07-30 07:37:09 +08:00
|
|
|
" DisBlock(im_dim, hidden_dim * 4),\n",
|
|
|
|
" DisBlock(hidden_dim * 4, hidden_dim * 2),\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" DisBlock(hidden_dim * 2, hidden_dim),\n",
|
2024-07-30 07:37:09 +08:00
|
|
|
" \n",
|
|
|
|
" nn.Dropout(0.3),\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" nn.Linear(hidden_dim,1),\n",
|
2024-07-30 07:37:09 +08:00
|
|
|
" # nn.Sigmoid()\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" )\n",
|
|
|
|
" \n",
|
|
|
|
" def __call__(self, noise):\n",
|
2024-07-30 07:37:09 +08:00
|
|
|
" x = noise + 1.0\n",
|
|
|
|
" x = self.disc(noise)\n",
|
|
|
|
" out = mx.log(mx.softmax(x)) \n",
|
|
|
|
" return out"
|
2024-07-26 21:07:40 +08:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 07:37:09 +08:00
|
|
|
"execution_count": 583,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"text/plain": [
|
|
|
|
"Discriminator(\n",
|
|
|
|
" (disc): Sequential(\n",
|
|
|
|
" (layers.0): Sequential(\n",
|
2024-07-30 07:37:09 +08:00
|
|
|
" (layers.0): Linear(input_dims=784, output_dims=1024, bias=True)\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" (layers.1): LeakyReLU()\n",
|
|
|
|
" )\n",
|
|
|
|
" (layers.1): Sequential(\n",
|
2024-07-30 07:37:09 +08:00
|
|
|
" (layers.0): Linear(input_dims=1024, output_dims=512, bias=True)\n",
|
|
|
|
" (layers.1): LeakyReLU()\n",
|
|
|
|
" )\n",
|
|
|
|
" (layers.2): Sequential(\n",
|
|
|
|
" (layers.0): Linear(input_dims=512, output_dims=256, bias=True)\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" (layers.1): LeakyReLU()\n",
|
|
|
|
" )\n",
|
2024-07-30 07:37:09 +08:00
|
|
|
" (layers.3): Dropout(p=0.30000000000000004)\n",
|
|
|
|
" (layers.4): Linear(input_dims=256, output_dims=1, bias=True)\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" )\n",
|
|
|
|
")"
|
|
|
|
]
|
|
|
|
},
|
2024-07-30 07:37:09 +08:00
|
|
|
"execution_count": 583,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"output_type": "execute_result"
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"disc = Discriminator()\n",
|
|
|
|
"disc"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"# Model Training 🏋🏻♂️"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
2024-07-30 00:44:16 +08:00
|
|
|
"cell_type": "markdown",
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
2024-07-30 00:44:16 +08:00
|
|
|
"### Losses"
|
2024-07-26 21:07:40 +08:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
2024-07-30 00:44:16 +08:00
|
|
|
"#### Discriminator Loss"
|
2024-07-26 21:07:40 +08:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 07:37:09 +08:00
|
|
|
"execution_count": 584,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
2024-07-26 21:36:29 +08:00
|
|
|
"def disc_loss(gen, disc, real, num_images, z_dim):\n",
|
2024-07-30 00:44:16 +08:00
|
|
|
" \n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" noise = mx.array(get_noise(num_images, z_dim))\n",
|
|
|
|
" fake_images = gen(noise)\n",
|
2024-07-27 06:09:51 +08:00
|
|
|
" \n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" fake_disc = disc(fake_images)\n",
|
|
|
|
" \n",
|
2024-07-27 05:19:08 +08:00
|
|
|
" fake_labels = mx.zeros((fake_images.shape[0],1))\n",
|
2024-07-30 00:44:16 +08:00
|
|
|
" \n",
|
2024-07-30 07:06:52 +08:00
|
|
|
" fake_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels)\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" \n",
|
|
|
|
" real_disc = disc(real)\n",
|
2024-07-27 05:19:08 +08:00
|
|
|
" real_labels = mx.ones((real.shape[0],1))\n",
|
|
|
|
"\n",
|
2024-07-30 07:06:52 +08:00
|
|
|
" real_loss = nn.losses.binary_cross_entropy(real_disc,real_labels)\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
"\n",
|
2024-07-30 07:06:52 +08:00
|
|
|
" disc_loss = (fake_loss + real_loss) / 2\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
"\n",
|
|
|
|
" return disc_loss"
|
|
|
|
]
|
|
|
|
},
|
2024-07-30 00:44:16 +08:00
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"#### Generator Loss"
|
|
|
|
]
|
|
|
|
},
|
2024-07-26 21:07:40 +08:00
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 07:37:09 +08:00
|
|
|
"execution_count": 585,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
2024-07-26 21:36:29 +08:00
|
|
|
"def gen_loss(gen, disc, num_images, z_dim):\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
"\n",
|
|
|
|
" noise = mx.array(get_noise(num_images, z_dim))\n",
|
2024-07-30 07:06:52 +08:00
|
|
|
" \n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" fake_images = gen(noise)\n",
|
|
|
|
" fake_disc = disc(fake_images)\n",
|
|
|
|
"\n",
|
2024-07-27 05:19:08 +08:00
|
|
|
" fake_labels = mx.ones((fake_images.shape[0],1))\n",
|
2024-07-30 00:44:16 +08:00
|
|
|
" \n",
|
2024-07-30 07:06:52 +08:00
|
|
|
" gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels)\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
"\n",
|
|
|
|
" return gen_loss"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 07:37:09 +08:00
|
|
|
"execution_count": 586,
|
2024-07-30 00:44:16 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"# Get only the training images\n",
|
|
|
|
"train_images,*_ = map(np.array, mnist.mnist())"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 07:37:09 +08:00
|
|
|
"execution_count": 587,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
2024-07-30 00:44:16 +08:00
|
|
|
"# Normalize the images to fall between -1,1\n",
|
|
|
|
"train_images = train_images * 2.0 - 1.0"
|
2024-07-26 21:07:40 +08:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 07:37:09 +08:00
|
|
|
"execution_count": 588,
|
2024-07-29 06:24:50 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"text/plain": [
|
2024-07-30 07:37:09 +08:00
|
|
|
"<matplotlib.image.AxesImage at 0x156eb0df0>"
|
2024-07-29 06:24:50 +08:00
|
|
|
]
|
|
|
|
},
|
2024-07-30 07:37:09 +08:00
|
|
|
"execution_count": 588,
|
2024-07-29 06:24:50 +08:00
|
|
|
"metadata": {},
|
|
|
|
"output_type": "execute_result"
|
2024-07-30 00:44:16 +08:00
|
|
|
},
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaEAAAGdCAYAAAC7EMwUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAaI0lEQVR4nO3df2jU9x3H8dfVH1d1lytBk7vUmGVF202dpWrVYP3R1cxApf4oWMtGZEPa+YOJ/cGsDNNBjdgpRdI6V0amW239Y9a6KdUMTXRkijpdRYtYjDOdCcFM72LUSMxnf4hHz1j1e975vkueD/iCufu+vY/ffuvTby75xueccwIAwMBD1gsAAHRfRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJjpab2AW3V0dOjcuXMKBALy+XzWywEAeOScU0tLi/Ly8vTQQ3e+1km7CJ07d075+fnWywAA3Kf6+noNHDjwjvuk3afjAoGA9RIAAElwL3+fpyxCH3zwgQoLC/Xwww9r5MiR2rdv3z3N8Sk4AOga7uXv85REaPPmzVq8eLGWLVumI0eO6JlnnlFJSYnOnj2bipcDAGQoXyruoj1mzBg99dRTWrduXeyx73//+5o+fbrKy8vvOBuNRhUMBpO9JADAAxaJRJSVlXXHfZJ+JXTt2jUdPnxYxcXFcY8XFxertra20/5tbW2KRqNxGwCge0h6hM6fP6/r168rNzc37vHc3Fw1NjZ22r+8vFzBYDC28ZVxANB9pOwLE259Q8o5d9s3qZYuXapIJBLb6uvrU7UkAECaSfr3CfXv3189evTodNXT1NTU6epIkvx+v/x+f7KXAQDIAEm/Eurdu7dGjhypqqqquMerqqpUVFSU7JcDAGSwlNwxYcmSJfrpT3+qUaNGady4cfr973+vs2fP6tVXX03FywEAMlRKIjR79mw1NzfrN7/5jRoaGjRs2DDt2LFDBQUFqXg5AECGSsn3Cd0Pvk8IALoGk+8TAgDgXhEhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmelovAEgnPXr08DwTDAZTsJLkWLhwYUJzffv29Tzz+OOPe55ZsGCB55nf/va3nmfmzJnjeUaSrl696nlm5cqVnmfefvttzzNdBVdCAAAzRAgAYCbpESorK5PP54vbQqFQsl8GANAFpOQ9oaFDh+rvf/977ONEPs8OAOj6UhKhnj17cvUDALirlLwndOrUKeXl5amwsFAvvfSSTp8+/a37trW1KRqNxm0AgO4h6REaM2aMNm7cqJ07d+rDDz9UY2OjioqK1NzcfNv9y8vLFQwGY1t+fn6ylwQASFNJj1BJSYlmzZql4cOH67nnntP27dslSRs2bLjt/kuXLlUkEolt9fX1yV4SACBNpfybVfv166fhw4fr1KlTt33e7/fL7/enehkAgDSU8u8Tamtr05dffqlwOJzqlwIAZJikR+j1119XTU2N6urqdODAAb344ouKRqMqLS1N9ksBADJc0j8d9/XXX2vOnDk6f/68BgwYoLFjx2r//v0qKChI9ksBADJc0iP0ySefJPu3RJoaNGiQ55nevXt7nikqKvI8M378eM8zkvTII494npk1a1ZCr9XVfP31155n1q5d63lmxowZnmdaWlo8z0jSv//9b88zNTU1Cb1Wd8W94wAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAMz7nnLNexDdFo1EFg0HrZXQrTz75ZEJzu3fv9jzDf9vM0NHR4XnmZz/7meeZS5cueZ5JRENDQ0JzFy5c8Dxz8uTJhF6rK4pEIsrKyrrjPlwJAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwExP6wXA3tmzZxOaa25u9jzDXbRvOHDggOeZixcvep6ZPHmy5xlJunbtmueZP/3pTwm9Fro3roQAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADPcwBT63//+l9DcG2+84Xnm+eef9zxz5MgRzzNr1671PJOoo0ePep6ZMmWK55nW1lbPM0OHDvU8I0m//OUvE5oDvOJKCABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAw43POOetFfFM0GlUwGLReBlIkKyvL80xLS4vnmfXr13uekaSf//znnmd+8pOfeJ75+OOPPc8AmSYSidz1/3muhAAAZogQAMCM5wjt3btX06ZNU15ennw+n7Zu3Rr3vHNOZWVlysvLU58+fTRp0iQdP348WesFAHQhniPU2tqqESNGqKKi4rbPr1q1SmvWrFFFRYUOHjyoUCikKVOmJPR5fQBA1+b5J6uWlJSopKTkts855/Tee+9p2bJlmjlzpiRpw4YNys3N1aZNm/TKK6/c32oBAF1KUt8TqqurU2Njo4qLi2OP+f1+TZw4UbW1tbedaWtrUzQajdsAAN1DUiPU2NgoScrNzY17PDc3N/bcrcrLyxUMBmNbfn5+MpcEAEhjKfnqOJ/PF/exc67TYzctXbpUkUgkttXX16diSQCANOT5PaE7CYVCkm5cEYXD4djjTU1Nna6ObvL7/fL7/clcBgAgQyT1SqiwsFChUEhVVVWxx65du6aamhoVFRUl86UAAF2A5yuhS5cu6auvvop9XFdXp6NHjyo7O1uDBg3S4sWLtWLFCg0ePFiDBw/WihUr1LdvX7388stJXTgAIPN5jtChQ4c0efLk2MdLliyRJJWWluqPf/yj3nzzTV25ckXz58/XhQsXNGbMGO3atUuBQCB5qwYAdAncwBRd0rvvvpvQ3M1/VHlRU1Pjeea5557zPNPR0eF5BrDEDUwBAGmNCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZriLNrqkfv36JTT317/+1fPMxIkTPc+UlJR4ntm1a5fnGcASd9EGAKQ1IgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMNzAFvuGxxx7zPPOvf/3L88zFixc9z+zZs8fzzKFDhzzPSNL777/veSbN/ipBGuAGpgCAtEaEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmOEGpsB9mjFjhueZyspKzzOBQMDzTKLeeustzzMbN270PNPQ0OB5BpmDG5gCANIaEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGG5gCBoYNG+Z5Zs2aNZ5nfvSjH3meSdT69es9z7zzzjueZ/773/96noENbmAKAEhrRAgAYMZzhPbu3atp06YpLy9PPp9PW7dujXt+7ty58vl8cdvYsWOTtV4AQBfiOUKtra0aMWKEKioqvnWfqVOnqqGhIbbt2LHjvhYJAOiaenodKCkpUUlJyR338fv9CoVCCS8KANA9pOQ9oerqauXk5GjIkCGaN2+empqavnXftrY2RaPRuA0A0D0kPUIlJSX66KOPtHv3bq1evVoHDx7Us88+q7a2ttvuX15ermAwGNvy8/OTvSQAQJry/Om4u5k9e3bs18OGDdOoUaNUUFCg7du3a+bMmZ32X7p0qZYsWRL7OBqNEiI
|
|
|
|
"text/plain": [
|
|
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "display_data"
|
2024-07-29 06:24:50 +08:00
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
2024-07-30 00:44:16 +08:00
|
|
|
"plt.imshow(train_images[0].reshape(28,28),cmap='gray')"
|
2024-07-29 06:24:50 +08:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 07:37:09 +08:00
|
|
|
"execution_count": 589,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
2024-07-30 00:44:16 +08:00
|
|
|
"def batch_iterate(batch_size: int, ipt: list[int])-> list[int]:\n",
|
2024-07-29 06:24:50 +08:00
|
|
|
" perm = np.random.permutation(len(ipt))\n",
|
|
|
|
" for s in range(0, len(ipt), batch_size):\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" ids = perm[s : s + batch_size]\n",
|
2024-07-30 00:44:16 +08:00
|
|
|
" yield ipt[ids]"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 07:37:09 +08:00
|
|
|
"execution_count": 590,
|
2024-07-30 00:44:16 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"def show_images(imgs:list[int],num_imgs:int = 25):\n",
|
|
|
|
" if (imgs.shape[0] > 0): \n",
|
|
|
|
" fig,axes = plt.subplots(5, 5, figsize=(5, 5))\n",
|
|
|
|
" \n",
|
|
|
|
" for i, ax in enumerate(axes.flat):\n",
|
|
|
|
" img = mx.array(imgs[i]).reshape(28,28)\n",
|
|
|
|
" ax.imshow(img,cmap='gray')\n",
|
|
|
|
" ax.axis('off')\n",
|
|
|
|
" plt.show()"
|
2024-07-26 21:07:40 +08:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
2024-07-30 00:44:16 +08:00
|
|
|
"### show first batch of images"
|
2024-07-26 21:07:40 +08:00
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 07:37:09 +08:00
|
|
|
"execution_count": 591,
|
2024-07-26 21:07:40 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"data": {
|
2024-07-30 07:37:09 +08:00
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAADhpklEQVR4nOy9149cWXblvW7cG/aG9yYjvc+kLZLFcl0O1dVOagjd063BzGBmgJkHvc7L/BF6nIcBBA0gYQZSC2pJ3Wq1767qKlZXFU3RpiHTm/Dee/M98NuHkckkmSTTRGaeH0BUgQwGI26ee/c5e6+9ttBut9vgcDgcDmcPURz2B+BwOBzO8YMHFw6Hw+HsOTy4cDgcDmfP4cGFw+FwOHsODy4cDofD2XN4cOFwOBzOnsODC4fD4XD2HB5cOBwOh7Pn8ODC4XA4nD1H2u0LBUHYz89xpHgRUwN+/R7Br9/L8aKmGvwaPoKvwZdjN9ePn1w4HA6Hs+fw4MLhcDicPWfXabGDQhAECIIAo9EIu90OtVoNk8kElUoFrVYLjUaDRqOBSqWCarWK5eVlxGIxtFotNJvNw/74HA6Hw0EXBhdRFCGKInp6evDmm2/CYrFgenoaVqsVLpcLLpcLpVIJkUgEqVQKf/3Xf40rV66gWq2iXC4f9sfncDgcDrosuAiCwE4ndrsdXq8XFosFVqsVZrMZNpsNdrsdlUoFAKBSqWCz2WA2m5HP51GpVF642HncEEURSqUSgiBAkiQIggC1Wg21Wg3gUXGSrhed/JrNJiqVCprNJur1OhqNxqF9Bw6Hc3TpquAiSRImJycxODiI06dP48MPP4QkSUgkEshkMtBqtVCr1RBFEXa7HbIs47XXXoNWq8X8/Dz++Mc/ol6vH/bX6AosFgv6+vqg1WrhcDggyzKmpqYwNTUFQRCgUDwst7VaLQBAPp9n1/nGjRuIxWLY3NxEMBg8zK/B4XCOKF0VXERRhNPpxNDQEIaHhzEyMoJms4l0Oo1yuYxSqYRSqQSNRgODwQBJktDT04NisYhkMglRFHlw+f/RarVwOp0wGAzo7e2FyWTCa6+9hjfffBMKhWLLyaXdbiOZTCIYDCIWi7FrmU6nD/lbcDico0pXBBelUgmTyQSj0YipqSm8+uqr0Ov1WFtbQyaTwW9/+1sEAgGYzWaYzWb09PTgnXfegV6vh9frhU6nQzabxVdffYV8Po9MJoNarXbYX+vAUCgUUKlUUCqV8Pl8sFqtGB4exsWLF6HX62G1WqHVauHz+QDsrFHXarWw2+1otVowm83IZDIshcbhcDjPS1cEF7VaDbfbDbvdjvPnz+Ptt99GMBjE/Pw8Njc38bOf/Qz379+HSqWCWq3G+fPn0d/fj97eXvj9foyPjyOVSqG/vx+JRAKlUulEBRdRFKHT6aDT6XD69GmMjo5ienoa7733HrRaLSRJgkKhYKkw4PEAQ7UuhUIBm82GbDYLjUYDQRB4HYvD4Tw3XRFcaOetUqnQbrdRrVaRTqexurqKUCiEfD6PWq2GdruNZrOJYrGITCYDo9EIo9EISZKg1WphsVhQr9chSV3xtQ4MSZJgNpthMBjg9XrR19cHh8MBjUYDlUoFURS3vL7ZbKJarQIA+/NqtYpSqYRUKoVYLIZoNIpiscgDC4fzElBrBW3S6NdBIUkSRFGEVquF1WqFUqlktet4PI5gMMiEPHv9ubriKSxJEoxGIwwGAwqFAgKBAL766iv84z/+I9LpNKLRKNrtNlMvJZNJzM/Po1AowGQyweFwwGazYWpqCmazGcvLy8hkMof9tQ4MvV6PiYkJOJ1OvPvuu7h06RI0Gg07iWynUqkgEomg1WrB6XTCaDQikUhgYWEBGxsb+Pjjj7G4uIhisXgI34bDOR5QbVOtVkOSJDQaDdRqNbRaLSak2W9kWYbRaMTAwAA++OAD2O12jI6Owul04ic/+Qn+6q/+Cvl8HoVCYc+VoV0RXDoplUpIp9NIJpMIh8PI5XJsl01Rv1arIZPJQK/Xs/SXSqWC0WhELpd7bKd+XKFUl1arhc1mg9PphNPphMPh2PH19Xod9XodhUIB6XQarVYLBoMBOp0OxWIR8Xgc8XgciUQCyWTywG4AzvGEVIm0cwce7aQ7aTQa7IF7XKTvlI0RRRGyLEOj0aBcLjPJ/0GcYARBgEqlgizLsFqt6Ovrg9vtxsTEBDweD27dugVZltFoNFAqlfb83++K4FIsFrG4uIjNzU1ks1l8+eWXWF9fRzqdRrVafewhl06n8cUXX8DlcmF6ehqTk5PQaDSw2WwoFApQKpWH9E0OFq/Xi97eXvT19eGb3/wmXC4Xenp6Hntdo9FAvV7H559/js8//xylUgmJRAKSJOGtt97C6OgoZmdnceXKFSQSCaRSKbRaLZ4S47wUZrMZPp8PKpUKOp0OSqWS3a+UJqrX65ifn0coFML6+jpmZmaOdIChQGq1WnHu3DlYrVZMTEygt7cX9+7dwy9/+Utks1mk02nWr7efn6W3txfT09MYHh7GqVOnYLVaYTAY0G634Xa78dprryESieDq1atIpVJ7+u93RXCpVqsIBoMQBAGRSARqtZrVAHZ6wOXzeczOziIajTK5rFKphMFggCzLJ+LkIggCbDYbxsbGmDLM5XJBpVJteR3dwPV6HXfv3sU//uM/olKpoFAoQKPRQK/XQxRFzM3N4caNG8jn88jn8zywcF4avV4Pv98PnU4Hq9UKnU6Hb3zjG/j6178OhUKBdruNYrGI3//+95idnYVCocD8/PyRDy4KhQJ6vR7T09Pw+/145513cObMGfz617/GzZs3ATx8hh3EZ3E4HJiYmMDg4CD6+/thNBoBPHwuWK1WTE1NwWAw4N69e8czuBDtdpstrEaj8cQHHD0wa7UaqtUq6vX6ifEVkySJNZBOTk7i7Nmz8Hq9LKiS6KFcLiOXy6FcLrMT4YMHD5DNZlGv11kX/v3791Gv17G2toZsNotyuXxiriVnf5FlGf39/TCZTPD5fDAYDHC73QAepbglSYLH40Gj0UAmk0F/fz/y+TySySRLhx8ltFot9Ho93G43RkZG4Pf7YTabARycZX+nG4fL5UJ/fz9cLtdjQqdisch62/ZDXdtVwQV4eIohZdiToIdnqVRCuVxGuVx+5t85LqjVakxMTKCnpwfvvvsuPvjgA9ZUSk2kjUYD8Xgci4uLiMfj+N3vfofNzU1sbGwgGo2ylJcgCPjDH/6Azz77jBUb2+02r7Vw9gSXy4XLly/D6XRiYmICFouFnazpXiVXjpGREajVakSjUUSjUdy4ceNIBhej0Yje3l5MTEzgrbfeQn9//4H3i0mSBJPJBL1ej9HRUVy6dAk6ne6xz5FMJjE7O8vaN/b8c+z5O+4BuwkStEOvVConwleMvML0ej1cLhe8Xi/sdjsMBgOUSiUUCgVarRby+Twz9gwEAojH44hEIojFYsjn81tSDu12m5t9cvYcpVIJSZKg1+tZc7Rer4der3/stbTLViqVMBqNcDqdAAC32/3MloJyuczk8t2yIVKr1TCbzTCZTJBlGTqd7sA/A6XlqD1Bp9OxnrVO6vU6isUiSqXSvly/rgwuu6XRaGB5eRnXrl1jwea4Qjsit9uNb33rW5icnITT6YRWq2Vpwmq1ii+//BLz8/NYXV3FV199hUKhgGQyyUYUcDj7iSiK8Pl8TPLa19cHi8Xy1N07Kcp6e3vx/vvvo1gs4ty5c2w3/aQH3507d/CHP/wBlUqla9K5Pp8Ply5dQl9f36EEFgDQ6XQ4e/Ys/H4/hoeHYTQaIYriY20J5XIZ0WgUuVxuX2yzjnRwabfbyGQyCIfD0Gq1kGX5sD/SvqFWq2G1WuF0OjEwMICxsTEm66QmKBJG3L9/H0tLS7h79+6JciroFrbvEI/ziXo7NIvJ5XLBbrez9MyzFJz09/r6+lCtVmG1Wre0IOxENpuFVqtFq9Xad+XVbhAEAbIsw+PxwG6373jyOoi6i1KphNPphM/nY6nInf7der3Oygr85HKCsdlsOH/+PLxeL2w2G1u49XodqVQK9+7dQzwex7V
|
2024-07-26 21:07:40 +08:00
|
|
|
"text/plain": [
|
2024-07-29 06:24:50 +08:00
|
|
|
"<Figure size 500x500 with 25 Axes>"
|
2024-07-26 21:07:40 +08:00
|
|
|
]
|
|
|
|
},
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "display_data"
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
2024-07-30 00:44:16 +08:00
|
|
|
"X = batch_iterate(25, train_images)\n",
|
|
|
|
"for x in X: \n",
|
|
|
|
" show_images(x)\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
" break"
|
|
|
|
]
|
|
|
|
},
|
2024-07-30 07:06:52 +08:00
|
|
|
{
|
|
|
|
"cell_type": "markdown",
|
|
|
|
"metadata": {},
|
|
|
|
"source": [
|
|
|
|
"### Training Cycle"
|
|
|
|
]
|
|
|
|
},
|
2024-07-26 21:07:40 +08:00
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 07:37:09 +08:00
|
|
|
"execution_count": 592,
|
2024-07-27 06:09:51 +08:00
|
|
|
"metadata": {},
|
2024-07-29 06:24:50 +08:00
|
|
|
"outputs": [],
|
2024-07-27 06:09:51 +08:00
|
|
|
"source": [
|
2024-07-30 07:37:09 +08:00
|
|
|
"lr = 2e-8\n",
|
2024-07-30 07:17:12 +08:00
|
|
|
"z_dim = 64\n",
|
2024-07-30 07:06:52 +08:00
|
|
|
"\n",
|
2024-07-27 06:09:51 +08:00
|
|
|
"gen = Generator(z_dim)\n",
|
|
|
|
"mx.eval(gen.parameters())\n",
|
2024-07-30 07:37:09 +08:00
|
|
|
"gen_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.9]) #,betas=[0.5, 0.9]\n",
|
2024-07-27 06:09:51 +08:00
|
|
|
"\n",
|
|
|
|
"disc = Discriminator()\n",
|
|
|
|
"mx.eval(disc.parameters())\n",
|
2024-07-30 07:37:09 +08:00
|
|
|
"disc_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.9])"
|
2024-07-29 06:30:08 +08:00
|
|
|
]
|
|
|
|
},
|
2024-07-27 06:09:51 +08:00
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2024-07-30 07:37:09 +08:00
|
|
|
"execution_count": 593,
|
2024-07-27 06:09:51 +08:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"name": "stderr",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
2024-07-30 07:17:12 +08:00
|
|
|
" 0%| | 0/200 [00:00<?, ?it/s]"
|
2024-07-29 00:18:35 +08:00
|
|
|
]
|
2024-07-26 21:07:40 +08:00
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
2024-07-30 00:44:16 +08:00
|
|
|
"# Set your parameters\n",
|
|
|
|
"n_epochs = 200\n",
|
|
|
|
"display_step = 5000\n",
|
2024-07-28 06:10:19 +08:00
|
|
|
"cur_step = 0\n",
|
2024-07-30 00:44:16 +08:00
|
|
|
"\n",
|
2024-07-30 07:37:09 +08:00
|
|
|
"batch_size = 64\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
"\n",
|
2024-07-28 06:10:19 +08:00
|
|
|
"D_loss_grad = nn.value_and_grad(disc, disc_loss)\n",
|
|
|
|
"G_loss_grad = nn.value_and_grad(gen, gen_loss)\n",
|
|
|
|
"\n",
|
|
|
|
"\n",
|
2024-07-30 00:44:16 +08:00
|
|
|
"for epoch in tqdm(range(n_epochs)):\n",
|
2024-07-28 06:10:19 +08:00
|
|
|
"\n",
|
2024-07-30 00:44:16 +08:00
|
|
|
" for real in batch_iterate(batch_size, train_images):\n",
|
2024-07-28 06:10:19 +08:00
|
|
|
" \n",
|
2024-07-30 00:44:16 +08:00
|
|
|
" # TODO Train Discriminator\n",
|
2024-07-30 07:06:52 +08:00
|
|
|
" D_loss,D_grads = D_loss_grad(gen, disc,mx.array(real), batch_size, z_dim)\n",
|
2024-07-28 06:10:19 +08:00
|
|
|
"\n",
|
|
|
|
" # Update optimizer\n",
|
|
|
|
" disc_opt.update(disc, D_grads)\n",
|
|
|
|
" \n",
|
|
|
|
" # Update gradients\n",
|
|
|
|
" mx.eval(disc.parameters(), disc_opt.state)\n",
|
2024-07-26 21:07:40 +08:00
|
|
|
"\n",
|
2024-07-30 00:44:16 +08:00
|
|
|
" # TODO Train Generator\n",
|
|
|
|
" G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)\n",
|
2024-07-28 06:10:19 +08:00
|
|
|
" \n",
|
|
|
|
" # Update optimizer\n",
|
|
|
|
" gen_opt.update(gen, G_grads)\n",
|
|
|
|
" \n",
|
|
|
|
" # Update gradients\n",
|
2024-07-30 07:37:09 +08:00
|
|
|
" mx.eval(gen.parameters(), gen_opt.state) \n",
|
|
|
|
" \n",
|
|
|
|
" print('Losses D={0} G={1}'.format(D_loss,G_loss))\n",
|
|
|
|
" \n",
|
|
|
|
" if (cur_step + 1) % display_step == 0:\n",
|
|
|
|
" print(f\"Step {epoch}: Generator loss: {G_loss}, discriminator loss: {D_loss}\")\n",
|
|
|
|
" fake_noise = mx.array(get_noise(batch_size, z_dim))\n",
|
|
|
|
" fake = gen(fake_noise)\n",
|
|
|
|
" show_images(fake)\n",
|
|
|
|
" show_images(real)\n",
|
|
|
|
" cur_step += 1"
|
2024-07-27 05:19:08 +08:00
|
|
|
]
|
2024-07-26 21:07:40 +08:00
|
|
|
}
|
|
|
|
],
|
|
|
|
"metadata": {
|
|
|
|
"kernelspec": {
|
|
|
|
"display_name": "base",
|
|
|
|
"language": "python",
|
|
|
|
"name": "python3"
|
|
|
|
},
|
|
|
|
"language_info": {
|
|
|
|
"codemirror_mode": {
|
|
|
|
"name": "ipython",
|
|
|
|
"version": 3
|
|
|
|
},
|
|
|
|
"file_extension": ".py",
|
|
|
|
"mimetype": "text/x-python",
|
|
|
|
"name": "python",
|
|
|
|
"nbconvert_exporter": "python",
|
|
|
|
"pygments_lexer": "ipython3",
|
|
|
|
"version": "3.10.10"
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"nbformat": 4,
|
|
|
|
"nbformat_minor": 2
|
|
|
|
}
|