mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Updating GAN Code...
This commit is contained in:
parent
147cb3d2bc
commit
f8b7094fb8
@ -402,29 +402,20 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 74,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" 0%| | 0/1 [00:00<?, ?it/s]"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"batch_size = 8\n",
|
"batch_size = 8\n",
|
||||||
"cur_step = 0\n",
|
"cur_step = 0\n",
|
||||||
"mean_generator_loss = 0\n",
|
"mean_generator_loss = 0\n",
|
||||||
"mean_discriminator_loss = 0\n",
|
"mean_discriminator_loss = 0\n",
|
||||||
"error = False\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"D_loss_grad = nn.value_and_grad(disc, disc_loss)\n",
|
"D_loss_grad = nn.value_and_grad(disc, disc_loss)\n",
|
||||||
"G_loss_grad = nn.value_and_grad(gen, gen_loss)\n",
|
"G_loss_grad = nn.value_and_grad(gen, gen_loss)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"for epoch in tqdm(range(1)):\n",
|
"for epoch in tqdm(range(5)):\n",
|
||||||
" \n",
|
" \n",
|
||||||
" for real in batch_iterate(batch_size, train_images):\n",
|
" for real in batch_iterate(batch_size, train_images):\n",
|
||||||
" \n",
|
" \n",
|
||||||
@ -442,6 +433,8 @@
|
|||||||
" disc_opt.update(disc, D_grads)\n",
|
" disc_opt.update(disc, D_grads)\n",
|
||||||
" \n",
|
" \n",
|
||||||
" # Update gradients\n",
|
" # Update gradients\n",
|
||||||
|
" mx.eval(disc.parameters(), disc_opt.state)\n",
|
||||||
|
"\n",
|
||||||
" \n",
|
" \n",
|
||||||
" G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)\n",
|
" G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)\n",
|
||||||
" \n",
|
" \n",
|
||||||
@ -449,7 +442,9 @@
|
|||||||
" gen_opt.update(gen, G_grads)\n",
|
" gen_opt.update(gen, G_grads)\n",
|
||||||
" \n",
|
" \n",
|
||||||
" # Update gradients\n",
|
" # Update gradients\n",
|
||||||
|
" mx.eval(gen.parameters(), gen_opt.state)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
" print(f\"Step {epoch}: Generator loss: {G_loss}, discriminator loss: {D_loss}\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # # Keep track of the average discriminator loss\n",
|
" # # Keep track of the average discriminator loss\n",
|
||||||
" # mean_discriminator_loss += disc_loss.item() / display_step\n",
|
" # mean_discriminator_loss += disc_loss.item() / display_step\n",
|
||||||
|
Loading…
Reference in New Issue
Block a user