Updating GAN Code...

This commit is contained in:
Shubbair 2024-07-27 01:19:50 +03:00
parent 147cb3d2bc
commit f8b7094fb8

View File

@ -402,29 +402,20 @@
},
{
"cell_type": "code",
"execution_count": 74,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/1 [00:00<?, ?it/s]"
]
}
],
"outputs": [],
"source": [
"batch_size = 8\n",
"cur_step = 0\n",
"mean_generator_loss = 0\n",
"mean_discriminator_loss = 0\n",
"error = False\n",
"\n",
"D_loss_grad = nn.value_and_grad(disc, disc_loss)\n",
"G_loss_grad = nn.value_and_grad(gen, gen_loss)\n",
"\n",
"\n",
"for epoch in tqdm(range(1)):\n",
"for epoch in tqdm(range(5)):\n",
" \n",
" for real in batch_iterate(batch_size, train_images):\n",
" \n",
@ -442,6 +433,8 @@
" disc_opt.update(disc, D_grads)\n",
" \n",
" # Update gradients\n",
" mx.eval(disc.parameters(), disc_opt.state)\n",
"\n",
" \n",
" G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)\n",
" \n",
@ -449,7 +442,9 @@
" gen_opt.update(gen, G_grads)\n",
" \n",
" # Update gradients\n",
" mx.eval(gen.parameters(), gen_opt.state)\n",
"\n",
" print(f\"Step {epoch}: Generator loss: {G_loss}, discriminator loss: {D_loss}\")\n",
"\n",
" # # Keep track of the average discriminator loss\n",
" # mean_discriminator_loss += disc_loss.item() / display_step\n",