Files
mlx-examples/gan/playground.ipynb

548 lines
108 KiB
Plaintext
Raw Normal View History

2024-07-26 16:07:40 +03:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Import Library"
]
},
{
"cell_type": "code",
2024-07-30 02:17:12 +03:00
"execution_count": 427,
2024-07-26 16:07:40 +03:00
"metadata": {},
"outputs": [],
"source": [
"import mnist"
]
},
{
"cell_type": "code",
2024-07-30 02:17:12 +03:00
"execution_count": 428,
2024-07-26 16:07:40 +03:00
"metadata": {},
"outputs": [],
"source": [
"import mlx.core as mx\n",
"import mlx.nn as nn\n",
"import mlx.optimizers as optim\n",
"\n",
"from tqdm import tqdm\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
2024-07-30 02:17:12 +03:00
{
"cell_type": "code",
"execution_count": 429,
"metadata": {},
"outputs": [],
"source": [
"# mx.set_default_device(mx.gpu)"
]
},
2024-07-26 16:07:40 +03:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# GAN Architecture"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generator 👨🏻‍🎨"
]
},
{
"cell_type": "code",
2024-07-30 02:17:12 +03:00
"execution_count": 430,
2024-07-26 16:07:40 +03:00
"metadata": {},
"outputs": [],
"source": [
"def GenBlock(in_dim:int,out_dim:int):\n",
" \n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
2024-07-30 02:06:52 +03:00
" nn.BatchNorm(out_dim, 0.8),\n",
" nn.LeakyReLU(0.2)\n",
2024-07-26 16:07:40 +03:00
" )"
]
},
{
"cell_type": "code",
2024-07-30 02:17:12 +03:00
"execution_count": 431,
2024-07-26 16:07:40 +03:00
"metadata": {},
"outputs": [],
"source": [
"class Generator(nn.Module):\n",
"\n",
2024-07-30 02:06:52 +03:00
" def __init__(self, z_dim:int = 10, im_dim:int = 784, hidden_dim: int = 256):\n",
2024-07-26 16:07:40 +03:00
" super(Generator, self).__init__()\n",
" # Build the neural network\n",
" self.gen = nn.Sequential(\n",
" GenBlock(z_dim, hidden_dim),\n",
" GenBlock(hidden_dim, hidden_dim * 2),\n",
" GenBlock(hidden_dim * 2, hidden_dim * 4),\n",
" GenBlock(hidden_dim * 4, hidden_dim * 8),\n",
"\n",
"\n",
" nn.Linear(hidden_dim * 8,im_dim),\n",
" )\n",
" \n",
" def __call__(self, noise):\n",
2024-07-30 02:06:52 +03:00
" x = self.gen(noise)\n",
" return mx.tanh(x)"
2024-07-26 16:07:40 +03:00
]
},
{
"cell_type": "code",
2024-07-30 02:17:12 +03:00
"execution_count": 432,
2024-07-26 16:07:40 +03:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Generator(\n",
" (gen): Sequential(\n",
" (layers.0): Sequential(\n",
2024-07-30 02:06:52 +03:00
" (layers.0): Linear(input_dims=100, output_dims=256, bias=True)\n",
" (layers.1): BatchNorm(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): LeakyReLU()\n",
2024-07-26 16:07:40 +03:00
" )\n",
" (layers.1): Sequential(\n",
2024-07-30 02:06:52 +03:00
" (layers.0): Linear(input_dims=256, output_dims=512, bias=True)\n",
" (layers.1): BatchNorm(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): LeakyReLU()\n",
2024-07-26 16:07:40 +03:00
" )\n",
" (layers.2): Sequential(\n",
2024-07-30 02:06:52 +03:00
" (layers.0): Linear(input_dims=512, output_dims=1024, bias=True)\n",
" (layers.1): BatchNorm(1024, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): LeakyReLU()\n",
2024-07-26 16:07:40 +03:00
" )\n",
" (layers.3): Sequential(\n",
2024-07-30 02:06:52 +03:00
" (layers.0): Linear(input_dims=1024, output_dims=2048, bias=True)\n",
" (layers.1): BatchNorm(2048, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): LeakyReLU()\n",
2024-07-26 16:07:40 +03:00
" )\n",
2024-07-30 02:06:52 +03:00
" (layers.4): Linear(input_dims=2048, output_dims=784, bias=True)\n",
2024-07-26 16:07:40 +03:00
" )\n",
")"
]
},
2024-07-30 02:17:12 +03:00
"execution_count": 432,
2024-07-26 16:07:40 +03:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gen = Generator(100)\n",
"gen"
]
},
{
"cell_type": "code",
2024-07-30 02:17:12 +03:00
"execution_count": 433,
2024-07-26 16:07:40 +03:00
"metadata": {},
"outputs": [],
"source": [
2024-07-29 19:44:16 +03:00
"# make 2D noise with shape n_samples x z_dim\n",
"def get_noise(n_samples:list[int], z_dim:int)->list[int]:\n",
" return mx.random.normal(shape=(n_samples, z_dim))"
]
},
{
"cell_type": "code",
2024-07-30 02:17:12 +03:00
"execution_count": 434,
2024-07-29 19:44:16 +03:00
"metadata": {},
"outputs": [
{
"data": {
2024-07-30 02:17:12 +03:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAWKUlEQVR4nO3ca2zW9d3H8U9LhdJyKNVaoEABoTBkIEJtIDpAHAw2BDQ4j8MlbDFL3IzotmiyLCzOLTPRbImTaba4LFlMJuKGYAQhjoNAK1pAKKeCnAq01GIplNpy3c++ubmf9Pr8Etly5/16fL2vlp4+/J98czKZTEYAAEjK/U9/AgCA/x6MAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAEJeti986aWX7Df/+OOP7SYnJ8duJGno0KF2c9ttt9lNdXW13UyePNluVqxYYTeSNHfuXLvZsWOH3Zw9e9ZuKisr7UaSTp8+bTednZ1209bWdk2aefPm2Y0kvf/++3azceNGu1m8eLHdDBgwwG7y8rL+83OVgwcP2s306dPtpqOjw25qa2vtRpLGjRtnN7m5/v/pn3766e7f135XAMD/W4wCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAABC1hep3nnnHfvNq6qq7CblCJUkrV+/3m5Sjpk1NjbazZo1a+xm4cKFdiNJe/futZuioiK7GT9+vN2kfG6SdPnyZbspLi62mz59+tjNxYsX7aa5udluJKlfv352M2nSJLspLCy0m6amJrtpb2+3G0kqKyuzm88//9xujh8/bjc9e/a0G0latWqV3fzwhz9M+ljd4UkBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAAhJxMJpPJ5oVPPPGE/eZXrlyxm5TjbJJ05MgRu3n00UftprW11W7++te/2k3qYa0FCxbYzccff2w3ly5dspuZM2fajSStXbvWbsrLy+2mvr7ebrq6uuxm0KBBdiNJx44ds5uUA4733HOP3bz77rt2k/q7vmXLFru599577Wbp0qV284c//MFuJCk/P99ufve739lNNn/ueVIAAARGAQAQGAUAQGAUAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAIS8bF944403fpWfR6itrU3qOjo67ObDDz+0m6NHj9rNyJEj7SYnJ8dupLSrnU1NTXaTcgG3rq7ObiSppKTEbs6cOWM3p0+ftpu5c+fazfnz5+1GklpaWuxm/PjxdpNycfjmm2+2m5TLpZLUo0cPu7l8+bLdvPbaa3aT8ndIkoqLi+1m+vTpSR+rOzwpAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgJCTyWQy2bzwF7/4hf3mhYWFdtO3b1+7kaTGxka7+eCDD+zm888/t5tBgwbZzbBhw+xGkq677jq7KS8vt5vJkyfbzZNPPmk3ktTe3m43P/vZz+zm/ffft5uamhq7+cEPfmA3kvTZZ5/ZTZa/3le5dOmS3ZSVldlNytdOkkaPHm03KUf0Ug4kDh8+3G4kqaCgwG4uXLhgN7/+9a+7fQ1PCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACDkZfvClGNcU6dOtZuUw3uSdO+999rNfffdZzerVq2ym5TDgPn5+XYjSYcOHbKbW2+91W7+9a9/2c2iRYvsRpKam5vtJuXYYcrX/LHHHrOb3bt3242UdoTwueees5uUf9P27dvtpmfPnnYjSb169bKbrq4uu7l48aLdXMuDnimHAbPBkwIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIWR/EO3XqlP3mmzdvtpuZM2fajST16dPHbg4cOGA3J0+etJuFCxfazZkzZ+xGkhYvXmw3+/bts5uPPvrIblK+3pK0YMECu1m3bp3dPP7443YzYsQIu6mvr7cbKe3o3EsvvWQ3W7dutZuJEyfazdtvv203ktTa2mo3KX9XUv7mZTIZu5GkioqKa/axusOTAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAg5mSyvKi1atMh+8xkzZtjN/v377UaSmpub7ebmm2+2m4KCArtZtmyZ3aQcgZOk/Px8u3n55Zft5q677rKbWbNm2Y0kdXR02M3tt99uN3/729/spri42G5SfoYkqb293W4+/fRTu0k5kDhhwgS7mTp1qt1I0tChQ+2mZ8+edjNnzhy7STlAKEldXV12k/K7/sILL3T7Gp4UAACBUQAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAAAhL9sX3nDDDfabHzx40G5yc9N2KuUKYt++fe1m165ddpNyhfSmm26yG0lauXKl3fzmN7+xm/LycrspLCy0GyntEumOHTvsZvDgwXaT8vNaUlJiN5K0d+9eu8nLy/pXPLzxxht289Of/tRuUq6dStLRo0ft5uTJk3Zz4cIFu6mtrbUbSZo9e7bdbN68OeljdYcnBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAABEYBABCyvpbV0tJiv3lXV5fd3HnnnXYjSevXr7ebnJwcu/nyyy/tJuVzu/vuu+1GSjtcePz4cbtJOYh3/vx5u5Gk1atX283SpUvtZtq0aXbz97//3W7WrFljN5L0k5/8xG5Gjx5tNzU1NXYzceJEu0k9DNjW1mY3N9544zVpZsyYYTeSNHz4cLtpbGxM+ljd4UkBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAACBUQAAhJxMJpPJ5oWPPfaY/ebz5s2zmxdffNFupLTjdln+068yduxYu1m4cKHdLFmyxG4kqbOz026ef/55uzlz5ozd1NfX240kjRs3zm5OnTplNxs3brSbMWPG2M211KtXL7tpbW21m/z8fLtJOQInpX1v33zzTbv54x//aDd1dXV2I0n/+Mc/7Cblb96+ffu6fQ1PCgCAwCgAAAKjAAAIjAIAIDAKAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACDkZfvCbA4p/V8HDhywm1mzZtmNJLW1tdnNpk2b7Gbbtm12U1lZaTcpxwSltH/T/v377Wbnzp12k3JoTZKWLl1qN9u3b7eblMOFLS0tdpNypE6Shg0bZjcNDQ1206NHD7tJOR5XUVFhN5L00EMP2c2AAQPsZuvWrXYze/Zsu5HSDlmm/C3KBk8KAIDAKAAAAqMAAAiMAgAgMAoAgMAoAAACowAACIwCACAwCgCAwCgAAAKjAAAIjAIAIDAKAICQk8lkMtm88PXXX7ff/JVXXrGbyZMn240k7d69227KysrsprS01G6am5vtJuVCoyRNmjTJbsaPH283R44csZvLly/bjSTl5vr/dykqKrKb0aNH203K5deUC7OStGPHDrt58MEH7WbUqFF2U1BQYDcrV660G0lavHix3XzxxRd2k3ItdtWqVXYjSc8++6zdrF271m6WL1/e7Wt4UgAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAADhKz2Il3Jw7re//a3dSNLXv/71pM41cOBAu9m1a5fd3H///XYjSe+8847dnDt3zm7WrVtnN3fddZfdSFJVVZXd1NXV2U11dbXdzJkzx24uXLhgN5JUWFhoNxs2bLCbjo4Ou0k5xJiXl2c3ktTZ2Wk3+fn5djN48GC72bNnj91Iad/bMWPG2M0zzzzT7Wt4UgAABEYBABAYBQBAYBQAAIFRAAAERgEAEBgFAEBgFAAAgVEAAARGAQAQGAUAQGAUAAAh64tUZ86csd/8jTfesJu//OUvdiNJTz7
2024-07-29 19:44:16 +03:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"img = get_noise(28,28)\n",
"plt.imshow(img, cmap='gray')\n",
"plt.axis('off')\n",
"plt.show()"
2024-07-26 16:07:40 +03:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Discriminator 🕵🏻‍♂️"
]
},
{
"cell_type": "code",
2024-07-30 02:17:12 +03:00
"execution_count": 435,
2024-07-26 16:07:40 +03:00
"metadata": {},
"outputs": [],
"source": [
"def DisBlock(in_dim:int,out_dim:int):\n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
2024-07-29 19:44:16 +03:00
" nn.LeakyReLU(negative_slope=0.2),\n",
2024-07-26 16:07:40 +03:00
" )"
]
},
{
"cell_type": "code",
2024-07-30 02:17:12 +03:00
"execution_count": 436,
2024-07-26 16:07:40 +03:00
"metadata": {},
"outputs": [],
"source": [
"class Discriminator(nn.Module):\n",
"\n",
" def __init__(self,im_dim:int = 784, hidden_dim:int = 128):\n",
" super(Discriminator, self).__init__()\n",
"\n",
" self.disc = nn.Sequential(\n",
2024-07-30 02:06:52 +03:00
" # DisBlock(im_dim, hidden_dim * 4),\n",
" # DisBlock(hidden_dim * 4, hidden_dim * 2),\n",
" # DisBlock(hidden_dim * 2, hidden_dim),\n",
" \n",
" DisBlock(im_dim, hidden_dim * 2),\n",
2024-07-26 16:07:40 +03:00
" DisBlock(hidden_dim * 2, hidden_dim),\n",
"\n",
" nn.Linear(hidden_dim,1),\n",
2024-07-30 02:06:52 +03:00
" nn.Sigmoid()\n",
2024-07-26 16:07:40 +03:00
" )\n",
" \n",
" def __call__(self, noise):\n",
" return self.disc(noise)"
]
},
{
"cell_type": "code",
2024-07-30 02:17:12 +03:00
"execution_count": 437,
2024-07-26 16:07:40 +03:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Discriminator(\n",
" (disc): Sequential(\n",
" (layers.0): Sequential(\n",
2024-07-30 02:06:52 +03:00
" (layers.0): Linear(input_dims=784, output_dims=256, bias=True)\n",
2024-07-26 16:07:40 +03:00
" (layers.1): LeakyReLU()\n",
" )\n",
" (layers.1): Sequential(\n",
" (layers.0): Linear(input_dims=256, output_dims=128, bias=True)\n",
" (layers.1): LeakyReLU()\n",
" )\n",
2024-07-30 02:06:52 +03:00
" (layers.2): Linear(input_dims=128, output_dims=1, bias=True)\n",
" (layers.3): Sigmoid()\n",
2024-07-26 16:07:40 +03:00
" )\n",
")"
]
},
2024-07-30 02:17:12 +03:00
"execution_count": 437,
2024-07-26 16:07:40 +03:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"disc = Discriminator()\n",
"disc"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model Training 🏋🏻‍♂️"
]
},
{
2024-07-29 19:44:16 +03:00
"cell_type": "markdown",
2024-07-26 16:07:40 +03:00
"metadata": {},
"source": [
2024-07-29 19:44:16 +03:00
"### Losses"
2024-07-26 16:07:40 +03:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-07-29 19:44:16 +03:00
"#### Discriminator Loss"
2024-07-26 16:07:40 +03:00
]
},
{
"cell_type": "code",
2024-07-30 02:17:12 +03:00
"execution_count": 438,
2024-07-26 16:07:40 +03:00
"metadata": {},
"outputs": [],
"source": [
2024-07-26 16:36:29 +03:00
"def disc_loss(gen, disc, real, num_images, z_dim):\n",
2024-07-29 19:44:16 +03:00
" \n",
2024-07-26 16:07:40 +03:00
" noise = mx.array(get_noise(num_images, z_dim))\n",
" fake_images = gen(noise)\n",
2024-07-27 01:09:51 +03:00
" \n",
2024-07-26 16:07:40 +03:00
" fake_disc = disc(fake_images)\n",
" \n",
2024-07-27 00:19:08 +03:00
" fake_labels = mx.zeros((fake_images.shape[0],1))\n",
2024-07-29 19:44:16 +03:00
" \n",
2024-07-30 02:06:52 +03:00
" fake_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels)\n",
2024-07-26 16:07:40 +03:00
" \n",
" real_disc = disc(real)\n",
2024-07-27 00:19:08 +03:00
" real_labels = mx.ones((real.shape[0],1))\n",
"\n",
2024-07-30 02:06:52 +03:00
" real_loss = nn.losses.binary_cross_entropy(real_disc,real_labels)\n",
2024-07-26 16:07:40 +03:00
"\n",
2024-07-30 02:06:52 +03:00
" disc_loss = (fake_loss + real_loss) / 2\n",
2024-07-26 16:07:40 +03:00
"\n",
" return disc_loss"
]
},
2024-07-29 19:44:16 +03:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Generator Loss"
]
},
2024-07-26 16:07:40 +03:00
{
"cell_type": "code",
2024-07-30 02:17:12 +03:00
"execution_count": 439,
2024-07-26 16:07:40 +03:00
"metadata": {},
"outputs": [],
"source": [
2024-07-26 16:36:29 +03:00
"def gen_loss(gen, disc, num_images, z_dim):\n",
2024-07-26 16:07:40 +03:00
"\n",
" noise = mx.array(get_noise(num_images, z_dim))\n",
2024-07-30 02:06:52 +03:00
" \n",
2024-07-26 16:07:40 +03:00
" fake_images = gen(noise)\n",
" fake_disc = disc(fake_images)\n",
"\n",
2024-07-27 00:19:08 +03:00
" fake_labels = mx.ones((fake_images.shape[0],1))\n",
2024-07-29 19:44:16 +03:00
" \n",
2024-07-30 02:06:52 +03:00
" gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels)\n",
2024-07-26 16:07:40 +03:00
"\n",
" return gen_loss"
]
},
{
"cell_type": "code",
2024-07-30 02:17:12 +03:00
"execution_count": 440,
2024-07-29 19:44:16 +03:00
"metadata": {},
"outputs": [],
"source": [
"# Get only the training images\n",
"train_images,*_ = map(np.array, mnist.mnist())"
]
},
{
"cell_type": "code",
2024-07-30 02:17:12 +03:00
"execution_count": 441,
2024-07-26 16:07:40 +03:00
"metadata": {},
"outputs": [],
"source": [
2024-07-29 19:44:16 +03:00
"# Normalize the images to fall between -1,1\n",
"train_images = train_images * 2.0 - 1.0"
2024-07-26 16:07:40 +03:00
]
},
{
"cell_type": "code",
2024-07-30 02:17:12 +03:00
"execution_count": 442,
2024-07-29 01:24:50 +03:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2024-07-30 02:17:12 +03:00
"<matplotlib.image.AxesImage at 0x156d411b0>"
2024-07-29 01:24:50 +03:00
]
},
2024-07-30 02:17:12 +03:00
"execution_count": 442,
2024-07-29 01:24:50 +03:00
"metadata": {},
"output_type": "execute_result"
2024-07-29 19:44:16 +03:00
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaEAAAGdCAYAAAC7EMwUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAaI0lEQVR4nO3df2jU9x3H8dfVH1d1lytBk7vUmGVF202dpWrVYP3R1cxApf4oWMtGZEPa+YOJ/cGsDNNBjdgpRdI6V0amW239Y9a6KdUMTXRkijpdRYtYjDOdCcFM72LUSMxnf4hHz1j1e975vkueD/iCufu+vY/ffuvTby75xueccwIAwMBD1gsAAHRfRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJjpab2AW3V0dOjcuXMKBALy+XzWywEAeOScU0tLi/Ly8vTQQ3e+1km7CJ07d075+fnWywAA3Kf6+noNHDjwjvuk3afjAoGA9RIAAElwL3+fpyxCH3zwgQoLC/Xwww9r5MiR2rdv3z3N8Sk4AOga7uXv85REaPPmzVq8eLGWLVumI0eO6JlnnlFJSYnOnj2bipcDAGQoXyruoj1mzBg99dRTWrduXeyx73//+5o+fbrKy8vvOBuNRhUMBpO9JADAAxaJRJSVlXXHfZJ+JXTt2jUdPnxYxcXFcY8XFxertra20/5tbW2KRqNxGwCge0h6hM6fP6/r168rNzc37vHc3Fw1NjZ22r+8vFzBYDC28ZVxANB9pOwLE259Q8o5d9s3qZYuXapIJBLb6uvrU7UkAECaSfr3CfXv3189evTodNXT1NTU6epIkvx+v/x+f7KXAQDIAEm/Eurdu7dGjhypqqqquMerqqpUVFSU7JcDAGSwlNwxYcmSJfrpT3+qUaNGady4cfr973+vs2fP6tVXX03FywEAMlRKIjR79mw1NzfrN7/5jRoaGjRs2DDt2LFDBQUFqXg5AECGSsn3Cd0Pvk8IALoGk+8TAgDgXhEhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmelovAEgnPXr08DwTDAZTsJLkWLhwYUJzffv29Tzz+OOPe55ZsGCB55nf/va3nmfmzJnjeUaSrl696nlm5cqVnmfefvttzzNdBVdCAAAzRAgAYCbpESorK5PP54vbQqFQsl8GANAFpOQ9oaFDh+rvf/977ONEPs8OAOj6UhKhnj17cvUDALirlLwndOrUKeXl5amwsFAvvfSSTp8+/a37trW1KRqNxm0AgO4h6REaM2aMNm7cqJ07d+rDDz9UY2OjioqK1NzcfNv9y8vLFQwGY1t+fn6ylwQASFNJj1BJSYlmzZql4cOH67nnntP27dslSRs2bLjt/kuXLlUkEolt9fX1yV4SACBNpfybVfv166fhw4fr1KlTt33e7/fL7/enehkAgDSU8u8Tamtr05dffqlwOJzqlwIAZJikR+j1119XTU2N6urqdODAAb344ouKRqMqLS1N9ksBADJc0j8d9/XXX2vOnDk6f/68BgwYoLFjx2r//v0qKChI9ksBADJc0iP0ySefJPu3RJoaNGiQ55nevXt7nikqKvI8M378eM8zkvTII494npk1a1ZCr9XVfP31155n1q5d63lmxowZnmdaWlo8z0jSv//9b88zNTU1Cb1Wd8W94wAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAMz7nnLNexDdFo1EFg0HrZXQrTz75ZEJzu3fv9jzDf9vM0NHR4XnmZz/7meeZS5cueZ5JRENDQ0JzFy5c8Dxz8uTJhF6rK4pEIsrKyrrjPlwJAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwExP6wXA3tmzZxOaa25u9jzDXbRvOHDggOeZixcvep6ZPHmy5xlJunbtmueZP/3pTwm9Fro3roQAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADPcwBT63//+l9DcG2+84Xnm+eef9zxz5MgRzzNr1671PJOoo0ePep6ZMmWK55nW1lbPM0OHDvU8I0m//OUvE5oDvOJKCABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAw43POOetFfFM0GlUwGLReBlIkKyvL80xLS4vnmfXr13uekaSf//znnmd+8pOfeJ75+OOPPc8AmSYSidz1/3muhAAAZogQAMCM5wjt3btX06ZNU15ennw+n7Zu3Rr3vHNOZWVlysvLU58+fTRp0iQdP348WesFAHQhniPU2tqqESNGqKKi4rbPr1q1SmvWrFFFRYUOHjyoUCikKVOmJPR5fQBA1+b5J6uWlJSopKTkts855/Tee+9p2bJlmjlzpiRpw4YNys3N1aZNm/TKK6/c32oBAF1KUt8TqqurU2Njo4qLi2OP+f1+TZw4UbW1tbedaWtrUzQajdsAAN1DUiPU2NgoScrNzY17PDc3N/bcrcrLyxUMBmNbfn5+MpcEAEhjKfnqOJ/PF/exc67TYzctXbpUkUgkttXX16diSQCANOT5PaE7CYVCkm5cEYXD4djjTU1Nna6ObvL7/fL7/clcBgAgQyT1SqiwsFChUEhVVVWxx65du6aamhoVFRUl86UAAF2A5yuhS5cu6auvvop9XFdXp6NHjyo7O1uDBg3S4sWLtWLFCg0ePFiDBw/WihUr1LdvX7388stJXTgAIPN5jtChQ4c0efLk2MdLliyRJJWWluqPf/yj3nzzTV25ckXz58/XhQsXNGbMGO3atUuBQCB5qwYAdAncwBRd0rvvvpvQ3M1/VHlRU1Pjeea5557zPNPR0eF5BrDEDUwBAGmNCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZriLNrqkfv36JTT317/+1fPMxIkTPc+UlJR4ntm1a5fnGcASd9EGAKQ1IgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMNzAFvuGxxx7zPPOvf/3L88zFixc9z+zZs8fzzKFDhzzPSNL777/veSbN/ipBGuAGpgCAtEaEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmOEGpsB9mjFjhueZyspKzzOBQMDzTKLeeustzzMbN270PNPQ0OB5BpmDG5gCANIaEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGG5gCBoYNG+Z5Zs2aNZ5nfvSjH3meSdT69es9z7zzzjueZ/773/96noENbmAKAEhrRAgAYMZzhPbu3atp06YpLy9PPp9PW7dujXt+7ty58vl8cdvYsWOTtV4AQBfiOUKtra0aMWKEKioqvnWfqVOnqqGhIbbt2LHjvhYJAOiaenodKCkpUUlJyR338fv9CoVCCS8KANA9pOQ9oerqauXk5GjIkCGaN2+empqavnXftrY2RaPRuA0A0D0kPUIlJSX66KOPtHv3bq1evVoHDx7Us88+q7a2ttvuX15ermAwGNvy8/OTvSQAQJry/Om4u5k9e3bs18OGDdOoUaNUUFCg7du3a+bMmZ32X7p0qZYsWRL7OBqNEiI
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
2024-07-29 01:24:50 +03:00
}
],
"source": [
2024-07-29 19:44:16 +03:00
"plt.imshow(train_images[0].reshape(28,28),cmap='gray')"
2024-07-29 01:24:50 +03:00
]
},
{
"cell_type": "code",
2024-07-30 02:17:12 +03:00
"execution_count": 443,
2024-07-26 16:07:40 +03:00
"metadata": {},
"outputs": [],
"source": [
2024-07-29 19:44:16 +03:00
"def batch_iterate(batch_size: int, ipt: list[int])-> list[int]:\n",
2024-07-29 01:24:50 +03:00
" perm = np.random.permutation(len(ipt))\n",
" for s in range(0, len(ipt), batch_size):\n",
2024-07-26 16:07:40 +03:00
" ids = perm[s : s + batch_size]\n",
2024-07-29 19:44:16 +03:00
" yield ipt[ids]"
]
},
{
"cell_type": "code",
2024-07-30 02:17:12 +03:00
"execution_count": 444,
2024-07-29 19:44:16 +03:00
"metadata": {},
"outputs": [],
"source": [
"def show_images(imgs:list[int],num_imgs:int = 25):\n",
" if (imgs.shape[0] > 0): \n",
" fig,axes = plt.subplots(5, 5, figsize=(5, 5))\n",
" \n",
" for i, ax in enumerate(axes.flat):\n",
" img = mx.array(imgs[i]).reshape(28,28)\n",
" ax.imshow(img,cmap='gray')\n",
" ax.axis('off')\n",
" plt.show()"
2024-07-26 16:07:40 +03:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-07-29 19:44:16 +03:00
"### show first batch of images"
2024-07-26 16:07:40 +03:00
]
},
{
"cell_type": "code",
2024-07-30 02:17:12 +03:00
"execution_count": 445,
2024-07-26 16:07:40 +03:00
"metadata": {},
"outputs": [
{
"data": {
2024-07-30 02:17:12 +03:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAADrT0lEQVR4nOy92Y9d15Wn+d15nueIGyNjIBmcSVGzaElOp+10uZTZ1a52o5AooNEPjX5qoP+MBvq5HwrdlWh0IwtwprPtstOyZcumJpKixJkRjHm48zzP9/YDsbdukEGKkiIYcYPnAwRbYkTwnhPn7LX3Wr/1W6per9dDQUFBQUFhF1Hv9wdQUFBQUDh8KMFFQUFBQWHXUYKLgoKCgsKuowQXBQUFBYVdRwkuCgoKCgq7jhJcFBQUFBR2HSW4KCgoKCjsOkpwUVBQUFDYdZTgoqCgoKCw62if9QtVKtVefo6B4tuYGij37yuU+/fd+LamGso9/ArlGfxuPMv9U04uCgoKCgq7jhJcFBQUFBR2HSW4KCgoKCjsOkpwUVBQUFDYdZTgoqCgoKCw6yjBRUFBQUFh13lmKbLC4eBZ5ZTKDDmFb8puSHWV5+7woASXQ45Go8HhcGAwGBgbG2N8fJxOp0O5XAYgEAhgt9vpdrt0Oh2q1SrLy8uUy2Xy+TyFQoFWq0W1WqXb7e7z1SgcNKxWKzabDZ/Px/nz53E4HKhUKjQazTN9f6fTIZvNUi6XKRaLZDKZHZ+zUqlEJBKh2WwqzyEP32uj0YhGo8Fut2M0Gmm327RaLVqtFvl8nmazua+fUQkuhxytVovX68XpdPL222/z7rvvUqvVSCaTAJw5c4aRkRHa7TaNRoN0Os2//uu/srm5ycrKCmtra1SrVer1uvJSKzyG3W5naGiIkydP8r/8L/8LIyMjaLXax4KLWv0wA9/r9badTlqtFvPz88TjcTY2Nrh37x6dTuexvycSiZDNZul0Oo/9jBcRrVaL1WrFaDQSDodxu93U63VKpRK1Wo1qtaoEl0dRq9WoVCr0er2MzGazGbVajUajQaVSodVqMRgM8v9/3S6p0WhQLpdpNptkMhmq1eqhfTjVajVqtRqz2YzT6cRsNjMzM4PL5WJsbAyPx0Oj0ZBf73A4MJlMdLtd9Ho97Xab0dFRDAaD3B2lUikKhQLdbvfQBhiNRoNarZYv7E6InWG73f5GJzm1Wr3tfup0OsxmMzabDXi44HY6HeLxOOl0eteu6XlgMpnweDzyWTMajWi1WhlM4GG6TKTMHg0MGo0Gp9NJu92m0+nQaDR2DC56vZ7NzU0KhQLpdJpKpbL3F7fPqNVqtFotWq0Wu92OTqdDr9fL58fn82E0GgkGgzgcDmq1GqVSiXw+Tzqdls/qTvfzeXDggovBYECn0+Hz+RgdHcVutzM5OSlTO3q9HqvVSjgcxmQyYbPZ5GLwpICRSCT48ssvSaVS/Pa3v2V+fv5QLpQqlQqTyYTBYGBmZobXX38dj8fD+fPn8fl8eL1ePB4P7Xab8fFxACwWCyaTCXh4/ywWCz/+8Y9pNpvEYjESiQQ3btwgHo+TTCZpNpv79rDuFWLRNxqNnDp1St6bR8nlcuTzeXK5HAsLC9Tr9Wf6+SaTiWAwiMlkYnJyEo/Hw8zMDGfPnkWtVtNoNKjX6/zf//f/zb/8y7/s4pXtLSqVikAgwOnTp5mcnMRisaDVah+rvfS/l4++oxqNhnA4TCAQYGZmhpdeemnHv2ttbQ2Px0M8HueDDz5gaWlp9y/ogGEwGHA4HDgcDi5cuIDf75f/OBwOxsbG5HOr0+moVqsUCgU2NjYolUosLCxQKBQolUr78vkPRHBRqVRyxy0WO5fLRTAYxG63MzY2htPpxGg0YjAYsNvtjI+PYzab5c4bnhxcXC4XhUIBk8kkf06r1dq2gx9kRI5bnFjMZjNer5fR0VG8Xi/j4+P4fD50Oh0ajYZut4tGo6HX69Htdmk0GvLEqFarcbvdwMMX32QykUwmsdlslMtlWZs5TKjVakwmE2azGb/fTzgcfuxrer0eRqMRvV6PSqXCbrdvOzF3Oh3a7ba8p71eT+46LRYLLpcLq9VKKBTC7/czNjbGkSNHUKlUlMtlKpUKVqsVlUo1UKdqvV6PzWbDbDbLzAJsfxd3uqb+fzcYDBgMhqf+PY1Gg2AwSLfbfeLJcpAR73D/Wmi1WnE6nTidToaGhuQ/oVAIu93OxMSEzO6o1WoqlQoOh4NutyufN5Eq2w/2NbhotVp5xAsGg1itVt544w2OHTsmX0i9Xo/L5ZJpBY1Gg16vly+3Vqv92pfR6XRy+vRpxsbGyOfzjI+Ps7S0xI0bN2i1WgO7WIqH0GazMTk5ic1m48SJE4yMjBAOhzl+/LgM1Gq1moWFBZaWlshkMty/f59ms0k4HMblcmGxWLDb7VitVmZmZrDb7Wi1WoLBICdOnOBnP/sZsViMDz74gAcPHuz3pe8qHo+HS5cu4ff7uXjxItPT0499Ta/Xo1qtyn8SicS2nPby8jLXr1+nWCyytbVFvV7nzJkznDx5Eq/Xy9GjR7FYLHi9XqxWK9Vqlfn5efL5PF988QXpdJobN24MVGAB6Ha7MvWyl59d1BhsNht6vX7P/p79wmazMTU1hdVqZWxsDK/Xi8vlYmhoCLPZTDgclkHcbDbLgNwf0MXGe2hoiO9///tMT0/zl7/8hUwmsy/P1b4HF4PBgM1mY2RkBK/Xy1//9V9z6dIluXA+C19348xmM5OTk1SrVbLZLF6vF4B79+7J9NigvdTw1W7HarUyOTmJ1+vl0qVLnDhxAofDQSAQQKVSUa/XabVabG5u8sknn7CxscGf/vQnqtUqc3NzMsAEAgG8Xq9U/LhcLhwOB71ej0uXLhGPx7l9+/ahCy52u51z584xNjbG+fPnOXLkyI5f1+l05D+PChw++ugjmevOZDK0Wi2mpqZ4++23CQaDnDlzBpPJJHeat2/f5v79+2xtbfH+++8TiUSo1WrP65J3DXGS3et3SNReTSbTM68Lg4TJZGJiYgK/3y+fQZF90Ol06HS6Ha+7PwUpTsq9Xo+zZ88yMjLCysrK87yMbTy34CKOe1qtFqfTicFgIBQKMTQ0hMPhYGJiAqfTid/vlyma70q73abdbqNWq+UvKBgMolKpWFtbw+l0UqlUKBQKtNvtXbjKvUelUmE2m9FqtVJa7PF4OHHiBC6Xi5GREXmqKxaL1Go17t69SyaT4csvv+TBgwek02kp6czn86jVallLEAElGAxy/PhxHA4HOp0Op9NJo9GQu6d2u73vapTdQgRpkV6Ah89OvV6XgaQ/5dUvAxWboGAwyNmzZykUCng8Hmq1GufPn2d8fBybzUa325UqvXq9zp07d7h16xbJZJJisTiwtax0Os3du3epVCqMj4/j9Xq3FfDhYWAQ9Rir1YrFYtnHT7z/iPS/TqfD7/cTCATw+/1cuHABl8slg4zVapXiCHE/m80mzWaTRqNBPp+n0+nIr7HZbLjd7m2ptf0cE/DcgotQzJjNZqanp/H5fFy8eJFXX30Vi8VCIBDAYDBgtVp3bWdSr9epVqsyjabT6ZidnWVqaopEIkE4HCabzVKr1QYmuGi1WtxuNzabjb/+67/mpz/9KQ6Hg6GhISl4EIElkUgQiUT4T//pP3Hnzh1yuZyUc7ZaLXq9Hpubm0SjUflA2u12UqkUgUAAjUbD9PQ0RqOR4eFh9Ho9gUAAt9tNuVyWP2PQUalUUokjnr1arUY6nabRaJBKpbadKoTIRK/XSwHK0aNH8Xq9tFotyuUynU6HUChEIBCg2WzKQH/jxg0ikQjXr1/ngw8+oFarUalUZPAaJHq9HktLS2xsbDA6OopWq8Xj8TymFjMYDITDYaxWK+Pj4y98cNFqtbIo/9prr/HWW2/hdruZnZ2VQUecVB4NEJVKRb7H9+7do1arYTKZ0Ov1HDlyBIfDAbBNXbtfPJfgolKpMBqNuN1uLBYL4XAYv98vi5smk0nukHU6HfCVPFPkdIWEsdfryYXw68hkMmSzWSwWi6zViL/DYDBIGfMgDQESL7D
2024-07-26 16:07:40 +03:00
"text/plain": [
2024-07-29 01:24:50 +03:00
"<Figure size 500x500 with 25 Axes>"
2024-07-26 16:07:40 +03:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
2024-07-29 19:44:16 +03:00
"X = batch_iterate(25, train_images)\n",
"for x in X: \n",
" show_images(x)\n",
2024-07-26 16:07:40 +03:00
" break"
]
},
2024-07-30 02:06:52 +03:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training Cycle"
]
},
2024-07-26 16:07:40 +03:00
{
"cell_type": "code",
2024-07-30 02:17:12 +03:00
"execution_count": 446,
2024-07-27 01:09:51 +03:00
"metadata": {},
2024-07-29 01:24:50 +03:00
"outputs": [],
2024-07-27 01:09:51 +03:00
"source": [
2024-07-30 02:17:12 +03:00
"lr = 2e-4\n",
"z_dim = 64\n",
2024-07-30 02:06:52 +03:00
"\n",
2024-07-27 01:09:51 +03:00
"gen = Generator(z_dim)\n",
"mx.eval(gen.parameters())\n",
2024-07-30 02:17:12 +03:00
"gen_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999]) #,betas=[0.5, 0.9]\n",
2024-07-27 01:09:51 +03:00
"\n",
"disc = Discriminator()\n",
"mx.eval(disc.parameters())\n",
2024-07-30 02:17:12 +03:00
"disc_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999])"
2024-07-29 01:30:08 +03:00
]
},
2024-07-27 01:09:51 +03:00
{
"cell_type": "code",
2024-07-30 02:17:12 +03:00
"execution_count": 447,
2024-07-27 01:09:51 +03:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-30 02:17:12 +03:00
" 0%| | 0/200 [00:00<?, ?it/s]"
2024-07-28 19:18:35 +03:00
]
2024-07-26 16:07:40 +03:00
}
],
"source": [
2024-07-29 19:44:16 +03:00
"# Set your parameters\n",
"n_epochs = 200\n",
"display_step = 5000\n",
2024-07-28 01:10:19 +03:00
"cur_step = 0\n",
2024-07-29 19:44:16 +03:00
"\n",
"batch_size = 128\n",
2024-07-26 16:07:40 +03:00
"\n",
2024-07-28 01:10:19 +03:00
"D_loss_grad = nn.value_and_grad(disc, disc_loss)\n",
"G_loss_grad = nn.value_and_grad(gen, gen_loss)\n",
"\n",
"\n",
2024-07-29 19:44:16 +03:00
"for epoch in tqdm(range(n_epochs)):\n",
2024-07-28 01:10:19 +03:00
"\n",
2024-07-29 19:44:16 +03:00
" for real in batch_iterate(batch_size, train_images):\n",
2024-07-28 01:10:19 +03:00
" \n",
2024-07-29 19:44:16 +03:00
" # TODO Train Discriminator\n",
2024-07-30 02:06:52 +03:00
" D_loss,D_grads = D_loss_grad(gen, disc,mx.array(real), batch_size, z_dim)\n",
2024-07-28 01:10:19 +03:00
"\n",
" # Update optimizer\n",
" disc_opt.update(disc, D_grads)\n",
" \n",
" # Update gradients\n",
" mx.eval(disc.parameters(), disc_opt.state)\n",
2024-07-26 16:07:40 +03:00
"\n",
2024-07-29 19:44:16 +03:00
" # TODO Train Generator\n",
" G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)\n",
2024-07-28 01:10:19 +03:00
" \n",
" # Update optimizer\n",
" gen_opt.update(gen, G_grads)\n",
" \n",
" # Update gradients\n",
" mx.eval(gen.parameters(), gen_opt.state)\n",
" \n",
2024-07-29 01:24:50 +03:00
" \n",
2024-07-29 19:44:16 +03:00
" if (cur_step + 1) % display_step == 0:\n",
" print(f\"Step {epoch}: Generator loss: {G_loss}, discriminator loss: {D_loss}\")\n",
" fake_noise = mx.array(get_noise(batch_size, z_dim))\n",
" fake = gen(fake_noise)\n",
" show_images(fake)\n",
" show_images(real)\n",
" cur_step += 1"
2024-07-27 00:19:08 +03:00
]
2024-07-26 16:07:40 +03:00
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}