Updating GAN Code...

This commit is contained in:
Shubbair 2024-07-27 01:20:00 +03:00
parent f8b7094fb8
commit 8b1713737a

View File

@ -402,9 +402,29 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 77,
"metadata": {},
"outputs": [],
"outputs": [
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[77], line 22\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m5\u001b[39m):\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m real \u001b[38;5;129;01min\u001b[39;00m batch_iterate(batch_size, train_images):\n\u001b[1;32m 13\u001b[0m \n\u001b[1;32m 14\u001b[0m \u001b[38;5;66;03m# real = real.reshape(-1)\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[38;5;66;03m# print(len(real))\u001b[39;00m\n\u001b[1;32m 20\u001b[0m \u001b[38;5;66;03m# break\u001b[39;00m\n\u001b[0;32m---> 22\u001b[0m D_loss,D_grads \u001b[38;5;241m=\u001b[39m \u001b[43mD_loss_grad\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgen\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdisc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreal\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mz_dim\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;66;03m# Update optimizer\u001b[39;00m\n\u001b[1;32m 25\u001b[0m disc_opt\u001b[38;5;241m.\u001b[39mupdate(disc, D_grads)\n",
"File \u001b[0;32m~/miniforge3/lib/python3.10/site-packages/mlx/nn/utils.py:34\u001b[0m, in \u001b[0;36mvalue_and_grad.<locals>.wrapped_value_grad_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(fn)\n\u001b[1;32m 33\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapped_value_grad_fn\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m---> 34\u001b[0m value, grad \u001b[38;5;241m=\u001b[39m \u001b[43mvalue_grad_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrainable_parameters\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 35\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m value, grad\n",
"File \u001b[0;32m~/miniforge3/lib/python3.10/site-packages/mlx/nn/utils.py:28\u001b[0m, in \u001b[0;36mvalue_and_grad.<locals>.inner_fn\u001b[0;34m(params, *args, **kwargs)\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21minner_fn\u001b[39m(params, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 27\u001b[0m model\u001b[38;5;241m.\u001b[39mupdate(params)\n\u001b[0;32m---> 28\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"Cell \u001b[0;32mIn[71], line 3\u001b[0m, in \u001b[0;36mdisc_loss\u001b[0;34m(gen, disc, real, num_images, z_dim)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdisc_loss\u001b[39m(gen, disc, real, num_images, z_dim):\n\u001b[1;32m 2\u001b[0m noise \u001b[38;5;241m=\u001b[39m mx\u001b[38;5;241m.\u001b[39marray(get_noise(num_images, z_dim))\n\u001b[0;32m----> 3\u001b[0m fake_images \u001b[38;5;241m=\u001b[39m \u001b[43mgen\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnoise\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m fake_disc \u001b[38;5;241m=\u001b[39m disc(fake_images)\n\u001b[1;32m 7\u001b[0m fake_labels \u001b[38;5;241m=\u001b[39m mx\u001b[38;5;241m.\u001b[39mzeros((fake_images\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m],\u001b[38;5;241m1\u001b[39m))\n",
"Cell \u001b[0;32mIn[5], line 19\u001b[0m, in \u001b[0;36mGenerator.__call__\u001b[0;34m(self, noise)\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, noise):\n\u001b[0;32m---> 19\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgen\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnoise\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniforge3/lib/python3.10/site-packages/mlx/nn/layers/containers.py:23\u001b[0m, in \u001b[0;36mSequential.__call__\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[1;32m 22\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayers:\n\u001b[0;32m---> 23\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[43mm\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x\n",
"File \u001b[0;32m~/miniforge3/lib/python3.10/site-packages/mlx/nn/layers/containers.py:23\u001b[0m, in \u001b[0;36mSequential.__call__\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[1;32m 22\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayers:\n\u001b[0;32m---> 23\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[43mm\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x\n",
"File \u001b[0;32m~/miniforge3/lib/python3.10/site-packages/mlx/nn/layers/normalization.py:357\u001b[0m, in \u001b[0;36mBatchNorm.__call__\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 355\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrack_running_stats:\n\u001b[1;32m 356\u001b[0m mu \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmomentum\n\u001b[0;32m--> 357\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrunning_mean\u001b[49m \u001b[38;5;241m=\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m mu) \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrunning_mean \u001b[38;5;241m+\u001b[39m mu \u001b[38;5;241m*\u001b[39m mean\n\u001b[1;32m 358\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrunning_var \u001b[38;5;241m=\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m mu) \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrunning_var \u001b[38;5;241m+\u001b[39m mu \u001b[38;5;241m*\u001b[39m var\n\u001b[1;32m 359\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrack_running_stats:\n",
"File \u001b[0;32m~/miniforge3/lib/python3.10/site-packages/mlx/nn/layers/base.py:139\u001b[0m, in \u001b[0;36mModule.__setattr__\u001b[0;34m(self, key, val)\u001b[0m\n\u001b[1;32m 136\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 137\u001b[0m \u001b[38;5;28msuper\u001b[39m(Module, \u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__getattribute__\u001b[39m(key)\n\u001b[0;32m--> 139\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__setattr__\u001b[39m(\u001b[38;5;28mself\u001b[39m, key: \u001b[38;5;28mstr\u001b[39m, val: Any):\n\u001b[1;32m 140\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(val, (mx\u001b[38;5;241m.\u001b[39marray, \u001b[38;5;28mdict\u001b[39m, \u001b[38;5;28mlist\u001b[39m, \u001b[38;5;28mtuple\u001b[39m)):\n\u001b[1;32m 141\u001b[0m \u001b[38;5;66;03m# If attribute was previously set but not in the\u001b[39;00m\n\u001b[1;32m 142\u001b[0m \u001b[38;5;66;03m# dictionary, delete it so we pick it up in future\u001b[39;00m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;66;03m# calls to __getattr__\u001b[39;00m\n\u001b[1;32m 144\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(\u001b[38;5;28mself\u001b[39m, key) \u001b[38;5;129;01mand\u001b[39;00m key \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"batch_size = 8\n",
"cur_step = 0\n",