mlx-examples/gan/playground.ipynb

563 lines
122 KiB
Plaintext
Raw Normal View History

2024-07-26 21:07:40 +08:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Import Library"
]
},
{
"cell_type": "code",
2024-07-28 22:35:36 +08:00
"execution_count": 23,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"import mnist"
]
},
{
"cell_type": "code",
2024-07-28 22:35:36 +08:00
"execution_count": 24,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"import mlx.core as mx\n",
"import mlx.nn as nn\n",
"import mlx.optimizers as optim\n",
"\n",
"from tqdm import tqdm\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# GAN Architecture"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generator 👨🏻‍🎨"
]
},
{
"cell_type": "code",
2024-07-28 22:35:36 +08:00
"execution_count": 25,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def GenBlock(in_dim:int,out_dim:int):\n",
" \n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
" nn.BatchNorm(out_dim),\n",
" nn.ReLU()\n",
" )"
]
},
{
"cell_type": "code",
2024-07-28 22:35:36 +08:00
"execution_count": 26,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"class Generator(nn.Module):\n",
"\n",
" def __init__(self, z_dim:int = 10, im_dim:int = 784, hidden_dim: int =128):\n",
" super(Generator, self).__init__()\n",
" # Build the neural network\n",
" self.gen = nn.Sequential(\n",
" GenBlock(z_dim, hidden_dim),\n",
" GenBlock(hidden_dim, hidden_dim * 2),\n",
" GenBlock(hidden_dim * 2, hidden_dim * 4),\n",
" GenBlock(hidden_dim * 4, hidden_dim * 8),\n",
"\n",
"\n",
" nn.Linear(hidden_dim * 8,im_dim),\n",
" nn.Sigmoid()\n",
" )\n",
" \n",
" def __call__(self, noise):\n",
"\n",
" return self.gen(noise)"
]
},
{
"cell_type": "code",
2024-07-28 22:35:36 +08:00
"execution_count": 27,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Generator(\n",
" (gen): Sequential(\n",
" (layers.0): Sequential(\n",
" (layers.0): Linear(input_dims=100, output_dims=128, bias=True)\n",
" (layers.1): BatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): ReLU()\n",
" )\n",
" (layers.1): Sequential(\n",
" (layers.0): Linear(input_dims=128, output_dims=256, bias=True)\n",
" (layers.1): BatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): ReLU()\n",
" )\n",
" (layers.2): Sequential(\n",
" (layers.0): Linear(input_dims=256, output_dims=512, bias=True)\n",
" (layers.1): BatchNorm(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): ReLU()\n",
" )\n",
" (layers.3): Sequential(\n",
" (layers.0): Linear(input_dims=512, output_dims=1024, bias=True)\n",
" (layers.1): BatchNorm(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): ReLU()\n",
" )\n",
" (layers.4): Linear(input_dims=1024, output_dims=784, bias=True)\n",
" (layers.5): Sigmoid()\n",
" )\n",
")"
]
},
2024-07-28 22:35:36 +08:00
"execution_count": 27,
2024-07-26 21:07:40 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gen = Generator(100)\n",
"gen"
]
},
{
"cell_type": "code",
2024-07-28 22:35:36 +08:00
"execution_count": 28,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def get_noise(n_samples, z_dim):\n",
" return np.random.randn(n_samples,z_dim)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Discriminator 🕵🏻‍♂️"
]
},
{
"cell_type": "code",
2024-07-28 22:35:36 +08:00
"execution_count": 29,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def DisBlock(in_dim:int,out_dim:int):\n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
" nn.LeakyReLU(negative_slope=0.2)\n",
" )"
]
},
{
"cell_type": "code",
2024-07-28 22:35:36 +08:00
"execution_count": 30,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"class Discriminator(nn.Module):\n",
"\n",
" def __init__(self,im_dim:int = 784, hidden_dim:int = 128):\n",
" super(Discriminator, self).__init__()\n",
"\n",
" self.disc = nn.Sequential(\n",
" DisBlock(im_dim, hidden_dim * 4),\n",
" DisBlock(hidden_dim * 4, hidden_dim * 2),\n",
" DisBlock(hidden_dim * 2, hidden_dim),\n",
"\n",
" nn.Linear(hidden_dim,1),\n",
" )\n",
" \n",
" def __call__(self, noise):\n",
"\n",
" return self.disc(noise)"
]
},
{
"cell_type": "code",
2024-07-28 22:35:36 +08:00
"execution_count": 31,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Discriminator(\n",
" (disc): Sequential(\n",
" (layers.0): Sequential(\n",
" (layers.0): Linear(input_dims=784, output_dims=512, bias=True)\n",
" (layers.1): LeakyReLU()\n",
" )\n",
" (layers.1): Sequential(\n",
" (layers.0): Linear(input_dims=512, output_dims=256, bias=True)\n",
" (layers.1): LeakyReLU()\n",
" )\n",
" (layers.2): Sequential(\n",
" (layers.0): Linear(input_dims=256, output_dims=128, bias=True)\n",
" (layers.1): LeakyReLU()\n",
" )\n",
" (layers.3): Linear(input_dims=128, output_dims=1, bias=True)\n",
" )\n",
")"
]
},
2024-07-28 22:35:36 +08:00
"execution_count": 31,
2024-07-26 21:07:40 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"disc = Discriminator()\n",
"disc"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model Training 🏋🏻‍♂️"
]
},
{
"cell_type": "code",
2024-07-28 22:35:36 +08:00
"execution_count": 32,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"# Set your parameters\n",
"criterion = nn.losses.binary_cross_entropy\n",
"n_epochs = 200\n",
"z_dim = 64\n",
"display_step = 500\n",
"batch_size = 128\n",
"lr = 0.00001"
]
},
{
"cell_type": "code",
2024-07-28 22:35:36 +08:00
"execution_count": 33,
2024-07-26 21:07:40 +08:00
"metadata": {},
2024-07-27 06:09:51 +08:00
"outputs": [],
2024-07-26 21:07:40 +08:00
"source": [
"gen = Generator(z_dim)\n",
2024-07-26 21:36:29 +08:00
"mx.eval(gen.parameters())\n",
2024-07-26 21:07:40 +08:00
"gen_opt = optim.Adam(learning_rate=lr)\n",
2024-07-26 21:36:29 +08:00
"\n",
2024-07-26 21:07:40 +08:00
"disc = Discriminator()\n",
2024-07-26 21:36:29 +08:00
"mx.eval(disc.parameters())\n",
2024-07-26 21:07:40 +08:00
"disc_opt = optim.Adam(learning_rate=lr)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Losses"
]
},
{
"cell_type": "code",
2024-07-28 22:35:36 +08:00
"execution_count": 34,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-26 21:36:29 +08:00
"def disc_loss(gen, disc, real, num_images, z_dim):\n",
2024-07-26 21:07:40 +08:00
" noise = mx.array(get_noise(num_images, z_dim))\n",
" fake_images = gen(noise)\n",
2024-07-27 06:09:51 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" fake_disc = disc(fake_images)\n",
" \n",
2024-07-27 05:19:08 +08:00
" fake_labels = mx.zeros((fake_images.shape[0],1))\n",
2024-07-26 21:36:29 +08:00
" fake_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)\n",
2024-07-26 21:07:40 +08:00
" \n",
" real_disc = disc(real)\n",
2024-07-27 05:19:08 +08:00
" real_labels = mx.ones((real.shape[0],1))\n",
"\n",
2024-07-27 06:09:51 +08:00
" real_loss = nn.losses.binary_cross_entropy(real_disc,real_labels,with_logits=True)\n",
2024-07-26 21:07:40 +08:00
"\n",
" disc_loss = (fake_loss + real_loss) / 2\n",
"\n",
" return disc_loss"
]
},
{
"cell_type": "code",
2024-07-28 22:35:36 +08:00
"execution_count": 35,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-26 21:36:29 +08:00
"def gen_loss(gen, disc, num_images, z_dim):\n",
2024-07-26 21:07:40 +08:00
"\n",
" noise = mx.array(get_noise(num_images, z_dim))\n",
" fake_images = gen(noise)\n",
" fake_disc = disc(fake_images)\n",
"\n",
2024-07-27 05:19:08 +08:00
" fake_labels = mx.ones((fake_images.shape[0],1))\n",
2024-07-26 21:07:40 +08:00
" \n",
2024-07-26 21:36:29 +08:00
" gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)\n",
2024-07-26 21:07:40 +08:00
"\n",
" return gen_loss"
]
},
{
"cell_type": "code",
2024-07-28 22:35:36 +08:00
"execution_count": 36,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"train_images, _, test_images, _ = map(\n",
" mx.array, getattr(mnist, 'mnist')()\n",
")"
]
},
{
"cell_type": "code",
2024-07-28 22:35:36 +08:00
"execution_count": 37,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def batch_iterate(batch_size:int, ipt:list):\n",
" perm = mx.array(np.random.permutation(len(ipt)))\n",
" for s in range(0, ipt.size, batch_size):\n",
" ids = perm[s : s + batch_size]\n",
" yield ipt[ids]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### show batch of images"
]
},
{
"cell_type": "code",
2024-07-28 22:35:36 +08:00
"execution_count": 38,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
2024-07-28 22:35:36 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAUkAAAFICAYAAADd1gwNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAACQiklEQVR4nOy9V49l2ZXf+bvee3/D+4zISFuZVZXFYpHFIkU2DdjdYKuBgQYcYCToC+gLCHrUsx4ECIIwwmB6ultSCxTZ7C4y2SzDrGKl95HhzfXe+3vnIWfvupEm0oc9PyDAYmbEzXN2nLP22sv8l6rX6/VQUFBQUHgi6r2+AAUFBYX9jGIkFRQUFHZAMZIKCgoKO6AYSQUFBYUdUIykgoKCwg4oRlJBQUFhBxQjqaCgoLADipFUUFBQ2AHFSCooKCjsgPZ5v1GlUr3J6zhQvGyTkrKGX6Os4aujrOGr8zxrqHiSCgoKCjugGEkFBQWFHVCMpIKCgsIOPHdMUkFBQQHAZDJht9sxmUyMjY3hdDrpdDq0Wi1yuRy3bt2iUqns9WW+NhQjqaCg8EI4nU6OHTvG4OAgP//5zzl+/DiVSoVKpcLVq1f5t//23ypGUkFB4eghsuJms5mBgQEGBgYIhUIEg0HK5TLlchmn04lGo9njK329KEZSQUHhuVCr1ajVaqanp/n5z3+O3+9nYGCAXq+37euwceCM5PPUeB3GX5SCwl6jUqlQq9W43W5OnTqFx+MBONQGEvaxkRS/EK1Wi91uR6/X4/F48Hg8aLVajEYjWq0Wm82GyWSiWq1SKpWoVCosLi5SLBap1WrUarW9vhUFBYUDzL42ksIYBgIB7HY7MzMzzM7OYjQacblcmEwmhoaGcLvdpNNpNjc3icfj/OIXv2BjY4N0Oq0YSQUFhVdiXxhJlUqFRqNBo9Fgt9sxm82YTCbcbjcGg4FwOIzFYmFsbIzBwUEMBgMOhwO9Xo/D4cBsNmO32/H5fKhUKsbHxzEYDAAUCgW63S7dbneP71LhoKPVauVx0+12y9BPp9MhlUpRKpXodDp0Op09vtLdR61Wo9Fo0Ol0WCwWLBYLjUaDdru915f2yuwLI6nRaOSx+Z133mFqaorR0VHOnj2L0WjEZDKh1WrR6/UYjUb5CxH0ej1MJhOBQIBarUYgECCXy/G3f/u3JJNJms0m1Wr10MZMFN48YgM3Go1873vf46OPPpJGslqt8nd/93dcu3ZNhn2O2rOm1WoxmUw4nU5GR0fpdrtEo1Gy2exeX9ors6dGUsQd9Xo9NpsNq9VKMBhkZGSEiYkJZmdnMRgMqNVqVCoV3W5XBoiFd9hoNGi1Wuh0OgwGAxqNhnA4jN1ux+PxYDQaAajVakfuwX0W4iXXaDSoVCq5roDieT+CSqXCYDBgNpsJBAJMTU2hVj9sWCuXy/j9fqxWK91ul3K5fGifNfGMtNttOp2OdFZE5lun02E2m7FarWi1+8IHe2X29C7cbjcej4dwOMy3vvUtvF4vU1NThMNhHA4HRqORTqfD1tYWlUqFXC5HLpejVquRTqdpNBpkMhkqlQqjo6OcOnUKq9XKwMAANpuN6elpzp49SyKR4Pbt29Tr9b283X2FVqvFbDZjMBhkXDefz5NKpWg0GmSzWZrN5l5f5r7BYDAwMTGBz+djZmaGkZERaSQbjQbf//73mZiY4MqVK/z617+m0Wjs8RW/fsTGWSwWWVpaolgsylCYRqORm4jX66VSqRCLxfb4il8Pe2okrVYr4XCYmZkZvv/97xMOh3G73VitVvk9zWaTVColEzObm5sUi0VWVlaoVCpEo1FyuRznz59Hr9cTDAYZHR3F7XYzODjI5OQkWq2WhYUFxUj2odFosFgsmM1mxsfHGRkZIRqNAg89o1KppBjJPnQ6HaFQiJGREYaGhggEAtJIttttzp07x+joKLVajd/85jeH0kiKU5x477rdLm63G4vFIvMKIl/gdDrR6XR7fcmvhV03kiqVCp1Oh1arZWhoiDNnzjA8PIzL5cJsNtNsNsnn8xSLRRKJBMVikevXr5NOp0mn02QyGarVKslkknq9TqlUotFoUC6XyefzmM1mWq0WgPyl6XS6A6uhp1KpsNlsmM1mnE4ng4OD2+KxkUiEWCxGs9mkUqk885is0+nQ6XT4fD7Onj2L0+lkdnaWYDCIz+fDYrHIdT5MrWUvi0ajQa/Xy1a86elpgsHgtudJrVZjs9mAhxv/QX3Wnodut0u9XieVSqHT6Q7lZvAou24k1Wo1FosFk8nEqVOn+LM/+zNcLheDg4Po9XoSiQT5fJ4HDx7w2WefkU6n+eqrr0ilUrTbbVqtFr1eT2YQO50OvV6PfD5PNBpFo9FII2kwGLBarRiNxgP74KrVakKhEOFwmPn5ef7kT/4Es9kMPNzZf/Ob3/DrX/+aQqHA5ubmMx9aIU4wOzvLz3/+cwYGBggGgzgcDqLRKIuLi6yurnLv3j1SqdRu3OK+RlRQhMNhPvzwQ86fP4/BYHjMSPr9frxeL16vV3qYhw0RZxUnuWazycmTJ/f4qt48e+JJisCu0+nE5XJhs9lkkLdUKpFKpUgkEsRiMbLZLLlcTpbyPC0g3mq1qNVqNBqNbd6USqU6sAYSHl6/y+ViaGiIcDhMIBDAbDbLjcLpdGK1Wmk0Gk99OUXNqShfCYfDDA4O4vP5cLvd6PV6er0ezWaTcrlMtVo9kmUsT0LEbs1mMxaLZVso6FFEIlKUsx3WNRSnPYvF8lhIRq1WYzQaMZvNSuLmZTEYDBw7doxwOCyVRDQaDe12m3q9zldffcXVq1dZX1/n2rVr1Gq1ZxpIeJi9TiQST/zFHWT0ej3f/va3+bM/+zPsdruMhXU6HdrtNoODg4yPj2MymdjY2Hhi8bxWqyUYDGKxWPiTP/kTfvCDH+ByuRgbG0Oj0bC8vEwymeTq1av85je/IZfLEY/H9+Bu9x9Wq1XGIU0m0xO/R8Tp6vU67XYbh8MBPNzwD0Od4KMkk0k+//xzRkZG+P73v7/t76xWK3NzczidTu7cubNHV/h62XUjqdVqcbvdhEIhvF4vNpuNXq9HtVql2WySTCZZWVlha2uLWCwmj87PQhjZer1+qHZwtVrN0NAQp0+flqVQwutrt9tYLBbsdjvFYvGpnqQIcTidTqanp3nvvffQarVotVrq9Tr5fJ7NzU2Wl5e5ceMG1Wr1udf9sCOO2zabDY1GQ6/Xe+xk0uv1aLVaNBoNer0eRqORer1+aGO61WqVSCSCXq9/zCHR6XR4PB46nY4svzvo7Mlx22g0YrVa0ev1ALJW0mw2Mz09TavV4tq1aywsLDz3y2qz2RgeHmZgYEDu+J1OR9ZRHuS6NXGM6y9evnz5MvF4nMuXL3P58mXy+fxTs/cGg4HJyUkGBwcJhUJotVqazSaxWIxCocAf/vAHrl27xsbGhtxkDvJ6vU5arRalUolcLsfKygpqtVrGHgUqlQqTyYRGo8HtdjMwMIDBYKBcLh+qU83zIDQWer2ebBARuYSDyp4kbkwm0zYjKQp1NRoNs7OzuN1uyuXyC8U0bDYbo6Oj8gGFh96lMJIHGWEkBdVqlc8//5ybN29y//597t+/v6NhMxqNTE1NMTs7y8DAAFqtllKpxPr6OolEgn/6p3/ik08+2VZMrvCQVqtFoVBAr9ezvLxMq9Xi2LFjeDweuWkJI2kymfB6vXKNt7a2KJfLe3wHu4tOp8Pr9aLRaHA6nZhMJtmeeFA33l03kiLh0Gq1nvhCNptNGdt5nkU1GAxotVqcTid+vx+PxyPrs2q1Gvl8nnK5fKhe/m63S61Wo1wuP9PzE10QLpcLr9crM+P1ep1oNEo0GpU9xwqPI5JZWq2WpaUlSqWSNIaiU6x/A2u321Kl+yiuqTgVGo1G7HY7brebUql0oDvedt1IiratJx0PRSnP1tYWuVzumYZNZGvtdjtTU1OcO3dOZst7vR6pVIp79+6RzWYPvDfZT6fTkcmVneJewkBaLBampqY4efIkTqcTlUpFJpPhd7/7Hevr67KIXOFxyuUya2traDQaVlZWMBgMspDa5/M
2024-07-26 21:07:40 +08:00
"text/plain": [
"<Figure size 400x400 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for X in batch_iterate(16, train_images):\n",
" fig,axes = plt.subplots(4, 4, figsize=(4, 4))\n",
"\n",
" for i, ax in enumerate(axes.flat):\n",
" img = mx.array(X[i]).reshape(28,28)\n",
" ax.imshow(img,cmap='gray')\n",
" ax.axis('off')\n",
" break"
]
},
{
"cell_type": "code",
2024-07-28 22:35:36 +08:00
"execution_count": 39,
2024-07-28 06:10:19 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-28 22:22:40 +08:00
"def show_images(imgs:list[int],num_imgs:int = 25):\n",
" fig,axes = plt.subplots(5, 5, figsize=(4, 4))\n",
" \n",
2024-07-28 06:10:19 +08:00
" for i, ax in enumerate(axes.flat):\n",
" img = mx.array(imgs[i]).reshape(28,28)\n",
" ax.imshow(img,cmap='gray')\n",
" ax.axis('off')\n",
" plt.show()"
]
},
{
"cell_type": "code",
2024-07-28 22:56:26 +08:00
"execution_count": 43,
2024-07-27 06:09:51 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2024-07-28 22:56:26 +08:00
"array(0.675341, dtype=float32)"
2024-07-27 06:09:51 +08:00
]
},
2024-07-28 22:56:26 +08:00
"execution_count": 43,
2024-07-27 06:09:51 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"z_dim = 64\n",
"gen = Generator(z_dim)\n",
"mx.eval(gen.parameters())\n",
"gen_opt = optim.Adam(learning_rate=lr)\n",
"\n",
"disc = Discriminator()\n",
"mx.eval(disc.parameters())\n",
"disc_opt = optim.Adam(learning_rate=lr)\n",
"\n",
"g_loss = gen_loss(gen, disc, 8, z_dim)\n",
"g_loss\n"
]
},
{
"cell_type": "code",
2024-07-28 22:56:26 +08:00
"execution_count": 44,
2024-07-26 21:07:40 +08:00
"metadata": {},
2024-07-27 06:20:00 +08:00
"outputs": [
{
2024-07-28 06:10:19 +08:00
"data": {
"text/plain": [
"60000"
]
},
2024-07-28 22:56:26 +08:00
"execution_count": 44,
2024-07-28 06:10:19 +08:00
"metadata": {},
"output_type": "execute_result"
2024-07-27 06:20:00 +08:00
}
],
2024-07-27 06:09:51 +08:00
"source": [
2024-07-28 06:10:19 +08:00
"len(train_images)"
2024-07-27 06:09:51 +08:00
]
},
{
"cell_type": "code",
2024-07-28 22:56:26 +08:00
"execution_count": 45,
2024-07-27 06:09:51 +08:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-29 00:18:35 +08:00
" 10%|█ | 20/200 [1:17:04<11:32:48, 230.94s/it]"
2024-07-27 05:19:08 +08:00
]
2024-07-29 00:18:35 +08:00
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Step 20: Generator loss: array(16.7247, dtype=float32), discriminator loss: array(nan, dtype=float32)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAUkAAAFICAYAAADd1gwNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAACOYElEQVR4nO29eZRcZ3nn/7331q19r953tVqtVkuyZGuxsOUtYAwxCVEIYLIQHJLB2Qhhcg4nZ8jM/AiTOYEQBzKADSTE9jBhCRAnAdtgkMCRLVlCu6ylu6VW713V3bXvy/v7o3le3aqurq7qru6ult7POXVsVVfVve9yn/d5n+2VGGMMAoFAICiKvN43IBAIBLWMEJICgUBQAiEkBQKBoARCSAoEAkEJhJAUCASCEgghKRAIBCUQQlIgEAhKIISkQCAQlEAISYFAICiBrtwPSpK0mvexbpSbcCTaL9p/K1JJwt3t2gdCkxQIBIISCCEpEAgEJRBCUiAQCEoghKRAIBCUQAhJgUAgKIEQkgKBQFACISQFAoGgBGsqJA8ePIjHH38cXV1da3lZQY0gxn8her0ejzzyCA4dOgSHw7Het7PqbMg5wMoEwIpesiyzf/iHf2DxeJwdOnRoxb9Xrddatb9WX2L813f8nU4n+/73v8/OnTvH+vr6arb9t/McKDvjZqUwxvDSSy/B5/NhYGBgrS4rqBHE+BcnmUziO9/5DpxOJ+bm5tb7dlaVjToHpJ+vEEt/8DZNSSJE+0X7b0XKbT9w+/bBmmmSgvJwuVz47d/+bSSTSTz33HOIRCIlP3/gwAG8/e1vx49//GP85Cc/WaO7FAhuH4SQrDHcbjf++I//GOFwGN/97neXFJJvetOb8N//+39HNpsVQlIgWAU2rJDs7+/Hm9/8Zpw+fRr/+Z//ud63UzVmZmbwyU9+EqlUCqFQaMnPHz58GH/6p3+K1157bQ3urnbo7+/HQw89hLNnz95S4y8on76+Pj4HXn311VW7zoYVkps2bcKv/dqvIZfL3VIPSTAYxFe/+tWyPivLMs6ePYszZ86s7k3VIJ2dnfjVX/3VDTn+kiRBkiTkcrn1vpUNzaZNm3Do0CEwxlZVSG5Yx017ezt27NiB69ev4/Lly8v+nY1quL///vvxgQ98AD/4wQ/w9a9/fdm/s1Hb39bWhh07dmB4eHhDjb+iKHjiiSewc+dOfO5zn8Mbb7xRld9dLhvZcdPa2srnwJUrV5b9O0v2wWrESEmSxBRFYZIkrXsM1FKvtYoRq/QlSRJTVZXpdLqif3/88cdZIpFgn/rUp2qu/WL8F3+pqsq+/vWvM7/fz9785jdvmPbfznNgVTTJ7du3Y//+/Th9+nTNbwXLbP6ar6KdnZ34i7/4C0xPT+Ov/uqvEI1G8/7e1dWFO++8E0NDQzh37tyyr7Ma7d+xYwfuvvtu/OxnPxPjX4Asy9i/fz8aGxvx2muvwev1LriOwWAAMB9DWe79LZdKfr+SPujr68Odd96J8+fP48KFC8u5tTVjqT5YFZukx+PB9u3bMTU1hatXryKVSiGTyazGpW4JjEYjJElCIpHgA2az2fDQQw9heHgYOp0OiqLAYDAgm80imUxieHgYw8PD63vji+DxeNDf3y/Gvwi5XA7Hjh3Le087/gCg0+lWXTiuNk6nE319fZiZmcG1a9c29hxYDVXb5XKxrVu3sj/8wz9k3/72t9k73vGOdVepF3utRvsredlsNvbZz36W/fM//zPr6uri71utVnb//fezvXv3Mp1Ox/bu3cv+/d//nX384x9niqLUdPtp/P/oj/6Ifec737ntx7/U1tNuty8Yf0VRqjrG1Wh/pX3gdDpZT08P+/3f//0NPwdWpcCF3+/HlStXoKoqdu/ejfr6+tW4zC2BoijYunUr7rjjDhiNRsiyDIfDAVmW8corr+DkyZPIZDKw2WzYvXs3uru7a86AXgiNv06nE+OPm9tUi8UCl8sFvV7P31dVFXv37sXBgwdhsVgAANlsFtlsdt3utxoEAgEMDg7CYDBs/DmwmitpS0sL27t3L2tsbFz31WKx12q2v5yXTqdjO3bsYHv27GEmk4m1trayf/mXf2HPPvssc7vd/HMOh4Pt37+f9fb2VtUYLsZ/bcZflmX28Y9/nP30pz9lDz74IH/f6XSyf/3Xf2UnTpxgW7durdn2385zoKo2Sb1eD5PJhGQyiUQigYmJCUxMTFTzEguwWCyw2+0Ih8N52SlOpxMmkwl+v5/bemqRTCaTZ9jW6XRobW1FMplES0sLZFnG7OwsgsEgXn/99XW806UR4784jDG43W50dnaira0NTU1N8Pv9yGazmJqaAgCk0+mqXW+92q+qKsxm85rOAbPZDJvNhkgkkufgrFofVHMVueuuu9j//J//kz3yyCNrtgq8733vYydOnGC/+7u/y9+TJIl94hOfYD/72c/YW97ylqqspGvVHr1ez3p7e9nb3/529sorr7DnnnuO2Wy2dVtFKx3///E//seaj//Jkyc3xPi3tray3bt3s6985Svs+PHj7MCBA0yWZdbe3s66u7uZXq+vynWo/SdPnqxa+8vtg927d7M///M/X/K61Xz92q/9Gnv55ZfZb/3Wb63KHKiqJinLMlRVhSyXNnU2NTXBbDZjcnIS8Xh8RdfU6XQwGAzQ6fKboqoqjEYjFEVZ0e+vNalUClevXoXRaERPTw9kWd4wbZBlGTqdbk3HX1XVmh//uro6OBwOeL1eeL1exONxbn/O5XIYHR1d9LtGoxFNTU1IpVKYnJws2+tN/bLUWKwGer1+yX6v5hygyI/Ca1ZtDlRzFTEajayuro5ZLJZFP6OqKvv85z/Pzp8/z+69994VryJ2u511dXUxp9OZ9359fT3r6uoqeS8VNH/NVkV63XHHHWxycpIdPXp0Qduq+RLjv/rj//GPf5wNDAywd73rXQwAa2xsZF1dXcxkMpX8niRJbM+ePezUqVPsH//xH5nRaCz7mtVuf7l9YDAYmMfjWdM5YLPZWEdHB3M4HKvSB1XVJBOJxJJ7f8YY4vE4otFoVTx4oVAIoVAITqcTXV1dmJubQygUgs/ng8/nW/HvrzV6vR7d3d3o6OjA1atX4fP50NfXh1gshlwuh2AwWFLzWE/E+BfHYDDA4XBwr/b09HTZ381ms4hEItDpdNixYwe8Xi9GRkaW/N56tT+ZTCKZTJb8TLXnQDgcRjgcXr05sBYraeHL6XSyxsbGqtlgALBDhw6x5557rmJ72Hq0v9Srs7OTHTt2jP3kJz9h27dvZ4888gi7cuUKGx0dZQMDA+xzn/vcoqmKy3ndSuP/zDPPsLe+9a011/7f/M3fZM899xy7//77K/6uqqqsoaGBHTp0iA0MDLAvfvGL6zL+G2EOvOtd72Jf+9rX2Nve9raq9sGqZNy0tLSgq6sLIyMjGBsbW/D3QCBQ9m9ZrVZs3boVkUgEAwMDi1ZOcbvd6O7u3rCHKen1evT396O3txft7e2Ix+PYvHkzLBYLRkZGoNfrkUwm4fV6az4bo9rj39fXh0gkgqtXry46/k6ns2bH32g0wmazQVXVir+bTqfh9XoxPj6OoaGhirTQ9aSac8BisaCnpwfRaBRDQ0OLzn+Xy4VNmzZVfw6sxiryxBNPMJ/Px/7sz/5sxavDXXfdxc6cOcP+6Z/+qaRN5uMf/ziLxWLsj/7oj1ZlJV1pO5Z6NTU1saNHj7JgMMiy2SzLZDIsEAiwl19+mXV2djKn08kcDgczm81VvW6tj/+ePXvY+fPn2bPPPrthx//Tn/40i8Vi7P3vf39Zn5ckaUEsrE6nYw6Hg1mtVqYoCpNleU3Hfz3nwB133MEOHz7M/v7v/54ZDIaScyAajbI//MM/rGofrIomOTU1hddffx3j4+Mr/q1IJIIzZ86U1CIB4Pr16/jRj37E7XUWiwUGgwGRSASpVGrF97HapFIpnD9/HtlsFnv37kUmk8GpU6dw7tw5zMzMLChwUctUe/xPnz5dUosEanv8BwcHceTIEUxOTi75WbPZjH379iGTyeDEiRP83p1OJ+688074fD5cuHCh5mt
"text/plain": [
"<Figure size 400x400 with 25 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 10%|█ | 20/200 [1:21:00<12:09:01, 243.01s/it]\n"
]
},
{
"ename": "ValueError",
"evalue": "[take] Cannot do a non-empty take from an array with zero elements.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[45], line 37\u001b[0m\n\u001b[1;32m 35\u001b[0m fake \u001b[38;5;241m=\u001b[39m gen(fake_noise)\n\u001b[1;32m 36\u001b[0m show_images(fake)\n\u001b[0;32m---> 37\u001b[0m \u001b[43mshow_images\u001b[49m\u001b[43m(\u001b[49m\u001b[43mreal\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 38\u001b[0m cur_step \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n",
"Cell \u001b[0;32mIn[39], line 5\u001b[0m, in \u001b[0;36mshow_images\u001b[0;34m(imgs, num_imgs)\u001b[0m\n\u001b[1;32m 2\u001b[0m fig,axes \u001b[38;5;241m=\u001b[39m plt\u001b[38;5;241m.\u001b[39msubplots(\u001b[38;5;241m5\u001b[39m, \u001b[38;5;241m5\u001b[39m, figsize\u001b[38;5;241m=\u001b[39m(\u001b[38;5;241m4\u001b[39m, \u001b[38;5;241m4\u001b[39m))\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i, ax \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(axes\u001b[38;5;241m.\u001b[39mflat):\n\u001b[0;32m----> 5\u001b[0m img \u001b[38;5;241m=\u001b[39m mx\u001b[38;5;241m.\u001b[39marray(\u001b[43mimgs\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m)\u001b[38;5;241m.\u001b[39mreshape(\u001b[38;5;241m28\u001b[39m,\u001b[38;5;241m28\u001b[39m)\n\u001b[1;32m 6\u001b[0m ax\u001b[38;5;241m.\u001b[39mimshow(img,cmap\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mgray\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 7\u001b[0m ax\u001b[38;5;241m.\u001b[39maxis(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124moff\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
"\u001b[0;31mValueError\u001b[0m: [take] Cannot do a non-empty take from an array with zero elements."
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWEAAAFlCAYAAAA6blnBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAc+UlEQVR4nO3dT2gTfR7H8U/sY1N0aUTFrpW2eOkWFISnUqi4oh4KBWXdkwd1C6489GQfeuqyLGWlPenuwqIe7GFdtKAs0p685FD815OywYOIKNJUrYr/JlF5WrC/PSypTTOTZNomv2Tm/YIeOmaYeY/h23QymUaMMUYAACvW2N4BAAgzhjAAWMQQBgCLGMIAYBFDGAAsYggDgEUMYQCwiCEMABYxhAHAIoYwAFjkewjfvn1bhw8fVmNjoyKRiMbHx0uwW5WLfvrpD29/Kfgewl++fNGuXbt0/vz5UuxPxaOffvrD218KP/hdobu7W93d3aXYl6pAP/30h7e/FHwPYb9mZ2c1Ozu78P38/Lw+fPigTZs2KRKJlHrzJff161c5jqN0Oq3GxkatWZP9ywX99NMf3P4MY4znMShm5WWTZMbGxvI+ZnBw0EgKxdf09DT99NMf0n6vY1BIxJjl3084EolobGxMR44c8XzM0p+EjuOoublZ09PTqq+vX+6mK0IsFtPo6Kj27dunpqYmffr0SbFYLOsx9NNPfzD7F0ulUp7HoJCSn46IRqOKRqM5y+vr6wPxn7Bu3bqFDrdfr+inn/7g9i+1nFMsXCcMABb5fiX8+fNnPX36dOH758+fK5FIaOPGjWpubl7VnatEbv0PHz60uEflRT/9Ye4vCb8nkScmJlxPSPf09BS1vuM4RpJxHMfvpiuCV3+xTfTTT3/19ntZSZfvV8L79++XCfHfBnXrT6VSvk/GVyv66Q9zfylwThgALGIIA4BFDGEAsIghDAAWMYQBwCKGMABYxBAGAIsYwgBgEUMYACxiCAOARQxhALCIIQwAFjGEAcAihjAAWMQQBgCLGMIAYBFDGAAsYggDgEUMYQCwiCEMABYxhAHAIoYwAFjEEAYAixjCAGARQxgALGIIA4BFDGEAsIghDAAWMYQBwCKGMABYxBAGAIsYwgBgEUMYACxiCAOARQxhALCIIQwAFjGEAcAihjAAWMQQBgCLGMIAYBFDGAAsYggDgEXLGsIXL17U9u3bVVdXp/b2dt25c2e196uiLe2fnJy0vUtlRT/9Ye5fdcana9eumbVr15qRkRHz6NEj09fXZ9avX2+mpqaKWt9xHCPJOI7jd9MVwau/2Cb66ae/evu9rKTL9xDu6Ogwvb29Wcva2trMwMBAUetX+3+CW39ra2tonoT00x/mfi8r6frBz6vmubk5PXjwQAMDA1nLu7q6PH8lmZ2d1ezs7ML3juNIklKplJ9NV4RM/+nTp7P2f+/evXry5ImMMTnr0E8//cHozyfT43YMCvIzsV++fGkkmXv37mUtHx4eNq2tra7rDA4OGkmh+Hr27Bn99NMf0n6vY1BIxJjiR/erV6+0bds2TU5OqrOzc2H58PCwrly5osePH+ess/Qn4adPn9TS0qJkMqlYLFbspivCzMyM2traFI/H1dHRsbB8aGhIZ8+e1cePH7Vhw4asdeinn/5g9OfjOI6am5tdj0Ehvk5HbN68WTU1NXr9+nXW8rdv36qhocF1nWg0qmg0mrM8Foupvr7ez+atq6urU01NjdLpdNa+p9NpSdKaNbkXm9BPP/3B6C+G2zEouI6fB9fW1qq9vV3xeDxreTwe1549e3xvvNp49U9MTFjao/Kin/4w95eKr1fCktTf368TJ05o9+7d6uzs1KVLl5RMJtXb21uK/as4bv0vXrywvVtlQz/9Ye4vBd9D+OjRo3r//r3OnDmjmZkZ7dy5Uzdv3lRLS0tR60ejUQ0ODrr+ilIN3PrHx8d19+7doprop5/+6u33spIuX2/MAQBWl++zyLdv39bhw4fV2NioSCSi8fHxEuxW5aKffvrD218Kvofwly9ftGvXLp0/f74U+1Px6Kef/vD2l4Lvc8Ld3d3q7u4uxb5UBfrppz+8/aXgewj7tfRi7fn5eX348EGbNm1SJBIp9eZL7uvXr3IcR+l0Wo2NjTnXCdJPP/3B7c8wxngeg2JWXjZJZmxsLO9jwvSxxZGREfrpD22/lDsPwtbv9hwoZEVXR0QiEY2NjenIkSOej8n8JLxx44Z++uknDQ0NaWBgQNPT01X/iZlYLKbR0VHt27dPTU1Nunr1qo4dO5b1GPrpD0u/pJx5sPiV8I0bN3Tq1CnNz88Hon+xVCrl+RwoyPfYXkQuP/m8ZG6BF6Rb2WX6M02jo6Oej6Wf/qD3F5oHHR0d5uTJk4HpX6yY54CXsvx5o8wt8Lq6usqxuYpDP/1h7pe+H4ODBw/a3pWK4/uNuc+fP+vp06cL3z9//lyJREIbN25Uc3Oz6zrv3r3Tt2/fPG/yU03c+h8+fJh3HfrpD3O/9P0YbNmypZS7V5V8D+H79+/rwIEDC9/39/dLknp6enT58uW86wbh3VCv/mLQT3+1W0m/FIxjsNp8n47Yv3+/zP//LFLWV74B7HULzGrk1p/5awFe6Kc/zP3S92Pw5s2bMuxldSnLOWGvW+CFBf30h7lf+n4MuO1lrpJ/WCMjcwu8HTt2lGuTZTc1NeV5fpx++oPeL+V/j6i/v1/Hjx+3tGflke854GllF2b4c+HCBdPU1BS4S1QWX6IjyfT09Lg+jn76w9Cf7xicO3cucP3GFP8ccFP2W1mmUinFYjE5jhOYi7X9NNFPP/3B6pdW1lWWc8IAAHcMYQCwiCEMABYxhAHAIoYwAFjEEAYAixjCAGARQxgALGIIA4BFDGEAsIghDAAWMYQBwCKGMABYxBAGAIsYwgBgEUMYACxiCAOARQxhALCIIQwAFjGEAcAihjAAWMQQBgCLGMIAYBFDGAAsYggDgEUMYQCwiCEMABYxhAHAIoYwAFjEEAYAixjCAGARQxgALGIIA4BFDGEAsIghDAAWMYQBwCKGMABYxBAGAIuWNYQvXryo7du3q66uTu3t7bpz585q71dFW9o/OTlpe5fKin76w9y/6oxP165dM2vXrjUjIyPm0aNHpq+vz6xfv95MTU0Vtb7jOEaScRzH76Yrgld/sU30009/9fZ7WUmX7yHc0dFhent7s5a1tbWZgYGBotav9v8Et/7W1tbQPAnppz/M/V5W0vWDn1fNc3NzevDggQYGBrKWd3V1ef5KMjs7q9nZ2YXvHceRJKVSKT+brgiZ/tOnT2ft/969e/XkyRMZY3LWoZ9++oPRn0+mx+0YFORnYr98+dJIMvfu3ctaPjw8bFpbW13XGRwcNJJC8fXs2TP66ac/pP1ex6CQiDHFj+5Xr15p27ZtmpycVGdn58Ly4eFhXblyRY8fP85ZZ+lPwk+fPqmlpUXJZFKxWKzYTVeEmZkZtbW1KR6Pq6OjY2H50NCQzp49q48fP2rDhg1Z69BPP/3B6M/HcRw1Nze7HoNCfJ2O2Lx5s2pqavT69eus5W/fvlVDQ4PrOtFoVNFoNGd5LBZTfX29n81bV1dXp5qaGqXT6ax9T6fTkqQ1a3IvNqGffvqD0V8Mt2NQcB0/D66trVV7e7vi8XjW8ng8rj179vjeeLXx6p+YmLC0R+VFP/1h7i8VX6+EJam/v18nTpzQ7t271dnZqUuXLimZTKq3t7cU+1dx3PpfvHhhe7fKhn76w9xfCr6H8NGjR/X+/XudOXNGMzMz2rlzp27evKmWlpai1o9GoxocHHT9FaUauPWPj4/r7t27RTXRTz/91dvvZUVdft/Ju3Xrljl06JDZunWrkWTGxsZ8vxtYzeinn/7w9peC77PIX7580a5du3T+/Hn/Ez8A6Kef/vD2l4Lv0xHd3d3q7u4uxb5UBfrppz+8/aXAXdQAwCLfr4T9Wnqx9vz8vD58+KBNmzYpEomUevMl9/XrVzmOo3Q6rcbGxpzrBOmnn/7g9mcYYzyPQTErL5uKODEfpo8tjoyM0E9/aPul3HkQtn6350Ahvj62vFQ
"text/plain": [
"<Figure size 400x400 with 25 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
2024-07-26 21:07:40 +08:00
}
],
"source": [
2024-07-28 22:35:36 +08:00
"# Train the GAN for only 1000 images\n",
2024-07-28 06:10:19 +08:00
"batch_size = 16\n",
2024-07-29 00:18:35 +08:00
"display_step = 5\n",
2024-07-28 06:10:19 +08:00
"cur_step = 0\n",
"mean_generator_loss = 0\n",
"mean_discriminator_loss = 0\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-28 06:10:19 +08:00
"D_loss_grad = nn.value_and_grad(disc, disc_loss)\n",
"G_loss_grad = nn.value_and_grad(gen, gen_loss)\n",
"\n",
"\n",
2024-07-29 00:18:35 +08:00
"for epoch in tqdm(range(50)):\n",
2024-07-28 06:10:19 +08:00
"\n",
2024-07-28 22:56:26 +08:00
" for real in batch_iterate(batch_size, train_images[:500]):\n",
2024-07-28 06:10:19 +08:00
" \n",
" D_loss,D_grads = D_loss_grad(gen, disc, real, batch_size, z_dim)\n",
"\n",
" # Update optimizer\n",
" disc_opt.update(disc, D_grads)\n",
" \n",
" # Update gradients\n",
" mx.eval(disc.parameters(), disc_opt.state)\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-28 06:10:19 +08:00
" G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)\n",
" \n",
" # Update optimizer\n",
" gen_opt.update(gen, G_grads)\n",
" \n",
" # Update gradients\n",
" mx.eval(gen.parameters(), gen_opt.state)\n",
" \n",
" if cur_step % display_step == 0 and cur_step > 0:\n",
" print(f\"Step {epoch}: Generator loss: {G_loss}, discriminator loss: {D_loss}\")\n",
" fake_noise = mx.array(get_noise(batch_size, z_dim))\n",
" fake = gen(fake_noise)\n",
" show_images(fake)\n",
2024-07-28 22:22:40 +08:00
" show_images(real)\n",
2024-07-29 00:18:35 +08:00
" print(real.shape)\n",
2024-07-28 22:22:40 +08:00
" cur_step += 1"
2024-07-27 05:19:08 +08:00
]
2024-07-26 21:07:40 +08:00
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}