mlx-examples/gan/playground.ipynb

515 lines
59 KiB
Plaintext
Raw Normal View History

2024-07-26 21:07:40 +08:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Import Library"
]
},
{
"cell_type": "code",
2024-07-28 06:10:19 +08:00
"execution_count": 1,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"import mnist"
]
},
{
"cell_type": "code",
2024-07-28 22:22:40 +08:00
"execution_count": 2,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"import mlx.core as mx\n",
"import mlx.nn as nn\n",
"import mlx.optimizers as optim\n",
"\n",
"from tqdm import tqdm\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# GAN Architecture"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generator 👨🏻‍🎨"
]
},
{
"cell_type": "code",
2024-07-28 22:22:40 +08:00
"execution_count": 3,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def GenBlock(in_dim:int,out_dim:int):\n",
" \n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
" nn.BatchNorm(out_dim),\n",
" nn.ReLU()\n",
" )"
]
},
{
"cell_type": "code",
2024-07-28 22:22:40 +08:00
"execution_count": 4,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"class Generator(nn.Module):\n",
"\n",
" def __init__(self, z_dim:int = 10, im_dim:int = 784, hidden_dim: int =128):\n",
" super(Generator, self).__init__()\n",
" # Build the neural network\n",
" self.gen = nn.Sequential(\n",
" GenBlock(z_dim, hidden_dim),\n",
" GenBlock(hidden_dim, hidden_dim * 2),\n",
" GenBlock(hidden_dim * 2, hidden_dim * 4),\n",
" GenBlock(hidden_dim * 4, hidden_dim * 8),\n",
"\n",
"\n",
" nn.Linear(hidden_dim * 8,im_dim),\n",
" nn.Sigmoid()\n",
" )\n",
" \n",
" def __call__(self, noise):\n",
"\n",
" return self.gen(noise)"
]
},
{
"cell_type": "code",
2024-07-28 22:22:40 +08:00
"execution_count": 5,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Generator(\n",
" (gen): Sequential(\n",
" (layers.0): Sequential(\n",
" (layers.0): Linear(input_dims=100, output_dims=128, bias=True)\n",
" (layers.1): BatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): ReLU()\n",
" )\n",
" (layers.1): Sequential(\n",
" (layers.0): Linear(input_dims=128, output_dims=256, bias=True)\n",
" (layers.1): BatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): ReLU()\n",
" )\n",
" (layers.2): Sequential(\n",
" (layers.0): Linear(input_dims=256, output_dims=512, bias=True)\n",
" (layers.1): BatchNorm(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): ReLU()\n",
" )\n",
" (layers.3): Sequential(\n",
" (layers.0): Linear(input_dims=512, output_dims=1024, bias=True)\n",
" (layers.1): BatchNorm(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): ReLU()\n",
" )\n",
" (layers.4): Linear(input_dims=1024, output_dims=784, bias=True)\n",
" (layers.5): Sigmoid()\n",
" )\n",
")"
]
},
2024-07-28 22:22:40 +08:00
"execution_count": 5,
2024-07-26 21:07:40 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gen = Generator(100)\n",
"gen"
]
},
{
"cell_type": "code",
2024-07-28 22:22:40 +08:00
"execution_count": 6,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def get_noise(n_samples, z_dim):\n",
" return np.random.randn(n_samples,z_dim)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Discriminator 🕵🏻‍♂️"
]
},
{
"cell_type": "code",
2024-07-28 22:22:40 +08:00
"execution_count": 7,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def DisBlock(in_dim:int,out_dim:int):\n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
" nn.LeakyReLU(negative_slope=0.2)\n",
" )"
]
},
{
"cell_type": "code",
2024-07-28 22:22:40 +08:00
"execution_count": 8,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"class Discriminator(nn.Module):\n",
"\n",
" def __init__(self,im_dim:int = 784, hidden_dim:int = 128):\n",
" super(Discriminator, self).__init__()\n",
"\n",
" self.disc = nn.Sequential(\n",
" DisBlock(im_dim, hidden_dim * 4),\n",
" DisBlock(hidden_dim * 4, hidden_dim * 2),\n",
" DisBlock(hidden_dim * 2, hidden_dim),\n",
"\n",
" nn.Linear(hidden_dim,1),\n",
" )\n",
" \n",
" def __call__(self, noise):\n",
"\n",
" return self.disc(noise)"
]
},
{
"cell_type": "code",
2024-07-28 22:22:40 +08:00
"execution_count": 9,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Discriminator(\n",
" (disc): Sequential(\n",
" (layers.0): Sequential(\n",
" (layers.0): Linear(input_dims=784, output_dims=512, bias=True)\n",
" (layers.1): LeakyReLU()\n",
" )\n",
" (layers.1): Sequential(\n",
" (layers.0): Linear(input_dims=512, output_dims=256, bias=True)\n",
" (layers.1): LeakyReLU()\n",
" )\n",
" (layers.2): Sequential(\n",
" (layers.0): Linear(input_dims=256, output_dims=128, bias=True)\n",
" (layers.1): LeakyReLU()\n",
" )\n",
" (layers.3): Linear(input_dims=128, output_dims=1, bias=True)\n",
" )\n",
")"
]
},
2024-07-28 22:22:40 +08:00
"execution_count": 9,
2024-07-26 21:07:40 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"disc = Discriminator()\n",
"disc"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model Training 🏋🏻‍♂️"
]
},
{
"cell_type": "code",
2024-07-28 22:22:40 +08:00
"execution_count": 10,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"# Set your parameters\n",
"criterion = nn.losses.binary_cross_entropy\n",
"n_epochs = 200\n",
"z_dim = 64\n",
"display_step = 500\n",
"batch_size = 128\n",
"lr = 0.00001"
]
},
{
"cell_type": "code",
2024-07-28 22:22:40 +08:00
"execution_count": 11,
2024-07-26 21:07:40 +08:00
"metadata": {},
2024-07-27 06:09:51 +08:00
"outputs": [],
2024-07-26 21:07:40 +08:00
"source": [
"gen = Generator(z_dim)\n",
2024-07-26 21:36:29 +08:00
"mx.eval(gen.parameters())\n",
2024-07-26 21:07:40 +08:00
"gen_opt = optim.Adam(learning_rate=lr)\n",
2024-07-26 21:36:29 +08:00
"\n",
2024-07-26 21:07:40 +08:00
"disc = Discriminator()\n",
2024-07-26 21:36:29 +08:00
"mx.eval(disc.parameters())\n",
2024-07-26 21:07:40 +08:00
"disc_opt = optim.Adam(learning_rate=lr)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Losses"
]
},
{
"cell_type": "code",
2024-07-28 22:22:40 +08:00
"execution_count": 12,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-26 21:36:29 +08:00
"def disc_loss(gen, disc, real, num_images, z_dim):\n",
2024-07-26 21:07:40 +08:00
" noise = mx.array(get_noise(num_images, z_dim))\n",
" fake_images = gen(noise)\n",
2024-07-27 06:09:51 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" fake_disc = disc(fake_images)\n",
" \n",
2024-07-27 05:19:08 +08:00
" fake_labels = mx.zeros((fake_images.shape[0],1))\n",
2024-07-26 21:36:29 +08:00
" fake_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)\n",
2024-07-26 21:07:40 +08:00
" \n",
" real_disc = disc(real)\n",
2024-07-27 05:19:08 +08:00
" real_labels = mx.ones((real.shape[0],1))\n",
"\n",
2024-07-27 06:09:51 +08:00
" real_loss = nn.losses.binary_cross_entropy(real_disc,real_labels,with_logits=True)\n",
2024-07-26 21:07:40 +08:00
"\n",
" disc_loss = (fake_loss + real_loss) / 2\n",
"\n",
" return disc_loss"
]
},
{
"cell_type": "code",
2024-07-28 22:22:40 +08:00
"execution_count": 13,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-26 21:36:29 +08:00
"def gen_loss(gen, disc, num_images, z_dim):\n",
2024-07-26 21:07:40 +08:00
"\n",
" noise = mx.array(get_noise(num_images, z_dim))\n",
" fake_images = gen(noise)\n",
" fake_disc = disc(fake_images)\n",
"\n",
2024-07-27 05:19:08 +08:00
" fake_labels = mx.ones((fake_images.shape[0],1))\n",
2024-07-26 21:07:40 +08:00
" \n",
2024-07-26 21:36:29 +08:00
" gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)\n",
2024-07-26 21:07:40 +08:00
"\n",
" return gen_loss"
]
},
{
"cell_type": "code",
2024-07-28 22:22:40 +08:00
"execution_count": 14,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"train_images, _, test_images, _ = map(\n",
" mx.array, getattr(mnist, 'mnist')()\n",
")"
]
},
{
"cell_type": "code",
2024-07-28 22:22:40 +08:00
"execution_count": 15,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def batch_iterate(batch_size:int, ipt:list):\n",
" perm = mx.array(np.random.permutation(len(ipt)))\n",
" for s in range(0, ipt.size, batch_size):\n",
" ids = perm[s : s + batch_size]\n",
" yield ipt[ids]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### show batch of images"
]
},
{
"cell_type": "code",
2024-07-28 22:22:40 +08:00
"execution_count": 16,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
2024-07-28 22:22:40 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAUkAAAFICAYAAADd1gwNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAACKuklEQVR4nOy9aXOcV5bf+ct93/dM7CBBgCS4SpQoiSpVV6mryh1ut93taIfbS3iZ6HkxX2Mi/A38wo5wxEx4PHa1PV3dVW2VVFJJonZxBxcQINZEbsh93zPnBX0vAZKCSHFBAnh+EQhVAYnkkxfPc+65Z/kfVb/f76OgoKCg8FjUu30BCgoKCoOMYiQVFBQUdkAxkgoKCgo7oBhJBQUFhR1QjKSCgoLCDihGUkFBQWEHFCOpoKCgsAOKkVRQUFDYAcVIKigoKOyA9klfqFKpXuR17Cl+aJOSsoYPUNbw2VHW8Nl5kjVUPEkFBQWFHVCMpIKCgsIOKEZSQUFBYQcUI6mgoKCwA4qRVFBQUNgBxUgqKCgo7IBiJBUUFBR24InrJPcKT1IDpoixKygoPCl73kjqdDq0Wi1er5fh4WFMJhOhUAiTyYTRaMRgMEjD2Ww2mZ+fJ5VKkc1mSSQSdLvdXf4ECgoKg8yeNpIqlQqDwYDBYGB8fJwLFy7g8/k4e/YsPp8Pp9OJ0+mUry+Xy/zyl7/k6tWr3Llzh83NTXq9nuJZKigofCd72kiq1Wrsdjt2u51QKMTIyAhutxu3243NZsNisWAwGOTru90u4XCYcrlMPp/HZDKhUqlotVqKoTzAqNVq1Go1Wq0Wo9GIRqNBr9ej0+nk9zUaDRaLBZ1O90Tv2W63yeVyNJtNVCoVKpUKo9GIw+EAoFAoUK/XqdfrVCoV+v2+cqoZUPaskVSr1ej1eo4cOcLk5CSvvPIKv/jFLzCbzZhMJnljb8VgMHDhwgXOnj2LyWTixo0blEolcrkcrVZrlz6Jwm6j1+sxmUzY7XbGxsYwm80Eg0GcTicWiwWHw4HVamVmZgaXyyV/Txg/uB/n3rrR5nI53n//fZLJJFqtFq1Wy9DQEOfPn0ej0fDpp5+yvLzM8vIyN27coNVqUa/X6fV6L/3zK+zMnjSSYnc3GAw4nU4CgQCBQACfz7fNc4TtSRq1Wo3L5cLlcuH1erFarbRaLdTqwU7yb30Yxf/XaDTfm6RSq9WP/N7W/z5Mv99/5CHtdrt0u125jg8bg72KSqVCq9WiVquxWCzYbDZ5L1ksFoaGhvB4PFgsFlwuF3a7naNHj+LxeLa9x3cZyXQ6zcLCAhqNBp1Oh06nY3R0lJmZGTQaDRsbG7RaLcrlMjabjXq9TqvVGmgjKRwP4Xn3+31arda2+2M/sieNpM1m4/Dhw3g8Ht59911Onz6N3+9Hq93+cbrdrrzpxA0tDEcoFOL8+fPE43EuXrxIo9HYjY/yvahUKunRiJvTaDQyNjaGw+FArVY/4jGLB9dqteLz+bYZS6vVisVieey/lcvliMfjdLtdtFot/X6fe/fusby8LD2dTqdDpVKh2Wy+2A/+glCr1eh0OsxmMydOnMDn8zExMcHk5CQmkwmv14ter5ehGp1Oh8FgkN/bylbD8LCRsFgsnD9/ntnZWfl3s9vtWK1W1Go1Z86cYXx8nGPHjnHs2DESiQTvvfeeTCYOmrHUarWcPXuWw4cPY7PZCAaDVKtV3n//faLRKLVajXq9/kz/hkjCdjodOp0OMBiVKHvSSFosFiYmJgiHw5w5c4Zz58499nW9Xk8utjAU4svtdnPs2DEsFguXL19+mZf/VKhUKiwWC16vF41Gg1arxWq1cvr0aUKhkDzKPewdqlQqaQDEz1UqFR6PZ5s3tJX19XVu3rxJt9tFp9PR7/f59NNPaTQa1Ot1CoUCzWZTfu1FxCnEarVy9OhRDh06xOnTp3nllVfQarXo9fpt3uHDbP2eSqX6TkNpMpk4evTott/d+jeampqi3+8zOjrK6Ogoi4uLfP3112xubj7Wo99tNBoNhw8f5sKFCwSDQY4cOUI2m2VpaYl8Pk+3230mIyk8e71eDzBQ3umeMpLCE7RarUxOTjI8PIzdbn/kdb1ej263y/r6OrFYbJuH4Ha7MZlMu3D1T45KpcJms3Ho0CFsNhujo6MEAgHpBRkMBg4dOoTT6USj0TziSYr3sNvtOBwO+XOVSiVvwsdhNpsJhUL0ej3pSZ44cQKTyUSj0aBcLlOtVrl48SLLy8t76ugtvDm/38+hQ4fweDwcP36ckZERfD6fPEo+jdZit9uVm4U4hsL9+0889Fvfb+vr9Xo9Go0Go9GI2+0mFApx/PhxTCYT0WiURCIxcOsrrkWr1Uqv+syZM1itVu7evcv8/DzdbleGDZ7GI9ZqtUxOThKJRKhWq5TLZZrNJsVikVar9Vw81R/KnjKSwpPy+/1cuHCB0dFRQqHQttf0+33a7TatVosrV67w0UcfEQgEeOWVV3C5XDJIP6iIhzkYDPJnf/ZnDA8PMzMzw/DwsDxaiwfwccZRIDzHh1+zU/xVxN76/b58uA8dOkSn06HdblOv10mn0+RyOaLRKN1uV3rqg444yh06dIh//I//MYFAgLNnz8owzZNmrbfSbrcplUp0u13MZjM6nU4aB41Gg9ls3rb+7XabQqEAgNPplMkii8WC1Wrl5z//OYlEgt/+9rdkMpmBXV/hbHg8Hv7sz/6MQqHAr3/9ayqVCvV6nWKxKO+XJ02IiqTqW2+9JcM+5XKZ+fl58vk8sVhMMZJPgtVqxeFw4PP5cDgc2Gw26fGIo3Wn06FUKlGv10mlUjK72Gg0Bj4wDsiElM1mw+fzEQqF8Hq9uFwuaUCfFFFWstUbabVadDqdbTHarYF4sT4iPGEymVCr1XS7XUwmE51OZ9uRdC+gUqkwm81YrVY8Hg+BQAC/34/NZtu2YT7uM4k16Xa722Kx/X6fRqMhj5pmsxm9Xi+9J5EQ2honbzQaFAoFVCqVTOZsLT3y+/30+30ZtxwkL3Ir4r7RarU4HA60Wi3BYJDh4WHK5TI6nY5Go0Gv13tiI6lSqbBarbhcLrmxVCoVKpUKVquVarVKNpuVf4uXyZ4xkmq1mnPnzvGTn/yE4eFhxsfHZeKi0+lQLBZJpVKUSiWuXbtGOp3m66+/5tq1a5RKJc6cOYPNZhvInXkrHo+H4eFhpqenOXnyJGNjY1gslqc+CgLUajUKhYI8tvX7fTY3N0mlUuj1eqxWKzqdDpvNhtFolMkZYVS0Wi1Op1Ous9FoxGg0ygd/rxhKrVbLiRMnOHr0KMePH+fs2bPYbDZsNtv3/m673aZarZLL5XjvvfeIRqM0Gg2azSatVotSqUSv18NoNKLT6bYZya1rBdDpdKjX65hMJv7sz/6MEydOyMy6xWLhxIkTVCoVLl26hNlsptls0m63B9ZYajQanE4nNpuNd999l5mZGbLZLDdv3iSbzfL73/+eW7duPdF7CSPpdrsJh8O88sordLtdisUitVqNX/7yl1SrVer1utyYXhZ7wkiKY2M4HObUqVN4vV4cDgcmk4l2uy2DxoVCgUwmw+LiIvF4nJWVFRKJBIFAgHa7PfAGEu4H/EVyxefz4fV6t/38ux6Yx31fPODiwe33+6RSKTY2NtDr9bhcLnQ6HZ1OB4vFIgub1Wo17XYbvV6P2WwGHvwN9pqBhPsbrM/nY3JykpGREfx+/xOHXDqdDo1Gg2KxyPz8PPPz81SrVWq1Gu12m0qlQq/Xw2AwyFNNt9uVdbxbjaTwrBwOBxcuXKDRaMgYsVarxefzyeYIkeV9ODm022yNk4qON4DR0VGCwSCbm5s0Gg3sdjuXLl164vcV3rXRaMTj8RAOh1GpVLTbbZrNJl999RUWi0XGe18mA2skdTodJpMJq9XKiRMn8Pv9vPXWW4yPj8vOh16vRzabpVQqMTc3x8cff0yhUGBlZYVSqUQ2mwXu34B2ux2n07lj4mIQ0Ov12Gw2rFarPHYIj6JYLBKLxWi32/L1rVZLhhKKxSKdTkcen4vFIslkUh65+/0+2WyWXC4nPUhxY+p0OulJWiw
2024-07-26 21:07:40 +08:00
"text/plain": [
"<Figure size 400x400 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for X in batch_iterate(16, train_images):\n",
" fig,axes = plt.subplots(4, 4, figsize=(4, 4))\n",
"\n",
" for i, ax in enumerate(axes.flat):\n",
" img = mx.array(X[i]).reshape(28,28)\n",
" ax.imshow(img,cmap='gray')\n",
" ax.axis('off')\n",
" break"
]
},
{
"cell_type": "code",
2024-07-28 22:22:40 +08:00
"execution_count": 17,
2024-07-28 06:10:19 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-28 22:22:40 +08:00
"def show_images(imgs:list[int],num_imgs:int = 25):\n",
" fig,axes = plt.subplots(5, 5, figsize=(4, 4))\n",
" \n",
2024-07-28 06:10:19 +08:00
" for i, ax in enumerate(axes.flat):\n",
" img = mx.array(imgs[i]).reshape(28,28)\n",
" ax.imshow(img,cmap='gray')\n",
" ax.axis('off')\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 18,
2024-07-27 06:09:51 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2024-07-28 22:22:40 +08:00
"array(0.738084, dtype=float32)"
2024-07-27 06:09:51 +08:00
]
},
2024-07-28 06:10:19 +08:00
"execution_count": 18,
2024-07-27 06:09:51 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"z_dim = 64\n",
"gen = Generator(z_dim)\n",
"mx.eval(gen.parameters())\n",
"gen_opt = optim.Adam(learning_rate=lr)\n",
"\n",
"disc = Discriminator()\n",
"mx.eval(disc.parameters())\n",
"disc_opt = optim.Adam(learning_rate=lr)\n",
"\n",
"g_loss = gen_loss(gen, disc, 8, z_dim)\n",
"g_loss\n"
]
},
{
"cell_type": "code",
2024-07-28 22:22:40 +08:00
"execution_count": 19,
2024-07-26 21:07:40 +08:00
"metadata": {},
2024-07-27 06:20:00 +08:00
"outputs": [
{
2024-07-28 06:10:19 +08:00
"data": {
"text/plain": [
"60000"
]
},
2024-07-28 22:22:40 +08:00
"execution_count": 19,
2024-07-28 06:10:19 +08:00
"metadata": {},
"output_type": "execute_result"
2024-07-27 06:20:00 +08:00
}
],
2024-07-27 06:09:51 +08:00
"source": [
2024-07-28 06:10:19 +08:00
"len(train_images)"
2024-07-27 06:09:51 +08:00
]
},
{
"cell_type": "code",
2024-07-28 22:22:40 +08:00
"execution_count": 22,
2024-07-27 06:09:51 +08:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-28 22:22:40 +08:00
" 0%| | 0/200 [00:00<?, ?it/s]"
2024-07-27 05:19:08 +08:00
]
2024-07-26 21:07:40 +08:00
}
],
"source": [
2024-07-28 06:10:19 +08:00
"batch_size = 16\n",
2024-07-28 22:22:40 +08:00
"display_step = 50\n",
2024-07-28 06:10:19 +08:00
"cur_step = 0\n",
"mean_generator_loss = 0\n",
"mean_discriminator_loss = 0\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-28 06:10:19 +08:00
"D_loss_grad = nn.value_and_grad(disc, disc_loss)\n",
"G_loss_grad = nn.value_and_grad(gen, gen_loss)\n",
"\n",
"\n",
"for epoch in tqdm(range(200)):\n",
"\n",
2024-07-28 22:22:40 +08:00
" for real in batch_iterate(batch_size, train_images):\n",
2024-07-28 06:10:19 +08:00
" \n",
" D_loss,D_grads = D_loss_grad(gen, disc, real, batch_size, z_dim)\n",
"\n",
" # Update optimizer\n",
" disc_opt.update(disc, D_grads)\n",
" \n",
" # Update gradients\n",
" mx.eval(disc.parameters(), disc_opt.state)\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-28 06:10:19 +08:00
" G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)\n",
" \n",
" # Update optimizer\n",
" gen_opt.update(gen, G_grads)\n",
" \n",
" # Update gradients\n",
" mx.eval(gen.parameters(), gen_opt.state)\n",
" \n",
" if cur_step % display_step == 0 and cur_step > 0:\n",
" print(f\"Step {epoch}: Generator loss: {G_loss}, discriminator loss: {D_loss}\")\n",
" fake_noise = mx.array(get_noise(batch_size, z_dim))\n",
" fake = gen(fake_noise)\n",
" show_images(fake)\n",
2024-07-28 22:22:40 +08:00
" show_images(real)\n",
" cur_step += 1"
2024-07-27 05:19:08 +08:00
]
2024-07-26 21:07:40 +08:00
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}