diff --git a/gan/playground.ipynb b/gan/playground.ipynb index d1352ed6..400d07b5 100644 --- a/gan/playground.ipynb +++ b/gan/playground.ipynb @@ -402,9 +402,29 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 77, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[77], line 22\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m5\u001b[39m):\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m real \u001b[38;5;129;01min\u001b[39;00m batch_iterate(batch_size, train_images):\n\u001b[1;32m 13\u001b[0m \n\u001b[1;32m 14\u001b[0m \u001b[38;5;66;03m# real = real.reshape(-1)\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[38;5;66;03m# print(len(real))\u001b[39;00m\n\u001b[1;32m 20\u001b[0m \u001b[38;5;66;03m# break\u001b[39;00m\n\u001b[0;32m---> 22\u001b[0m D_loss,D_grads \u001b[38;5;241m=\u001b[39m \u001b[43mD_loss_grad\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgen\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdisc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreal\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mz_dim\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;66;03m# Update optimizer\u001b[39;00m\n\u001b[1;32m 25\u001b[0m disc_opt\u001b[38;5;241m.\u001b[39mupdate(disc, D_grads)\n", + "File \u001b[0;32m~/miniforge3/lib/python3.10/site-packages/mlx/nn/utils.py:34\u001b[0m, in \u001b[0;36mvalue_and_grad..wrapped_value_grad_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(fn)\n\u001b[1;32m 33\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapped_value_grad_fn\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m---> 34\u001b[0m value, grad \u001b[38;5;241m=\u001b[39m \u001b[43mvalue_grad_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrainable_parameters\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 35\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m value, grad\n", + "File \u001b[0;32m~/miniforge3/lib/python3.10/site-packages/mlx/nn/utils.py:28\u001b[0m, in \u001b[0;36mvalue_and_grad..inner_fn\u001b[0;34m(params, *args, **kwargs)\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21minner_fn\u001b[39m(params, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 27\u001b[0m model\u001b[38;5;241m.\u001b[39mupdate(params)\n\u001b[0;32m---> 28\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[71], line 3\u001b[0m, in \u001b[0;36mdisc_loss\u001b[0;34m(gen, disc, real, num_images, z_dim)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdisc_loss\u001b[39m(gen, disc, real, num_images, z_dim):\n\u001b[1;32m 2\u001b[0m noise \u001b[38;5;241m=\u001b[39m mx\u001b[38;5;241m.\u001b[39marray(get_noise(num_images, z_dim))\n\u001b[0;32m----> 3\u001b[0m fake_images \u001b[38;5;241m=\u001b[39m \u001b[43mgen\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnoise\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m fake_disc \u001b[38;5;241m=\u001b[39m disc(fake_images)\n\u001b[1;32m 7\u001b[0m fake_labels \u001b[38;5;241m=\u001b[39m mx\u001b[38;5;241m.\u001b[39mzeros((fake_images\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m],\u001b[38;5;241m1\u001b[39m))\n", + "Cell \u001b[0;32mIn[5], line 19\u001b[0m, in \u001b[0;36mGenerator.__call__\u001b[0;34m(self, noise)\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, noise):\n\u001b[0;32m---> 19\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgen\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnoise\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniforge3/lib/python3.10/site-packages/mlx/nn/layers/containers.py:23\u001b[0m, in \u001b[0;36mSequential.__call__\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[1;32m 22\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayers:\n\u001b[0;32m---> 23\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[43mm\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x\n", + "File \u001b[0;32m~/miniforge3/lib/python3.10/site-packages/mlx/nn/layers/containers.py:23\u001b[0m, in \u001b[0;36mSequential.__call__\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[1;32m 22\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayers:\n\u001b[0;32m---> 23\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[43mm\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x\n", + "File \u001b[0;32m~/miniforge3/lib/python3.10/site-packages/mlx/nn/layers/normalization.py:357\u001b[0m, in \u001b[0;36mBatchNorm.__call__\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 355\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrack_running_stats:\n\u001b[1;32m 356\u001b[0m mu \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmomentum\n\u001b[0;32m--> 357\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrunning_mean\u001b[49m \u001b[38;5;241m=\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m mu) \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrunning_mean \u001b[38;5;241m+\u001b[39m mu \u001b[38;5;241m*\u001b[39m mean\n\u001b[1;32m 358\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrunning_var \u001b[38;5;241m=\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m mu) \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrunning_var \u001b[38;5;241m+\u001b[39m mu \u001b[38;5;241m*\u001b[39m var\n\u001b[1;32m 359\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrack_running_stats:\n", + "File \u001b[0;32m~/miniforge3/lib/python3.10/site-packages/mlx/nn/layers/base.py:139\u001b[0m, in \u001b[0;36mModule.__setattr__\u001b[0;34m(self, key, val)\u001b[0m\n\u001b[1;32m 136\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 137\u001b[0m \u001b[38;5;28msuper\u001b[39m(Module, \u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__getattribute__\u001b[39m(key)\n\u001b[0;32m--> 139\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__setattr__\u001b[39m(\u001b[38;5;28mself\u001b[39m, key: \u001b[38;5;28mstr\u001b[39m, val: Any):\n\u001b[1;32m 140\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(val, (mx\u001b[38;5;241m.\u001b[39marray, \u001b[38;5;28mdict\u001b[39m, \u001b[38;5;28mlist\u001b[39m, \u001b[38;5;28mtuple\u001b[39m)):\n\u001b[1;32m 141\u001b[0m \u001b[38;5;66;03m# If attribute was previously set but not in the\u001b[39;00m\n\u001b[1;32m 142\u001b[0m \u001b[38;5;66;03m# dictionary, delete it so we pick it up in future\u001b[39;00m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;66;03m# calls to __getattr__\u001b[39;00m\n\u001b[1;32m 144\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(\u001b[38;5;28mself\u001b[39m, key) \u001b[38;5;129;01mand\u001b[39;00m key \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], "source": [ "batch_size = 8\n", "cur_step = 0\n",