2024-07-26 16:07:40 +03:00
|
|
|
{
|
|
|
|
|
"cells": [
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"# Import Library"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-07-28 01:10:19 +03:00
|
|
|
"execution_count": 1,
|
2024-07-26 16:07:40 +03:00
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"import mnist"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-07-27 01:09:51 +03:00
|
|
|
"execution_count": 3,
|
2024-07-26 16:07:40 +03:00
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"import mlx.core as mx\n",
|
|
|
|
|
"import mlx.nn as nn\n",
|
|
|
|
|
"import mlx.optimizers as optim\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"from tqdm import tqdm\n",
|
|
|
|
|
"import numpy as np\n",
|
|
|
|
|
"import matplotlib.pyplot as plt"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"# GAN Architecture"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"## Generator 👨🏻🎨"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-07-27 01:09:51 +03:00
|
|
|
"execution_count": 4,
|
2024-07-26 16:07:40 +03:00
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"def GenBlock(in_dim:int,out_dim:int):\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" return nn.Sequential(\n",
|
|
|
|
|
" nn.Linear(in_dim,out_dim),\n",
|
|
|
|
|
" nn.BatchNorm(out_dim),\n",
|
|
|
|
|
" nn.ReLU()\n",
|
|
|
|
|
" )"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-07-27 01:09:51 +03:00
|
|
|
"execution_count": 5,
|
2024-07-26 16:07:40 +03:00
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"class Generator(nn.Module):\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" def __init__(self, z_dim:int = 10, im_dim:int = 784, hidden_dim: int =128):\n",
|
|
|
|
|
" super(Generator, self).__init__()\n",
|
|
|
|
|
" # Build the neural network\n",
|
|
|
|
|
" self.gen = nn.Sequential(\n",
|
|
|
|
|
" GenBlock(z_dim, hidden_dim),\n",
|
|
|
|
|
" GenBlock(hidden_dim, hidden_dim * 2),\n",
|
|
|
|
|
" GenBlock(hidden_dim * 2, hidden_dim * 4),\n",
|
|
|
|
|
" GenBlock(hidden_dim * 4, hidden_dim * 8),\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" nn.Linear(hidden_dim * 8,im_dim),\n",
|
|
|
|
|
" nn.Sigmoid()\n",
|
|
|
|
|
" )\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" def __call__(self, noise):\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" return self.gen(noise)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-07-28 01:10:19 +03:00
|
|
|
"execution_count": 6,
|
2024-07-26 16:07:40 +03:00
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
|
|
|
|
"text/plain": [
|
|
|
|
|
"Generator(\n",
|
|
|
|
|
" (gen): Sequential(\n",
|
|
|
|
|
" (layers.0): Sequential(\n",
|
|
|
|
|
" (layers.0): Linear(input_dims=100, output_dims=128, bias=True)\n",
|
|
|
|
|
" (layers.1): BatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
|
|
|
" (layers.2): ReLU()\n",
|
|
|
|
|
" )\n",
|
|
|
|
|
" (layers.1): Sequential(\n",
|
|
|
|
|
" (layers.0): Linear(input_dims=128, output_dims=256, bias=True)\n",
|
|
|
|
|
" (layers.1): BatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
|
|
|
" (layers.2): ReLU()\n",
|
|
|
|
|
" )\n",
|
|
|
|
|
" (layers.2): Sequential(\n",
|
|
|
|
|
" (layers.0): Linear(input_dims=256, output_dims=512, bias=True)\n",
|
|
|
|
|
" (layers.1): BatchNorm(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
|
|
|
" (layers.2): ReLU()\n",
|
|
|
|
|
" )\n",
|
|
|
|
|
" (layers.3): Sequential(\n",
|
|
|
|
|
" (layers.0): Linear(input_dims=512, output_dims=1024, bias=True)\n",
|
|
|
|
|
" (layers.1): BatchNorm(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
|
|
|
" (layers.2): ReLU()\n",
|
|
|
|
|
" )\n",
|
|
|
|
|
" (layers.4): Linear(input_dims=1024, output_dims=784, bias=True)\n",
|
|
|
|
|
" (layers.5): Sigmoid()\n",
|
|
|
|
|
" )\n",
|
|
|
|
|
")"
|
|
|
|
|
]
|
|
|
|
|
},
|
2024-07-28 01:10:19 +03:00
|
|
|
"execution_count": 6,
|
2024-07-26 16:07:40 +03:00
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"gen = Generator(100)\n",
|
|
|
|
|
"gen"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-07-28 01:10:19 +03:00
|
|
|
"execution_count": 7,
|
2024-07-26 16:07:40 +03:00
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"def get_noise(n_samples, z_dim):\n",
|
|
|
|
|
" return np.random.randn(n_samples,z_dim)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"## Discriminator 🕵🏻♂️"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-07-28 01:10:19 +03:00
|
|
|
"execution_count": 8,
|
2024-07-26 16:07:40 +03:00
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"def DisBlock(in_dim:int,out_dim:int):\n",
|
|
|
|
|
" return nn.Sequential(\n",
|
|
|
|
|
" nn.Linear(in_dim,out_dim),\n",
|
|
|
|
|
" nn.LeakyReLU(negative_slope=0.2)\n",
|
|
|
|
|
" )"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-07-28 01:10:19 +03:00
|
|
|
"execution_count": 9,
|
2024-07-26 16:07:40 +03:00
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"class Discriminator(nn.Module):\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" def __init__(self,im_dim:int = 784, hidden_dim:int = 128):\n",
|
|
|
|
|
" super(Discriminator, self).__init__()\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" self.disc = nn.Sequential(\n",
|
|
|
|
|
" DisBlock(im_dim, hidden_dim * 4),\n",
|
|
|
|
|
" DisBlock(hidden_dim * 4, hidden_dim * 2),\n",
|
|
|
|
|
" DisBlock(hidden_dim * 2, hidden_dim),\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" nn.Linear(hidden_dim,1),\n",
|
|
|
|
|
" )\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" def __call__(self, noise):\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" return self.disc(noise)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-07-28 01:10:19 +03:00
|
|
|
"execution_count": 10,
|
2024-07-26 16:07:40 +03:00
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
|
|
|
|
"text/plain": [
|
|
|
|
|
"Discriminator(\n",
|
|
|
|
|
" (disc): Sequential(\n",
|
|
|
|
|
" (layers.0): Sequential(\n",
|
|
|
|
|
" (layers.0): Linear(input_dims=784, output_dims=512, bias=True)\n",
|
|
|
|
|
" (layers.1): LeakyReLU()\n",
|
|
|
|
|
" )\n",
|
|
|
|
|
" (layers.1): Sequential(\n",
|
|
|
|
|
" (layers.0): Linear(input_dims=512, output_dims=256, bias=True)\n",
|
|
|
|
|
" (layers.1): LeakyReLU()\n",
|
|
|
|
|
" )\n",
|
|
|
|
|
" (layers.2): Sequential(\n",
|
|
|
|
|
" (layers.0): Linear(input_dims=256, output_dims=128, bias=True)\n",
|
|
|
|
|
" (layers.1): LeakyReLU()\n",
|
|
|
|
|
" )\n",
|
|
|
|
|
" (layers.3): Linear(input_dims=128, output_dims=1, bias=True)\n",
|
|
|
|
|
" )\n",
|
|
|
|
|
")"
|
|
|
|
|
]
|
|
|
|
|
},
|
2024-07-28 01:10:19 +03:00
|
|
|
"execution_count": 10,
|
2024-07-26 16:07:40 +03:00
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"disc = Discriminator()\n",
|
|
|
|
|
"disc"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"# Model Training 🏋🏻♂️"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-07-28 01:10:19 +03:00
|
|
|
"execution_count": 11,
|
2024-07-26 16:07:40 +03:00
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"# Set your parameters\n",
|
|
|
|
|
"criterion = nn.losses.binary_cross_entropy\n",
|
|
|
|
|
"n_epochs = 200\n",
|
|
|
|
|
"z_dim = 64\n",
|
|
|
|
|
"display_step = 500\n",
|
|
|
|
|
"batch_size = 128\n",
|
|
|
|
|
"lr = 0.00001"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-07-28 01:10:19 +03:00
|
|
|
"execution_count": 12,
|
2024-07-26 16:07:40 +03:00
|
|
|
"metadata": {},
|
2024-07-27 01:09:51 +03:00
|
|
|
"outputs": [],
|
2024-07-26 16:07:40 +03:00
|
|
|
"source": [
|
|
|
|
|
"gen = Generator(z_dim)\n",
|
2024-07-26 16:36:29 +03:00
|
|
|
"mx.eval(gen.parameters())\n",
|
2024-07-26 16:07:40 +03:00
|
|
|
"gen_opt = optim.Adam(learning_rate=lr)\n",
|
2024-07-26 16:36:29 +03:00
|
|
|
"\n",
|
2024-07-26 16:07:40 +03:00
|
|
|
"disc = Discriminator()\n",
|
2024-07-26 16:36:29 +03:00
|
|
|
"mx.eval(disc.parameters())\n",
|
2024-07-26 16:07:40 +03:00
|
|
|
"disc_opt = optim.Adam(learning_rate=lr)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"## Losses"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-07-28 01:10:19 +03:00
|
|
|
"execution_count": 13,
|
2024-07-26 16:07:40 +03:00
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
2024-07-26 16:36:29 +03:00
|
|
|
"def disc_loss(gen, disc, real, num_images, z_dim):\n",
|
2024-07-26 16:07:40 +03:00
|
|
|
" noise = mx.array(get_noise(num_images, z_dim))\n",
|
|
|
|
|
" fake_images = gen(noise)\n",
|
2024-07-27 01:09:51 +03:00
|
|
|
" \n",
|
2024-07-26 16:07:40 +03:00
|
|
|
" fake_disc = disc(fake_images)\n",
|
|
|
|
|
" \n",
|
2024-07-27 00:19:08 +03:00
|
|
|
" fake_labels = mx.zeros((fake_images.shape[0],1))\n",
|
2024-07-26 16:36:29 +03:00
|
|
|
" fake_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)\n",
|
2024-07-26 16:07:40 +03:00
|
|
|
" \n",
|
|
|
|
|
" real_disc = disc(real)\n",
|
2024-07-27 00:19:08 +03:00
|
|
|
" real_labels = mx.ones((real.shape[0],1))\n",
|
|
|
|
|
"\n",
|
2024-07-27 01:09:51 +03:00
|
|
|
" real_loss = nn.losses.binary_cross_entropy(real_disc,real_labels,with_logits=True)\n",
|
2024-07-26 16:07:40 +03:00
|
|
|
"\n",
|
|
|
|
|
" disc_loss = (fake_loss + real_loss) / 2\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" return disc_loss"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-07-28 01:10:19 +03:00
|
|
|
"execution_count": 14,
|
2024-07-26 16:07:40 +03:00
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
2024-07-26 16:36:29 +03:00
|
|
|
"def gen_loss(gen, disc, num_images, z_dim):\n",
|
2024-07-26 16:07:40 +03:00
|
|
|
"\n",
|
|
|
|
|
" noise = mx.array(get_noise(num_images, z_dim))\n",
|
|
|
|
|
" fake_images = gen(noise)\n",
|
|
|
|
|
" fake_disc = disc(fake_images)\n",
|
|
|
|
|
"\n",
|
2024-07-27 00:19:08 +03:00
|
|
|
" fake_labels = mx.ones((fake_images.shape[0],1))\n",
|
2024-07-26 16:07:40 +03:00
|
|
|
" \n",
|
2024-07-26 16:36:29 +03:00
|
|
|
" gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)\n",
|
2024-07-26 16:07:40 +03:00
|
|
|
"\n",
|
|
|
|
|
" return gen_loss"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-07-28 01:10:19 +03:00
|
|
|
"execution_count": 15,
|
2024-07-26 16:07:40 +03:00
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"train_images, _, test_images, _ = map(\n",
|
|
|
|
|
" mx.array, getattr(mnist, 'mnist')()\n",
|
|
|
|
|
")"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-07-28 01:10:19 +03:00
|
|
|
"execution_count": 16,
|
2024-07-26 16:07:40 +03:00
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"def batch_iterate(batch_size:int, ipt:list):\n",
|
|
|
|
|
" perm = mx.array(np.random.permutation(len(ipt)))\n",
|
|
|
|
|
" for s in range(0, ipt.size, batch_size):\n",
|
|
|
|
|
" ids = perm[s : s + batch_size]\n",
|
|
|
|
|
" yield ipt[ids]"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "markdown",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"source": [
|
|
|
|
|
"### show batch of images"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-07-28 01:10:19 +03:00
|
|
|
"execution_count": 17,
|
2024-07-26 16:07:40 +03:00
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
2024-07-28 01:10:19 +03:00
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAUkAAAFICAYAAADd1gwNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAACLYElEQVR4nOz92W+cV5rniX9i3/c9uAV3arUk27Lk9J6VzqWqq6qr0Oiu7K6bLqD7soEB5o+Yu8EAczHAzAAzg5pGY6qzXFmVma5ypjO9S5ZsbZRIcV9j3/c9fhf6nWNSomjZWhik3g9g2CaDwXgP3/c5z3mW76Pq9Xo9FBQUFBT2RH3QH0BBQUGhn1GMpIKCgsI+KEZSQUFBYR8UI6mgoKCwD4qRVFBQUNgHxUgqKCgo7INiJBUUFBT2QTGSCgoKCvugGEkFBQWFfdA+6gtVKtXT/ByHiu/bpKSs4Tcoa/j4KGv4+DzKGiqepIKCgsI+KEZSQUFBYR8UI6mgoKCwD4qRVFBQUNgHxUgqKCgo7INiJBUUFBT2QTGSCgoKCvvwyHWSB8VeNV2KmLqCgsLDUKlUT9RG9J2RNBqN6HQ6XC4X4XAYk8nEyMgIVquVTCZDLpcjn8+zvr5OrVajWq3SbDYP+mMrKCgcEFqtFoPBgNlsZnR0FLPZTLlcpl6vUywWiUajtNvt7//+T/CzPjYqlQqTyYTFYmFsbIwLFy7g8/l46623CAaDLCwssLCwwNraGh988AG5XI5Op6MYSQWF5xi9Xo/NZsPv9/P666/j9/uJx+Ok02m2trZIpVJHw0iqVCq0Wi2BQIBQKMTY2BjDw8N4PB4cDgcWiwW3200oFALg1KlT5PN5YrEY+XyeZrNJrVaj1WpRKBRotVoHfEUKCgpPG5VKhc1mY3BwkGAwyNDQED6fT3qX1WoVjUbzWL+jL4ykMJAWi4Uf/vCHvP766wQCASYmJjAYDFitVrRaLWNjYwwMDNBoNHjjjTeo1+usr6+TyWSIRqMsLS2RyWS4dOkS6XT6oC9LQUHhKaJWq1Gr1UxOTvKzn/2MQCDAxYsXcbvdxGIxUqkUJpOJP/zhD5TL5e/9e/rCSMI9Q6lWq7HZbPh8PtxuN3a7Ha1Wi0qlotPpoNVq0Wg0mM1m3G437XYbrVaLy+VCp9NRKpWAe+63gsKTQKPRoNfr5f15P91ul16vJ//d6XTodDoH8EmfP9RqNRqNBpvNxsDAAH6/H5/Ph9PppFar0Wg0MJvNe/7dvgt9YyQ7nQ71ep1PP/2UeDxOKBTi2LFj6PV6DAYDGo1GZqxcLhfHjh3DaDTi9/txOp30ej3S6TTdbheTyYROp6PT6dDtdg/4yhQOM9PT07z55ps4HA4CgQAmk0l+r1arEY/HqdfrJJNJSqUSGxsb3Llz57FiYArfjlqtxu12Y7VaiUQiTE1N4XK5MBqNdLtdotEos7OzrKysPHborS+MpNiJG40Gd+7cYX19naGhIXK5HEajEYvFglarla8dHh4mEolgNBpxOp1oNBpKpRIej4dKpbLLqCpGUuFxGBwc5Kc//Sl+v5/p6WmcTqf8Xj6f5+7duxSLRZaWlkgmk6jVau7evasYyaeMOHW6XC6CwSCDg4NYrVb0ej29Xo9MJsPKygrxePyx/xZ9YSRFTFKv1zMzM8PIyIjcuTudDslkkmq1SrvdptVqUSqVGB0dxev1MjAwgNPpxGazMTIygslk4gc/+AEjIyPEYjHi8TiNRoNCoaAYzG9BpVKh0WjQaDRYrVZ0Oh0WiwWTyYTRaMTtdqPT6TAajXLTgnsbV6vVot1uc+fOHZaXlw/wKh4fnU7H2bNnGR0d5YUXXmBgYAC73f5AGEev1+PxeDCbzQD4/X7UajWNRoNcLsf8/DzFYvEgLqEvUKvVsmbxST57arUanU6Hx+MhHA7jdrt33ZOdTodsNsvGxgapVOqxwx99YSTVarVM47/99tv88Ic/pFarUSgUSKVSfPHFF6yurlIulymXy4yNjWG1WhkYGOCNN97A6XTi9/txu91Uq1VGRkbI5/P84Q9/4IsvviCdTlOtVmk0Ggd9qX2NRqPBYDBgMpmIRCLYbDaGh4cJh8P4/X5Onz4tY8ZWq1WGP9rtNpVKhXK5zP/8P//PrKysHOqCf5PJxH/4D/+Bv/qrv0Kn02EymWT86/7XRSIRer0ek5OTdLtdTp8+zdmzZ1lbW+N//V//1+faSGo0GrRaLe12+4kZSbVajVarxWg0EolEpFNlNpvR6XR0u11arRYbGxt89dVXlEqlxy4R7AsjCd8kbsxmM06nE51Oh0qlotVqYbfbsdlsVKtVqtUqxWKRRCKBWq0mlUrh9XoxGAwYjUaMRiMejweDwUA4HGZwcBCNRkM0GgWg1WopHiX31lun06FWq7FYLBgMBmkQjEYjQ0NDOBwOgsEgPp9vVymWXq+XN6ter6fb7cq1F17VYUKr1cprcTgcuFwuQqEQXq+XTqdDu92m2WzuOs3sPMKp1WpMJhN6vZ5Go0GtVqNerz+XR+6dpxGn04nJZKLdbsv1qFartFqt772JarVabDYbNpsNj8eDz+fDZrOh0WjodrsUCgXK5TLFYpFKpUKj0XjsDbtvjKS4EHGcs9vtjI2NUalUMBqNJBIJfvOb3/DP//zPpNNpfvvb32Kz2cjn88zMzDA+Ps7Zs2fR6XR4vV5pKF944QUZSE8kEsRiMQqFwgFf7cFjMpkIhULYbDZef/11pqamMJlM2O12DAYDXq8Xo9FIs9mk2WyiUqmo1WpUKhWuX79OoVAgHA4zOjqKyWSSN6tOpzvoS/vOuN1u/H4/ExMT/Jt/828Ih8NMTU0BUCwW2d7eJp/Pc+nSJeLxOIlEgkQiIX/eYDDw8ssvMzo6yu3bt/noo48olUrEYrGDuqQDQ6/X4/V6sVqtvPnmmxw/fpxSqUQymSSXy/H5558Tj8dptVrfK6Hicrk4d+4cgUCAd999l1OnTkkjmclk+N3vfsf29jbXr18nk8k8keRtXxlJYShVKpV8UO12O7VajUAgwLVr19BoNDQaDdbX1zGZTITDYVQqFUajkZmZGVQqlayr1Ov1hEIh2u02fr+fdrtNNptFrVbv+n3PEyqVSq6vw+HA4/Fw/PhxXnzxRUwmk/Ti7XY7Op2OdDpNNpul0WiQz+epVCpsbGzIG93lctHtdvH5fGg0mscutzgITCYTXq+XkZERXn/9dYaGhoB792S9XiebzZJIJLh16xZra2usr6+zvr6+6+dFNcWdO3e4du3acxva0Wq1WK1WnE4nExMTvPjii2SzWba3t0kmk8zOzkrj9X0wmUwMDg4SDoeJRCJEIhH5PWEXlpeXZdXBE7mmJ/Iuj4mII5TLZT788EOi0ShjY2PywbVarVitVt555x3MZjPb29t89tlnVKtVVldXKRQKpNNp0um0TOCImGUwGGR4eJg///M/J5vNsri4SDKZZG1tjbm5OZrNJo1G40gfwUUtqdFoZHx8nNHRURwOh1ynqakp7HY7pVKJ2dlZ6vU68XhcxoWLxaI8MrVaLdLpNJVKBbvdLmOTV65coVQqsba2dqg2H5VKxdDQEBcvXmR8fByTyUSn02F5eZlEIsHy8jJffvkl+Xyeubk58vn8AyeRVqvF3bt3ZQfY83jM1mq16HQ6QqEQ7777LsFgkDNnzjAwMIDD4cDhcOB0OvF6vaTTaXq93nfaSETt9NTUFBcvXiQYDOL1egFkt10qlWJxcZG7d++Sy+We3LU9sXd6DER2tNVq8cEHH/Db3/6W8+fP0263CYVCvPbaa4RCIaxWKy+++CJXrlxhbm6OarXK8vKy7Om+efOmdMeDwSCvvfaaLBcaHh6m0WgwPz9PPB7nD3/4A7FYjEql8kQDy/2ITqfD7/fjcrl45513+NGPfoTdbmdoaAidTketVqPZbBKPx7l58ybxeJyPPvqIeDxOtVqlVqsB34RE9Ho9Go2G6elpaVy/+OILtra2WF1dPchL/c6oVCpGRkZ488038Xq9mM1m2u02d+/e5caNG9y4cYP333+fWq0mTx/3bwLtdpu5uTnm5+ef2xO
|
2024-07-26 16:07:40 +03:00
|
|
|
"text/plain": [
|
|
|
|
|
"<Figure size 400x400 with 16 Axes>"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "display_data"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"for X in batch_iterate(16, train_images):\n",
|
|
|
|
|
" fig,axes = plt.subplots(4, 4, figsize=(4, 4))\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" for i, ax in enumerate(axes.flat):\n",
|
|
|
|
|
" img = mx.array(X[i]).reshape(28,28)\n",
|
|
|
|
|
" ax.imshow(img,cmap='gray')\n",
|
|
|
|
|
" ax.axis('off')\n",
|
|
|
|
|
" break"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-07-28 01:10:19 +03:00
|
|
|
"execution_count": null,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"def show_images(imgs:list[int]):\n",
|
|
|
|
|
" fig,axes = plt.subplots(4, 4, figsize=(4, 4))\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" for i, ax in enumerate(axes.flat):\n",
|
|
|
|
|
" img = mx.array(imgs[i]).reshape(28,28)\n",
|
|
|
|
|
" ax.imshow(img,cmap='gray')\n",
|
|
|
|
|
" ax.axis('off')\n",
|
|
|
|
|
" plt.show()"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 18,
|
2024-07-27 01:09:51 +03:00
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
|
|
|
|
"text/plain": [
|
2024-07-28 01:10:19 +03:00
|
|
|
"array(0.683622, dtype=float32)"
|
2024-07-27 01:09:51 +03:00
|
|
|
]
|
|
|
|
|
},
|
2024-07-28 01:10:19 +03:00
|
|
|
"execution_count": 18,
|
2024-07-27 01:09:51 +03:00
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"z_dim = 64\n",
|
|
|
|
|
"gen = Generator(z_dim)\n",
|
|
|
|
|
"mx.eval(gen.parameters())\n",
|
|
|
|
|
"gen_opt = optim.Adam(learning_rate=lr)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"disc = Discriminator()\n",
|
|
|
|
|
"mx.eval(disc.parameters())\n",
|
|
|
|
|
"disc_opt = optim.Adam(learning_rate=lr)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"g_loss = gen_loss(gen, disc, 8, z_dim)\n",
|
|
|
|
|
"g_loss\n"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-07-28 01:10:19 +03:00
|
|
|
"execution_count": 20,
|
2024-07-26 16:07:40 +03:00
|
|
|
"metadata": {},
|
2024-07-27 01:20:00 +03:00
|
|
|
"outputs": [
|
|
|
|
|
{
|
2024-07-28 01:10:19 +03:00
|
|
|
"data": {
|
|
|
|
|
"text/plain": [
|
|
|
|
|
"60000"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"execution_count": 20,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "execute_result"
|
2024-07-27 01:20:00 +03:00
|
|
|
}
|
|
|
|
|
],
|
2024-07-27 01:09:51 +03:00
|
|
|
"source": [
|
2024-07-28 01:10:19 +03:00
|
|
|
"len(train_images)"
|
2024-07-27 01:09:51 +03:00
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2024-07-28 01:10:19 +03:00
|
|
|
"execution_count": 65,
|
2024-07-27 01:09:51 +03:00
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stderr",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
2024-07-28 01:10:19 +03:00
|
|
|
" 3%|▎ | 6/200 [02:24<1:16:40, 23.71s/it]"
|
2024-07-27 00:19:08 +03:00
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
2024-07-28 01:10:19 +03:00
|
|
|
"Step 5: Generator loss: array(8.15901, dtype=float32), discriminator loss: array(nan, dtype=float32)\n"
|
2024-07-27 00:19:08 +03:00
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"name": "stderr",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
2024-07-28 01:10:19 +03:00
|
|
|
" 6%|▌ | 11/200 [04:22<1:13:56, 23.47s/it]"
|
2024-07-27 00:19:08 +03:00
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
2024-07-28 01:10:19 +03:00
|
|
|
"Step 10: Generator loss: array(8.52206, dtype=float32), discriminator loss: array(nan, dtype=float32)\n"
|
2024-07-27 00:19:08 +03:00
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"name": "stderr",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
2024-07-28 01:10:19 +03:00
|
|
|
" 8%|▊ | 16/200 [06:18<1:11:26, 23.30s/it]"
|
2024-07-27 00:19:08 +03:00
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
2024-07-28 01:10:19 +03:00
|
|
|
"Step 15: Generator loss: array(8.47402, dtype=float32), discriminator loss: array(nan, dtype=float32)\n"
|
2024-07-27 00:19:08 +03:00
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"name": "stderr",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
2024-07-28 01:10:19 +03:00
|
|
|
" 8%|▊ | 16/200 [06:31<1:15:05, 24.49s/it]\n"
|
2024-07-27 00:19:08 +03:00
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"ename": "KeyboardInterrupt",
|
|
|
|
|
"evalue": "",
|
2024-07-26 16:07:40 +03:00
|
|
|
"output_type": "error",
|
|
|
|
|
"traceback": [
|
|
|
|
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
2024-07-27 00:19:08 +03:00
|
|
|
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
2024-07-28 01:10:19 +03:00
|
|
|
"Cell \u001b[0;32mIn[65], line 38\u001b[0m\n\u001b[1;32m 35\u001b[0m gen_opt\u001b[38;5;241m.\u001b[39mupdate(gen, G_grads)\n\u001b[1;32m 37\u001b[0m \u001b[38;5;66;03m# Update gradients\u001b[39;00m\n\u001b[0;32m---> 38\u001b[0m \u001b[43mmx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43meval\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgen\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparameters\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgen_opt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstate\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m cur_step \u001b[38;5;241m%\u001b[39m display_step \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m cur_step \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 41\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mStep \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: Generator loss: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mG_loss\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, discriminator loss: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mD_loss\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
|
2024-07-27 00:19:08 +03:00
|
|
|
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
2024-07-26 16:07:40 +03:00
|
|
|
]
|
2024-07-28 01:10:19 +03:00
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
|
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAUkAAAFICAYAAADd1gwNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAACQe0lEQVR4nOy9d5Rc13ng+XuVc1VXVeeAbqCRA0ESBMGAwEyKQYGmLGkkUR6au+OxpbVnZe+xbNkzuw5zjnc1HFkeh/HYEhVMyhRFSRQDQDBBJJFTIzTQOYeqrpzT2z/ge1ndABogAXQ3qt/vnD4kuqsL9T5897v3flFRVVVFQ0NDQ+OC6Ob7A2hoaGgsZDQjqaGhoTELmpHU0NDQmAXNSGpoaGjMgmYkNTQ0NGZBM5IaGhoas6AZSQ0NDY1Z0IykhoaGxixoRlJDQ0NjFgyX+0JFUa7l57iu+LhFSpoMP0ST4ZWjyfDKuRwZaidJDQ0NjVnQjKSGhobGLGhGUkNDQ2MWNCOpoaGhMQuakdTQ0NCYBc1IamhoaMyCZiQ1NDQ0ZkEzkhoaGhqzoBlJDQ0NjVnQjKSGhsaCYSFWA2lGUkNDY8GwEOcSakZSQ0NDYxY0I6mhoaExC9e9kVQUBZ1OtyB9GQsVTVYa842iKPJLoNPp0Ov16HS68147n1z3RlKn08kvjctnpuLNtyJqLD4upIMzDzzi/+dTPy+7n+R8IHYag8GATqfDbDZjNpsxGAzYbDZMJhNerxen00mpVKJYLJLJZBgbGyOTyVBfX099fT2BQIDjx4+TTqcXpGP4WmMymaiursZqtQLnnONOp5Pm5mYMBgO5XI58Pk8mkyGVSpHP54lGoxSLRZqbm2loaGB8fJzDhw+TSqXm+WnmjwudcBRFwWg00tbWht/vB6BUKpHP5wkEAmSzWfx+P1VVVYRCIc6ePUsmk5mPjz+viIOMxWJBr9djMpkwmUxYrVbq6uowGo3T5CtkODk5STabpbGxkdraWgKBAEeOHCGVSs3ZWl6wRlIooF6vx263YzKZqKqqkkaxubkZp9PJhg0baG5ulkYyGAyya9cuJicneeCBB9ixYwcffPABf/EXf8HY2BilUmlRGUqdTofNZmPDhg3U1dVRLBYplUosXbqUT33qU7jdbiYnJ4nFYoRCIUZHR4nFYpw5c4ZkMsmjjz7K/fffz7vvvss3vvGNRbvRCF0UC1lVVXQ6HQaDAafTyQMPPMDmzZtRVZVisUgkEmHfvn1MTU2xadMmbrjhBo4dO8bf//3fk81mF5UMxTXaaDTi8/mwWq243W7cbjeNjY3cdddduFwuSqUSpVJJGtRIJMKePXuYmprioYce4p577uFXv/oV3/zmN0mn08DcRMMXnJEUx22bzYbb7cZgMOB2uzEajXi9XqqqqrDZbFRXV2Oz2XA6ndhsNuDDI3ljYyNmsxm/34/L5cJuty/K67jT6cTr9eLxeGhubqa2tpZCoUCxWKSurg6fz4fD4WBqaopcLoder5eyrq+vJ5PJyNdYrdZFJUOj0Yher8diseBwOKa5dAwGA3q9nmKxSC6Xw2az4fV6qa6ulgvdbDbT2NiIzWajqqpK3oAWk1tDHHQcDgderxeLxYLf78dms+FyuXA6nTQ0NFBVVYXD4SCRSJDP5zEajXJN19fXY7Vaqa6uxuVyYbPZ5lwPF5SRVBQFu92OxWLhxhtv5JFHHsHpdOLxeDCbzdhsNmw2G5lMhmAwSKFQoFAoMDAwgNvtxu/3U11dzeOPPy7fMxKJEIvFKBQKi+IUKfw6Op2OLVu28NRTT1FVVUV1dTUWi4VCoUA+n8dkMqHT6QiHw+zatYt9+/axcuVKbr31VhoaGti0aRMmk4lSqUR3dzcjIyPk8/mKlx+AXq+ntrYWj8fD+vXrue+++zAajSSTSfL5PNXV1fj9fsbGxti7dy+FQoE1a9bQ1NQEIPVsxYoVFItFent7OXPmDENDQ+RyuUUhQ0VRMJlMGAwGtmzZwhe+8AXcbjdOp1N+32AwUCqVyGazZLNZzp49y/DwMI2NjSxfvhyr1crDDz+M0WhEURRGRkYIBAJSDxfldVv4d6xWK42NjWzevBmPx4PH48FkMmE0GjGbzUQiEXkdDAQCJBIJTCaTPIGK3Wd0dJTx8XHy+fyiMZDwof+noaGBrVu3UlVVJY1iNpsll8tRKBRIJpMkk0n6+vo4fvw4brcbs9mM2+2mvb0dh8NBf38/g4ODxONxSqXSPD/h3CA2a6/Xy9KlS7n99tsxm83yxN3Y2EhDQwP9/f1MTEyQSqXweDzYbDb5b2A0GnG5XOj1eiYnJ4lEIiQSCYrF4jw/3bWn3FVmNBppaGjg9ttvl7dAYRxVVSUajdLb20sikSAYDDI8PIzFYqG1tRWHw0Fraytut5uRkRHGx8dJp9NzrocLxkgKodbU1FBfX09zczN+vx+9Xs+pU6eIRqMEAgECgQCpVIpgMEgulyOZTJLNZtm0aROtra2kUil2795NMpmUzl673Y5er0dRlIo1lOIEKWTodrupra2Vm0ckEiGbzXLo0CEOHz5MoVAgl8uRyWQ4evQo0WiUaDRKLBYjn88zPj4OnDtV6fV6qZiVLEM4f4GLE4+Q3ejoqNSlSCTCqVOnyOfzDAwM4Pf72bBhA3fffTfpdJq33nqLYDCIXq+nubmZWCyGwbBgltw1Q8jP5/Phdruprq6WLq+hoSFSqRRnz56lq6uLVCpFOBwmnU4zMDBAMBjE4/FgtVpRVZXDhw9TKpVwOBwyWCtccovyJKnX66mvr2flypW0trZSU1NDMpmko6ODrq4uDh48yNGjR6edCsWx22Kx8MgjjxCPx3n99dcZGBjgM5/5DCtXrpQ+pUpf4CKQUFdXR1NTE/X19ZjNZgDC4TDRaJRXX32V7373uxSLRSk74UcTRjISidDZ2Uk8Hmf16tUsX7684o1keaqJkKOIwBoMBpLJJPv37+fYsWNMTEwwMTEhZSd+T1EUnnzySe677z5SqRQ/+clP6Ojo4JFHHuHee+8lGAxWvJEs32R8Ph/19fXU1tbicDhQVZXBwUHGx8f5xS9+wcsvvzxND8VXe3s7NpuNfD7PgQMHZPBr48aN04zkXOnivP6LuVwu3G63PAGZTCba29tZvnw5Op2OgwcPkkgkGBwclNfqi/kjJicnpQF1uVwsWbKEVCrF0aNH6enpoVAozKlg5wpx0jOZTLjdbqxWK8uWLWPp0qWYzWaOHDkCwNjYGPF4nJGREXK53AWvfdFolJ6eHoxGI3a7HbvdDkAgECASiUiFrjR8Ph9+vx9VVclms+j1epYvX05bWxsmk4kjR44QiUTk1TqTyZznnxW6NT4+zr59+8hms1RXV7NmzRpKpRKnT59mcHCQfD4/j0967fB4PHi9XpllotfrpR6aTCYOHz5MPp/n1KlTTE1NMTk5SS6Xu+DVORQKceLECblR+Xw+0uk03d3djI6OUigU5vTZFPUytf5qR+UUReHGG2+UAQKr1YrVamXbtm2sXbuWXbt28Td/8zfE43H5dweDQYLB4DQjKXYVl8uF1+ulqamJL3/5yyxZsoTXX3+dN998c5pxEAGcK2EhzTu22WzSf7Z+/Xp8Ph8PPvggN954I3v27OHZZ58lFosRj8fJ5XKEQiFCodB5n0v44dxuN01NTfzmb/4mS5cu5ejRo3R0dDAyMsL7779PMpm8KoZyochQURTuueceHnjgATKZDMPDwyiKImW4a9cuvvOd7xCLxUilUhQKBVKp1LR8UeFL1+v1MqOgubmZ3/zN36StrY1/+Zd/4ac//am8WhYKhYqT4e23387WrVvJZrPSxfDII49w880388477/DP//zPRKNRucGEw2HC4TBw7jnKA45Chg0NDXzpS19iyZIl7Nu3j8OHD0/LeRbpbFfC5chwzk+SiqJgNptlzlRjY6M8SZrNZlwuFw6Hg2w2S3d3N4lEQgZustnsee8lAjr5fJ6xsTEsFgsul4v6+npyuRz9/f3k8/k5332uJTqdDqvVKk98DodDnoZ8Pp9MlRKRVeHzEek/F3o/RVHIZDLkcjk
|
|
|
|
|
"text/plain": [
|
|
|
|
|
"<Figure size 400x400 with 16 Axes>"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "display_data"
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
|
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAUkAAAFICAYAAADd1gwNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB/D0lEQVR4nO29d3ic1Zn3/33K9KKZ0ah3WbJccbdjY2OqA2w2cUiykEAIW5JsyO4mBEJINpvdZNns+25JWdL2yv5CTDoswYZAbIoBG1vGtmxZtlzVpZnR9N7b7w+/5zAaS2P1mZHO57p0gaVnZp7nnnPuc5+7HS6dTqfBYDAYjHHh830DDAaDUcgwJclgMBg5YEqSwWAwcsCUJIPBYOSAKUkGg8HIAVOSDAaDkQOmJBkMBiMHTEkyGAxGDpiSZDAYjByIk72Q47i5vI+iYrpFSkyG78FkOHOYDGfOZGTILEkGg8HIAVOSDAaDkQOmJBkMBiMHTEkyGAxGDpiSZDAYjBwwJclgMBg5YEqSwWAwcsCUJIPByCscxxV07uakk8kZDAZjtqmtrcWtt94KhUKBjo4ODAwMIBqNIhgMIpVK5fv2ADAlyWAw8khjYyMefvhhGAwG/PCHP0QkEoHb7UY4HEY6nZ52VdFssiiUpFarRX19PXiex9DQEDweT75vqagRRRE8zyOZTCKZTOb7dooOURTR2NgIg8EAq9WK4eHhgrGa5ptAIIDLly9Dq9XCZrMhHA4jHo9PqBzlcjkMBgM4joPb7UYoFJrze+Qme1piIfsMrsfmzZvx6KOPQiqV4nvf+x7efvvtGb3fYq6Z5XkeGo0GcrkcoVAIgUBgWvJYzDI0GAx44okncPvtt+O5557D9773PYTD4Sm/T7HLkOM4KJVKVFVVQRAEOJ1OBAIBJJNJxOPxcV/T2NiIXbt2QRAEvPHGG7h8+fKM7mEyMlwUlqRcLkdVVRVkMhmUSmW+b6foEQSBWpOLFY7jpq2keJ6H0WhEQ0MDtYoWK9FoFGazGQAQi8WQSCRyXi+KIkpKSiCKIqRS6XzcYvFakuR+JnP71dXV2LJlC0RRxPHjxzE4ODijzy72FXwmcBwHqVQKURQRj8cRi8Wm9T6LWYZyuRw33ngjmpqacP78eRw/fvy6ymE8ilGGmZ+dTqfR0tKCe+65BwqFAvv27UNnZ2fO15eUlKCpqQmCIKC/vx8ul2tG9zMZGRatkiRWTD6cu8U4OAsNJsOZU4wyJOk+ZN7eeOON+Ld/+zdotVp84xvfwAsvvDCv97OottsNDQ1YtWoVIpEITp8+PeMVhsGYK0RRRE1NDUpKSuB0OmGxWBZN4IbneXAch1QqhXQ6DYfDgTfeeAMKhQLDw8P5vr1xKVolmbkCcByHTZs24fHHH4fVasU//MM/MCU5h0zF1bEYmKp/UiaTYePGjVi+fDk6Ojpgs9kWhZLkOA6CIIDneSQSCaRSKfT19eE73/kOOI5DMBjM9y2OS9EoyesNxEQigVAohHA4vCgGHKMwmMzWNbOihFhQJIK7mFKoOI6DRCKBIAhIp9NIJBKIx+MFn5JXFD5JURTBcRySySRVgNnWTHV1NVpaWhCNRnHx4kV4vd45u59i9AVlkn0f2VZ59r95nkc6nZ7VxWchyzAbuVwOmUyGRCKBcDgMQRBQVVUFrVYLl8uF0dHRacm22GQolUpRVVUFhUIBh8MBh8ORl/vIpOh9kmQFzvRjZP+dPKTFYoHNZqMOYfJacs1CtS6zrZSJyE7XyTU4MuVa6HW14zFZmczk+vH+fyKZCoIAqVQKnudporTdbofL5UI8Hl+wYxMYa8zwPA+FQgG1Wg2fzzep1403budbXgWpJDmOQ2VlJZYsWQIAGB0dRTgchs/ng9/vp9dlD8pUKgWlUonGxkaUlJSgsbERdXV16Ovrw/79+6/7xRQCU7HcOI7D5s2bsW3bNlgsFrz66qtwu91j5MLzPHQ6HW677TZUVVXh+PHjOH78OIDJraKCIEAikdB7ycwmKFSf5IYNG3DjjTfCZrPhtddey2mxjCfD8fzZEokEFRUVkMvlcLvd9JrryYJsMRUKBaqqqrB+/XrwPI+3334bfX19BSvDmSIIAmpqamAwGOB2uzEyMoJEIgGn0wm/3z9mHmej1+tRUVGBkpISrFq1Cnq9nvoy+/v76VyeL2VZcEqSrOrV1dW45ZZbkEgkqHM7mUxS4WYPLvJvuVyONWvWoL6+HjfffDO2bduGN954A+3t7UWhJHmehyiK1Hd1PYtv27Zt+PKXv4yOjg6cOnUKHo/nGivQYDDg3nvvxcaNG/H9738fJ06cGHeAZVpU5D1I0i5xdWTfV6FNchLEe/TRR3Hu3DmcPn36ukoyU4YnT54cV0lKpVLU1tZCr9fjypUrcLlck352iUQCpVKJZcuW4eGHH4YoirBarejt7S04+c0WgiCgvr4ebW1t6OnpwejoKKLRKP0ucj23Xq/H8uXLUV9fj/vvvx/Nzc2QSqWQSqV47bXXcPz4cVrpNR/yy6uSlMlkkMlkSCaTCIfD4DgOLS0tqKysRGtrK5qampBMJhGLxeByueB0OuF0OuHxeNDb24tIJELfi1hgMpkM1dXVtDaWOIqLZcsol8uhVquRSCTg9Xqvm2RssVjQ0dGBS5cuUXlIpVIIgoBUKkUDBMPDw9BoNPB4PNQ/lkgkrqvwiGIsKSlBbW0teJ5HT09PQfiTJoJYbiqVCjqdDjqdDqFQaMLE9/FkmE0qlUIoFIIoivR9sheViSZsIpFAJBJBIpGAVCqFRCKhPvaFqiR5nkdlZSVaWlqgUCjA8zyi0Sh8Ph8SiQRUKhUUCgX8fj8GBgYQiUToQqzT6dDa2kqr5FKpFDiOgyiK0Gq1aG5uhiiKGB0dnRfDJ69K0mAwoKamBoFAAAMDA5BIJPjUpz6Fe+65B3K5HCqVCul0GnfccQeNhMXjcbz77rv41re+NSavilhger0eO3bswLp166BUKscEe4oBo9GI5uZm+P1+XLx4EYFAYMJrU6kUXn31VZw8eRKRSAQ2mw3A1aoErVaLQCAAp9MJl8uF3/3ud9Dr9bBardDpdIjFYrRONlNG2ZM2mUwiGo2itbUVX/va1yCVSvEv//IvePXVV+dOCDOEWG56vR4tLS0AgP7+flit1muuzZbheNcAV0vmBgcHIYoirbMmExcAXXCySafTCAQCiEQi8Pl8kEqlkMlkSKfTORs5FCOZC4ZUKsXmzZvxsY99jC7UoVAI3d3d8Hg8WLJkCZqbm9HZ2Yn//M//hMlkQjgcRiwWw/Lly3HffffR+U9cSBKJBFVVVfjoRz8Ki8WCF154AV1dXXP+XHlVksShLZfLodVqIZfL0dDQgLa2NnpNMpmETCajDu50Og2DwQBBEK55P2JNKpVKaLVaxGIxeDweBAKBolGUEokEWq0WqVRqUrXRLpdrzPaQ53kIgkB9OAAQj8fpqhuLxejfr2ddcxwHmUwGhUIBnU6H8vJy+n0VIqIoQhRFunsg/08s64nIluF4pFIphMNhmuOXSa7gVqZPUiaT0d8XShuw2UYQBDpeysrKUFFRQf3awWAQoVAITqcTNTU1qKyshMFgoN8PCdCKokgzAqLRKJLJJEKhENLpNCKRCFQqFTQaDSQSybw8U16VpNvtRjweR1NTEx577DHU1NRg8+bNY65xOBz4+c9/jkuXLlElabFY4HQ6x1yXSqUQj8fh8/lw9uxZxGIxnDlzBufPn8fIyAjcbvd8Ptq04DiOLhiJRGLKDSTIRPX5fLTlFJGL0+mE1+ulW3Dy+4kmK8/zkEgkuPPOO3HHHXcgEAjgZz/7GbxeL7q7u2fleWcTmUyGZcuWwWg0orGxETzPIxKJYGBgAFeuXJlxShjJ68uM/JN8RwB0S5gdNJNIJLj77ruxa9culJeXI5lMwmazTbitL2bS6TQaGxvx53/+56ivr8fGjRupawG4uohVVlZCqVTi2LFjeOqppzA6Ooq+vj4EAgGaYH7s2DF885vfRG1tLT7+8Y+jrq4O+/fvx8GDBxEMBuF0OhEOh2fcg2Gy5FVJBoNBBINBtLa
|
|
|
|
|
"text/plain": [
|
|
|
|
|
"<Figure size 400x400 with 16 Axes>"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "display_data"
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
|
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAUkAAAFICAYAAADd1gwNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAACPVElEQVR4nOy9eXjcV3no/5l9X6WRRrtkW973xFv2QDayE2hJgUJ4gPaS0vZC+9z2ttxfS6G0l94WQtrC5fY2KUmBJJBAAtkTk8RZvMW7ZFn7NhpJo9n39feH7zkZ2bLixLE0kr6f59HjRBqN5/v6nPe8511VpVKphIKCgoLCjKjn+wMoKCgoVDKKklRQUFCYBUVJKigoKMyCoiQVFBQUZkFRkgoKCgqzoChJBQUFhVlQlKSCgoLCLChKUkFBQWEWFCWpoKCgMAva832hSqW6mJ9jQfF+i5QUGb6DIsMLR5HhhXM+MlQsSQUFBYVZUJSkgoKCwiwoSlJBQUFhFhQlqaCgoDALipJUUFBQmAVFSSooKCjMgqIkFRQUFGZBUZIKCgoKs6AoSQUFBYVZUJSkgoLCgmUuqofOuyxxrlGpVKjVp3V4sVikVCpRW1uLx+MhHo/j8/nIZrPz/CmXDg6Hg7a2NvR6PcViEYDJyUmGh4fl/yu8g16vx+PxYDQaCQaDhEIhYPqmXioz+Mr3cqFQAKCmpgaPx0MikbigvTwXMqxYJanVajGZTJRKJVKpFKVSiWuuuYa77rqLjo4OfvCDHzA+Pr4kF918sG7dOr72ta9RV1dHoVCgUCjw85//nPvvv59UKjXfH6/icDqd3H777TQ1NfH888/zyiuvoFKp0GpPbzlx8Is/FzNarRaDwUCpVCKdTgNw5ZVXcuedd9LR0cH//b//l4mJife0l1Uq1ZzJrWKVZPnpYzKZUKvV1NXV0d7eTjQaRa/Xo1Kp5Be8I1jFsrkw1Go1BoMBjUZDqVSiVCpRVVXFypUraWhoIJ1Ok81mqaqqkrLXaDSoVCqKxaIif04rBqfTSXV1NSaTSX5/pjVbKpVQqVTodLppMiwWi4vmtqRWq1GpVHIve71eVqxYQSwWO2svC3nAzHt5rht0VKySLBQKpFIpqqqquOOOO1i+fDnr16/H4/FQVVWF3W7HZrPJRWY0GrFareTzeSYnJ8lkMvP9CAuW2tpaPv3pT7Ns2TICgQCRSISmpiZisRg9PT288cYb9Pf38/bbb5PP5zGbzTQ3N2M2m/H5fExMTMh/l6WGONxjsRgvvPACNpuN/v5+KYt8Po9Wq8Xr9WKz2TAajZjNZpxOJ9u2bcPpdOLz+QgEAgwODrJnzx6SyeQ8P9WFUSgUSKfTuN1ubrnlFlpbW9m4cSM1NTW43W4cDgfRaFSuGYPBgNlsJp/PEwwGz9rLc72uKlpJFgoFjEYj119/PZdddhl6vR6dTofNZsNisWAymSgUCpRKJaxWKzU1NWQyGSKRiKIkLwCXy8Xtt9/Orl27GBwcZGRkBJVKRSqVIhqN8pvf/Ib9+/cTiUSkkmxqasLpdJJMJgkEApRKJel/Wqi8nyudsIZSqRSHDh2adlgImWi1Wlwul1SUbrebxsZG7r77bhoaGujs7KS3t5cDBw6wf//+Ba8khUVsMBi45ppr2L59O2azGbPZjN1ux2QynbWXXS4X2WyWWCw273u5YpVkc3MzmzZtor6+HrPZTDQaxeFwYDabsdlstLW1odVqGRkZIRgMotFomJqaIp/Pk8vl5PvMpe9ioSKuQsIK0ul08vpssVjweDxMTk5y5MgRAoEAPp+PeDxOPp9Hp9PhcDhYv349tbW1TE5O0t3dvShk/l6eQcjPZrPJK3ZjYyNGo5HOzk66urpQq9VotVo0Gg3xeJzJyUnsdjuNjY3U19djNBrRaDSo1Wr552Lo/djU1CQtR51Ox9TUlHRH2Gw2mpqaAPD7/YTDYSnLfD5PPp8H5jfgVbFK8pJLLuGv/uqvsFgsBINB/H4/er2e6upqPB4Pu3btorm5meeee46RkRFSqRTBYBB4J4ImNr1wkCucjUqlQqPRyC+9Xi/9RiqViqqqKpxOJxMTE/zyl79kZGSEsbExotEoWq0Wo9GI1+vlIx/5CCtWrKCvr49XXnllUSjJ2Sj3KYp1Jnxtmzdvpr6+nltvvRWv18t9991Hd3c3arUao9GIWq0mEAgwNTXFihUr2Lx5Mx6PB7vdjkajQavVSmW6GNi6dSt/8Rd/gcFgYGBggP7+fiwWCy0tLXg8HrZs2YLH4+HVV1/F5/ORTCYJh8PT9m25v3JJKkmDwUBVVRVarZZkMkkmk8FgMJDL5Ugmk4yPj5NIJHA6ncBp8z2TyZBOp8nn89MUoIggCgf4Yt+s7xej0YjL5UKj0RCLxUin0/J7LpcLnU4HIC2aUqlENBolHA6TzWblYi2VSuTzecLhMMFgUEYvFzPC7SOuzyqVSl4b6+vrqa+vp6amBq1WSz6fR6PRYDabgXeUa6FQoFgskkwmZXqQcCGNjo4yPj5OOBxecIe72MsajUbuZZ1ORzabpVgsEggECIfDNDc3A+9cxTOZzLS9XCwW5V5Wq9XzGhCsCCXZ1tbGvffei8fjYd++fXR3d5PL5bj//vvJ5XKMjo6SyWS45557WL58OWNjY7z44ov09/czMTEx7b20Wi0ejwe9Xk84HCYSiczTU1U2y5cv5wtf+AJms5mHHnqIN954g5aWFq677joaGxuprq6e9vpcLkc0GiUWi8koZT6fJ5VK0dPTw7e//W0sFgt9fX2L+mBSqVR4vV6amprIZDJEo1F0Oh3XXXcda9eupba2ltbWVqLRKK+88grj4+NMTEywefNmwuEww8PDZDIZqSwPHTrE6OgoOp0Os9ksr+KpVIpIJEIikZjnJ35vtLW18fu///tUV1fz1ltvcerUKdLpNP/8z/9MNpuVe1mn07F582bGx8d5/fXX6e/vZ2pqatp76XQ6ampqZK6psC7nmopQkk6nk+3bt9PU1EQ8HiedTuP3+3n77belJZnNZrnuuuukM3dwcJDBwcGzggNqtVo6gsUCW8yb9v3icDi49NJLsdvtPPvss6hUKhwOB8uXL8fr9cqk8fKorEj9ET5L4f+NRCIcOHBgnp/o4iOu1TabDa/XSzqdRq1Wo9frWbNmDTt27JABmZGREXw+HydOnMBsNstDZ2RkhGKxKN0Z4+PjjIyMzPOTfXCIddXU1EQ4HCaZTOL3+zl8+DCpVIrJyUny+Tx+v59MJkM8Hmd0dJTR0dEZ97II8MTj8XnLmJh3JVnuz9Hr9axfv57q6mr27t3LsWPHZICgVCrxxhtv8N3vfpfBwUGCwSDFYlFerUU0vFAoEA6HicfjJJPJJakg381/o1KpSKfTDAwM4PF4uOqqq2TO2ssvv4xer+fAgQNYLBZ53T516hSRSIRcLievQgs9ev1ulAf9DAYDTU1N2O12duzYwY4dOygWi6RSKbRaLZs2bcLj8UiXUTgcxufzMTQ0hF6vl1dvu92O2WyWLo6Fdp2eDbHudDodRqORzZs3U1tby759++js7CSZTEo3w1tvvcV9993H0NDQtL2sVqspFAoyaBMMBonH4yQSiXnby/OqJMsTSIWS3LRpE1u2bJGLKB6PS8Hu2bOHN998k2KxKCPYBoMBnU5HLpcjlUpRKBRkAGcpKkh45+CZyScr5J1MJunv7yebzXLNNdfQ0tLCww8/zF/91V8Rj8dlMrlOp0Ov15PJZAiHw3IBL3bOTPg2GAy0t7fT2NjIhz/8YW644QYZYCl/bTKZJBKJEA6HGR0dZXBwUL6fiGRrNBpSqdS0LIyFjpCBCDyZTCa2bt0qn/WnP/0pyWRSrsmZ9rLRaESv15PNZqXBUwl7eV6VpHjwSCTCoUOHCAQC6HQ6tFot3d3dZDIZmTslnOTCehG/WygU5OkDp6tz2trasFgs8sqzFBFRV61We1b5m/hTWEpjY2Mkk0lZQysUoXCWi2v1Ujx0HA4H1dXV2O12qqursVqt8pAxGAyyWkT4DwOBAMPDwwwODhKPx+X7CGWQSqVkesti4lx7WaPR0NP
|
|
|
|
|
"text/plain": [
|
|
|
|
|
"<Figure size 400x400 with 16 Axes>"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "display_data"
|
2024-07-26 16:07:40 +03:00
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
2024-07-28 01:10:19 +03:00
|
|
|
"batch_size = 16\n",
|
|
|
|
|
"display_step = 5\n",
|
|
|
|
|
"cur_step = 0\n",
|
|
|
|
|
"mean_generator_loss = 0\n",
|
|
|
|
|
"mean_discriminator_loss = 0\n",
|
2024-07-26 16:07:40 +03:00
|
|
|
"\n",
|
2024-07-28 01:10:19 +03:00
|
|
|
"D_loss_grad = nn.value_and_grad(disc, disc_loss)\n",
|
|
|
|
|
"G_loss_grad = nn.value_and_grad(gen, gen_loss)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"for epoch in tqdm(range(200)):\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" for real in batch_iterate(batch_size, train_images[:50]):\n",
|
|
|
|
|
" cur_batch_size = len(real)\n",
|
|
|
|
|
" # real = real.reshape(-1)\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" # Flatten the batch of real images from the dataset\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" # plt.imshow(real[0].reshape(28,28))\n",
|
|
|
|
|
" # print(len(real))\n",
|
|
|
|
|
" # break\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" D_loss,D_grads = D_loss_grad(gen, disc, real, batch_size, z_dim)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" # Update optimizer\n",
|
|
|
|
|
" disc_opt.update(disc, D_grads)\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" # Update gradients\n",
|
|
|
|
|
" mx.eval(disc.parameters(), disc_opt.state)\n",
|
2024-07-26 16:07:40 +03:00
|
|
|
"\n",
|
2024-07-28 01:10:19 +03:00
|
|
|
" \n",
|
|
|
|
|
" G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" # Update optimizer\n",
|
|
|
|
|
" gen_opt.update(gen, G_grads)\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" # Update gradients\n",
|
|
|
|
|
" mx.eval(gen.parameters(), gen_opt.state)\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" if cur_step % display_step == 0 and cur_step > 0:\n",
|
|
|
|
|
" print(f\"Step {epoch}: Generator loss: {G_loss}, discriminator loss: {D_loss}\")\n",
|
|
|
|
|
" fake_noise = mx.array(get_noise(batch_size, z_dim))\n",
|
|
|
|
|
" fake = gen(fake_noise)\n",
|
|
|
|
|
" show_images(fake)\n",
|
|
|
|
|
" # print(fake.shape)\n",
|
|
|
|
|
" cur_step += 1\n"
|
2024-07-27 00:19:08 +03:00
|
|
|
]
|
2024-07-26 16:07:40 +03:00
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"metadata": {
|
|
|
|
|
"kernelspec": {
|
|
|
|
|
"display_name": "base",
|
|
|
|
|
"language": "python",
|
|
|
|
|
"name": "python3"
|
|
|
|
|
},
|
|
|
|
|
"language_info": {
|
|
|
|
|
"codemirror_mode": {
|
|
|
|
|
"name": "ipython",
|
|
|
|
|
"version": 3
|
|
|
|
|
},
|
|
|
|
|
"file_extension": ".py",
|
|
|
|
|
"mimetype": "text/x-python",
|
|
|
|
|
"name": "python",
|
|
|
|
|
"nbconvert_exporter": "python",
|
|
|
|
|
"pygments_lexer": "ipython3",
|
|
|
|
|
"version": "3.10.10"
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
"nbformat": 4,
|
|
|
|
|
"nbformat_minor": 2
|
|
|
|
|
}
|