mlx-examples/gan/playground.ipynb

479 lines
66 KiB
Plaintext
Raw Normal View History

2024-07-26 21:07:40 +08:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Import Library"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import mnist"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import mlx.core as mx\n",
"import mlx.nn as nn\n",
"import mlx.optimizers as optim\n",
"\n",
"from tqdm import tqdm\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# GAN Architecture"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generator 👨🏻‍🎨"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def GenBlock(in_dim:int,out_dim:int):\n",
" \n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
" nn.BatchNorm(out_dim),\n",
" nn.ReLU()\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class Generator(nn.Module):\n",
"\n",
" def __init__(self, z_dim:int = 10, im_dim:int = 784, hidden_dim: int =128):\n",
" super(Generator, self).__init__()\n",
" # Build the neural network\n",
" self.gen = nn.Sequential(\n",
" GenBlock(z_dim, hidden_dim),\n",
" GenBlock(hidden_dim, hidden_dim * 2),\n",
" GenBlock(hidden_dim * 2, hidden_dim * 4),\n",
" GenBlock(hidden_dim * 4, hidden_dim * 8),\n",
"\n",
"\n",
" nn.Linear(hidden_dim * 8,im_dim),\n",
" nn.Sigmoid()\n",
" )\n",
" \n",
" def __call__(self, noise):\n",
"\n",
" return self.gen(noise)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Generator(\n",
" (gen): Sequential(\n",
" (layers.0): Sequential(\n",
" (layers.0): Linear(input_dims=100, output_dims=128, bias=True)\n",
" (layers.1): BatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): ReLU()\n",
" )\n",
" (layers.1): Sequential(\n",
" (layers.0): Linear(input_dims=128, output_dims=256, bias=True)\n",
" (layers.1): BatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): ReLU()\n",
" )\n",
" (layers.2): Sequential(\n",
" (layers.0): Linear(input_dims=256, output_dims=512, bias=True)\n",
" (layers.1): BatchNorm(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): ReLU()\n",
" )\n",
" (layers.3): Sequential(\n",
" (layers.0): Linear(input_dims=512, output_dims=1024, bias=True)\n",
" (layers.1): BatchNorm(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): ReLU()\n",
" )\n",
" (layers.4): Linear(input_dims=1024, output_dims=784, bias=True)\n",
" (layers.5): Sigmoid()\n",
" )\n",
")"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gen = Generator(100)\n",
"gen"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def get_noise(n_samples, z_dim):\n",
"\n",
" return np.random.randn(n_samples,z_dim)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Discriminator 🕵🏻‍♂️"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"def DisBlock(in_dim:int,out_dim:int):\n",
" \n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
" nn.LeakyReLU(negative_slope=0.2)\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"class Discriminator(nn.Module):\n",
"\n",
" def __init__(self,im_dim:int = 784, hidden_dim:int = 128):\n",
" super(Discriminator, self).__init__()\n",
"\n",
" self.disc = nn.Sequential(\n",
" DisBlock(im_dim, hidden_dim * 4),\n",
" DisBlock(hidden_dim * 4, hidden_dim * 2),\n",
" DisBlock(hidden_dim * 2, hidden_dim),\n",
"\n",
" nn.Linear(hidden_dim,1),\n",
" )\n",
" \n",
" def __call__(self, noise):\n",
"\n",
" return self.disc(noise)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Discriminator(\n",
" (disc): Sequential(\n",
" (layers.0): Sequential(\n",
" (layers.0): Linear(input_dims=784, output_dims=512, bias=True)\n",
" (layers.1): LeakyReLU()\n",
" )\n",
" (layers.1): Sequential(\n",
" (layers.0): Linear(input_dims=512, output_dims=256, bias=True)\n",
" (layers.1): LeakyReLU()\n",
" )\n",
" (layers.2): Sequential(\n",
" (layers.0): Linear(input_dims=256, output_dims=128, bias=True)\n",
" (layers.1): LeakyReLU()\n",
" )\n",
" (layers.3): Linear(input_dims=128, output_dims=1, bias=True)\n",
" )\n",
")"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"disc = Discriminator()\n",
"disc"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model Training 🏋🏻‍♂️"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# Set your parameters\n",
"criterion = nn.losses.binary_cross_entropy\n",
"n_epochs = 200\n",
"z_dim = 64\n",
"display_step = 500\n",
"batch_size = 128\n",
"lr = 0.00001"
]
},
{
"cell_type": "code",
2024-07-26 21:36:29 +08:00
"execution_count": 197,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"gen = Generator(z_dim)\n",
2024-07-26 21:36:29 +08:00
"mx.eval(gen.parameters())\n",
2024-07-26 21:07:40 +08:00
"gen_opt = optim.Adam(learning_rate=lr)\n",
2024-07-26 21:36:29 +08:00
"\n",
2024-07-26 21:07:40 +08:00
"disc = Discriminator()\n",
2024-07-26 21:36:29 +08:00
"mx.eval(disc.parameters())\n",
2024-07-26 21:07:40 +08:00
"disc_opt = optim.Adam(learning_rate=lr)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Losses"
]
},
{
"cell_type": "code",
2024-07-26 21:36:29 +08:00
"execution_count": 198,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-26 21:36:29 +08:00
"def disc_loss(gen, disc, real, num_images, z_dim):\n",
2024-07-26 21:07:40 +08:00
" noise = mx.array(get_noise(num_images, z_dim))\n",
" fake_images = gen(noise)\n",
" \n",
" fake_disc = disc(fake_images)\n",
" \n",
" fake_labels = mx.zeros((len(fake_images),1))\n",
2024-07-26 21:36:29 +08:00
" fake_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)\n",
2024-07-26 21:07:40 +08:00
" \n",
" real_disc = disc(real)\n",
" real_labels = mx.ones((len(real),1))\n",
2024-07-26 21:36:29 +08:00
" real_loss = nn.losses.binary_cross_entropy(real_disc,real_labels,with_logits=True)\n",
2024-07-26 21:07:40 +08:00
"\n",
" disc_loss = (fake_loss + real_loss) / 2\n",
"\n",
" return disc_loss"
]
},
{
"cell_type": "code",
2024-07-26 21:36:29 +08:00
"execution_count": 199,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-26 21:36:29 +08:00
"def gen_loss(gen, disc, num_images, z_dim):\n",
2024-07-26 21:07:40 +08:00
"\n",
" noise = mx.array(get_noise(num_images, z_dim))\n",
" fake_images = gen(noise)\n",
" \n",
" fake_disc = disc(fake_images)\n",
"\n",
" fake_labels = mx.ones((fake_images.size(0),1))\n",
" \n",
2024-07-26 21:36:29 +08:00
" gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)\n",
2024-07-26 21:07:40 +08:00
"\n",
" return gen_loss"
]
},
{
"cell_type": "code",
2024-07-26 21:36:29 +08:00
"execution_count": 200,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"train_images, _, test_images, _ = map(\n",
" mx.array, getattr(mnist, 'mnist')()\n",
")"
]
},
{
"cell_type": "code",
2024-07-26 21:36:29 +08:00
"execution_count": 201,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def batch_iterate(batch_size:int, ipt:list):\n",
" perm = mx.array(np.random.permutation(len(ipt)))\n",
" for s in range(0, ipt.size, batch_size):\n",
" ids = perm[s : s + batch_size]\n",
" yield ipt[ids]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### show batch of images"
]
},
{
"cell_type": "code",
2024-07-26 21:36:29 +08:00
"execution_count": 202,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
2024-07-26 21:36:29 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAUkAAAFICAYAAADd1gwNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAACX60lEQVR4nOy992+cWXan/1TOOVcxB5GicupR556R2x6PPV4H7Ngwdm0Y2AX2f1qsscB60+xixmv7O7Fn2p2mk1o5MpPFYuWc8/eHxr1dVGCrWyJVpN4HINQtFkv1Xr7vueee8Dmqfr/fR0FBQUHhkaif9wdQUFBQGGYUI6mgoKCwC4qRVFBQUNgFxUgqKCgo7IJiJBUUFBR2QTGSCgoKCrugGEkFBQWFXVCMpIKCgsIuKEZSQUFBYRe0T/pClUq1l5/jQPFtm5SUNfwKZQ2fHmUNn54nWUPFk1RQUFDYBcVIKigoKOyCYiQVFBQUdkExkgoKCgq7oBhJBQUFhV1QjKSCgoLCLihGUkFBQWEXFCOpoKCgsAuKkVRQUFDYBcVIKigoKOzCE7clPgl6vR6j0YjBYMDlcqHX69Hr9eh0uke+vl6v02g06PV6dDod+v0+3W6XXq9Hq9WiVqs9sm1IvG7wTwUFBYW94JkaSY/Hw+joKOFwmEuXLuHz+QgGg7hcLvka0Tfa7/fZ2NhgbW2NVqtFuVym3W5TqVRoNBokEgmWlpZot9sP/TudTodqtUqn06FWq9FqtZ7lZSgoKChInomRVKvVqFQqLBYLHo+HYDDI9PQ0wWCQSCSC1+uVrx1srjeZTKhUKprNJsVikXa7TalUol6vA5DNZh9pADudDlqtlna7Lb1I4YEqXqXCoxD3qEajQa/XP/T9fr8vvwQqlYp+v7/jvhL/3+v15P8rfHtUKhVarRaVSiV/R+K/+/2+PGEKBn8fD/6+9oqnNpJarZZgMIjVauXVV1/le9/7Hm63myNHjmC1WjGbzTte3+/35c3n9XrRarV0Oh2azSb9fp9Wq0Wn06FSqZDL5eTNKH4GoNlsUigUqNfr3Lt3j3g8TiKRYHl5mU6nQ6vVUm5eBYlarSYUCuFyuZidneW1117DZDKhVn8Zkm82mzQaDRqNBrlcjk6ng8lkQqfTUS6XSafTtFotGo0G7XabeDxOPB6n2+3Sbrfp9Xq022263e5zvtKDh9vt5uTJk9jtdvx+P1arFafTic/no1arcePGDXK5nHx9tVollUrRaDRIJpNUKpU9/4xPbSQ1Gg1erxefz8fJkye5dOkSJpNpx034qN0ZwOFwYLfbv/bfGPQ++/0+9XqdXC5HtVrF6/WyvLzMnTt3iMVi1Ot1Op2OcsMqSFQqFS6Xi4mJCb7zne/wt3/7tzgcDjQaDSqVinK5LL82NzdpNps4HA5MJhOJRILV1VUajQaFQoFGo4FWq5XhoUajQbfblV8K3wybzcaJEycIBALMzMzg9/sJhULMzMyQy+X453/+Z6LRqHx9Lpfj/v37lMtlSqXSwTCSBoOBY8eOMTc3x8zMDAaDQbrPT8Kg0Xzc9x/8f61Wi9lsRq1WMzk5iclkwul04nQ6KRaL3L59m0KhQK1Wo1arPdX1PW/UajWBQAC3243D4SAUCtHpdFhZWaFYLNJsNqUX/jj0ej1+vx+j0Sj/ThxrtFotHo8Hq9VKq9WSybStrS1qtdq+3Yh7Sb/fp1gssr29TTKZlJ6JzWZDr9ej1Wrl2gQCAelJ6vV61Go1Go2GdrtNrVaj3W7j8/mYnJyk3W5TLpdpNptsbW2Ry+Uol8tkMhnlJDOA0+kkGAxiMBiw2WzodDosFgsWiwW/38/Zs2dxOBwEg0Hsdjt2ux2VSoXRaGRqagqHwyHfq1KpEAqFyOfzFItF0un0nn/+pzaSVquV73//+/ze7/0eJpMJi8Wy56KeOp1OeqFut5tut0u9XqdcLrO1tcV/+S//hcXFRaLR6IE3khqNhuPHj3Pq1Cnm5+e5dOkStVqNf/iHf+DOnTtks9mvfShtNhuvvfYawWBQ/p1arZbG4dy5c0xMTFAoFEilUqRSKX7+858TjUZZWlo68Eay1+tJAxkIBKS3qNVq0el0sgLDZrPhdruBrzaRSCTC/Pz8jjhYt9ul0+lQr9dJp9OUy2U++OADlpaWWFpaolAoPDLh+KISiUS4dOkSTqeT6elpnE4no6OjjIyMyBixRqNBrVbv+LJarbz88ss7PHQRD06lUty/f5/bt2/v+ed/ZtltEXQVgVWRcBkMroobT7wevrxosSji+4Nfu/178KURAeTN3mw2CYVCVCoVisUiiURi3wK8e4FKpcJgMGCxWLDb7Xi9XhqNBuFwmHK5jNlsxmKxfK2RjEQiBAKBHe+r0+kwGAz4/X68Xu+OUq1IJEK/3yebzZJMJg98rFcch2u1GtlsFp1Oh9frlWGhwYTBIFrtzkdk8PpFgtFkMhEMBqlUKmQyGXQ6Hb1e74U8fqtUKumBC288FAoRDodxuVyEQiHp3FitVvr9Pu12m3a7TbVa3bG5qNVqdDqdNJhWq1V+r9FoPLa08Fnz1Eay2Wxy584dzGYz4+PjTE9P02w2icVi1Go1Op0OnU5Hei5i8YRx6/f70v3WarXyuC5qLp/UK9VqtfJm/eEPf0gmk+HHP/4xsViMdrv9tUfSYUasnVgzq9XKH/7hH/L666/TarVoNpu7/rxOp8PtdmMwGICvkmfiy+l0AmA2m2WCw263UyqV+PWvf43ZbCabzbK8vPy1/9awE4vF+PWvf00oFEKtVjMxMYHVasVms33j9zIYDLjdbux2O6+//jonTpzAYrFw+/ZtGePsdDp7cBXDi9lsZmRkBJvNxvnz55mZmWF0dJRTp05hNBoxm81otVqKxaI8pWxtbVEqlfj8889ZW1uT72UwGBgdHcVut/MHf/AHvPXWW89l9MRTG8lut0smkyEajWK1WhkZGaFWq5HJZKhWqzQaDVqt1g632mKxoNPppNHS6/X0ej10Oh3dbleWaOj1+h0P9KBH+iAivmaxWJieniYQCBAMBuWu/nWxz4OA8NK1Wi0TExNP9T4P0uv10Gg0mM1mTCaTjFEuLy9z9+5d+f2DTrVaZWNjg2azSTabxePxyE36wQQh8NCJRtyPgNzw+/2+LHW7efMmZrOZZrP5kFd62BFepMvlwuVyMT8/z+nTpwkEAkxPT6PVauU9nM/nyWaz5PN5VldXyeVy/O53v+PWrVvy/cxmM0eOHMHn83Hq1KmH/j1xAt3rk+JTG8lGo8HVq1fZ3t7m3r17XL16lVarJdP09Xpd3jDiITMYDDtuPL1ej91ul56kOAaazWZ0Oh0ej0fuQiaTCZvNxsjICAaDQR6VBGq1GovFglqtxu124/f7ZXBdlBMdJLrdLisrK7RaLdLptLw+nU6HRqPB5XLh8Xjk2j3qwex0OuTzeRqNBtVq9bFxWrfbTTgclp68TqcjEAgwPz+PVqvl5s2be325e06lUmFtbY10Oo1WqyUQCDA6OsrY2Jh8Tb/fJ5PJUCqV8Hg8TExMoNfrMRgMaDQafD7fjtpfYEcNptFoxGg0vlBGcnR0lPHxcYLBIC+//DIej4fZ2VlCoRBWqxW1Wk25XOaTTz4hkUiwsbHBxsYG1WpVVqp8kySM0WjktddeQ6/Xs7i4yJUrV/bMa38mRvLKlSuoVCocDgcul4tut0ur1aLb7coOmkEe9ASFByhuMq1WK5MKFouF2dlZnE6nLDUSR0KNRoNOp9vxfhqNBqvVisFgwOv1EgwG0Wq1ZDKZA3n06Xa7LC4usrKywtbWFq1WS3p6er2emZkZjhw5gl6vx+l0PjJOU6/XWV1dpVAokE6nH3szzs7OYjKZMJvNcrMKhUIsLCzQarX2LQa0l5TLZSqVChqNhrW1NfR6PXNzcxw/flzeR91ulzt37rC1tcXRo0d54403sFgsOBwOjEYjR48exePxyNeLWKa4H180I6lSqRgfH+fNN99kYmKCH/7wh3i93odyC8VikZ///OfcuHGD9fV1Njc3dxTqfxMnxmQy8dZbb3H06FH+6Z/+iRs3bgyvkYS
2024-07-26 21:07:40 +08:00
"text/plain": [
"<Figure size 400x400 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for X in batch_iterate(16, train_images):\n",
" fig,axes = plt.subplots(4, 4, figsize=(4, 4))\n",
"\n",
" for i, ax in enumerate(axes.flat):\n",
" img = mx.array(X[i]).reshape(28,28)\n",
" ax.imshow(img,cmap='gray')\n",
" ax.axis('off')\n",
" break"
]
},
{
"cell_type": "code",
2024-07-26 21:36:29 +08:00
"execution_count": 203,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"0it [00:00, ?it/s]\n"
]
},
{
"ename": "TypeError",
2024-07-26 21:36:29 +08:00
"evalue": "'bool' object is not callable",
2024-07-26 21:07:40 +08:00
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
2024-07-26 21:36:29 +08:00
"Cell \u001b[0;32mIn[203], line 28\u001b[0m\n\u001b[1;32m 23\u001b[0m disc_opt\u001b[38;5;241m.\u001b[39mupdate(disc, D_grads)\n\u001b[1;32m 25\u001b[0m \u001b[38;5;66;03m# Update gradients\u001b[39;00m\n\u001b[0;32m---> 28\u001b[0m G_loss,G_grads \u001b[38;5;241m=\u001b[39m \u001b[43mG_loss_grad\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgen\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdisc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mz_dim\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 30\u001b[0m \u001b[38;5;66;03m# Update optimizer\u001b[39;00m\n\u001b[1;32m 31\u001b[0m gen_opt\u001b[38;5;241m.\u001b[39mupdate(gen, G_grads)\n",
"File \u001b[0;32m~/miniforge3/lib/python3.10/site-packages/mlx/nn/utils.py:34\u001b[0m, in \u001b[0;36mvalue_and_grad.<locals>.wrapped_value_grad_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(fn)\n\u001b[1;32m 33\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapped_value_grad_fn\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m---> 34\u001b[0m value, grad \u001b[38;5;241m=\u001b[39m \u001b[43mvalue_grad_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrainable_parameters\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 35\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m value, grad\n",
"File \u001b[0;32m~/miniforge3/lib/python3.10/site-packages/mlx/nn/utils.py:28\u001b[0m, in \u001b[0;36mvalue_and_grad.<locals>.inner_fn\u001b[0;34m(params, *args, **kwargs)\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21minner_fn\u001b[39m(params, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 27\u001b[0m model\u001b[38;5;241m.\u001b[39mupdate(params)\n\u001b[0;32m---> 28\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[0;31mTypeError\u001b[0m: 'bool' object is not callable"
2024-07-26 21:07:40 +08:00
]
}
],
"source": [
"batch_size = 8\n",
"cur_step = 0\n",
"mean_generator_loss = 0\n",
"mean_discriminator_loss = 0\n",
"test_generator = True # Whether the generator should be tested\n",
"gen_loss = False\n",
"error = False\n",
"\n",
2024-07-26 21:36:29 +08:00
"D_loss_grad = nn.value_and_grad(disc, disc_loss)\n",
"G_loss_grad = nn.value_and_grad(gen, gen_loss)\n",
"\n",
"\n",
2024-07-26 21:07:40 +08:00
"for epoch in range(n_epochs):\n",
" \n",
" # Dataloader returns the batches\n",
" for real in tqdm(batch_iterate(batch_size, train_images)):\n",
"\n",
" # Flatten the batch of real images from the dataset\n",
" \n",
2024-07-26 21:36:29 +08:00
" D_loss,D_grads = D_loss_grad(gen, disc, real, batch_size, z_dim)\n",
2024-07-26 21:07:40 +08:00
"\n",
" # Update optimizer\n",
2024-07-26 21:36:29 +08:00
" disc_opt.update(disc, D_grads)\n",
2024-07-26 21:07:40 +08:00
" \n",
2024-07-26 21:36:29 +08:00
" # Update gradients\n",
" \n",
" \n",
" G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)\n",
" \n",
" # Update optimizer\n",
" gen_opt.update(gen, G_grads)\n",
" \n",
" # Update gradients\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-26 21:36:29 +08:00
" \n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-26 21:36:29 +08:00
" # # Keep track of the average discriminator loss\n",
" # mean_discriminator_loss += disc_loss.item() / display_step\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-26 21:36:29 +08:00
" # # Keep track of the average generator loss\n",
" # mean_generator_loss += gen_loss.item() / display_step\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-26 21:36:29 +08:00
" # ### Visualization code ###\n",
" # if cur_step % display_step == 0 and cur_step > 0:\n",
" # print(f\"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}\")\n",
" # fake_noise = get_noise(cur_batch_size, z_dim, device=device)\n",
" # fake = gen(fake_noise)\n",
" # show_tensor_images(fake)\n",
" # show_tensor_images(real)\n",
" # mean_generator_loss = 0\n",
" # mean_discriminator_loss = 0\n",
" # cur_step += 1\n"
2024-07-26 21:07:40 +08:00
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}