mlx-examples/gan/playground.ipynb

567 lines
112 KiB
Plaintext
Raw Normal View History

2024-07-26 21:07:40 +08:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Import Library"
]
},
{
"cell_type": "code",
2024-07-29 06:24:50 +08:00
"execution_count": 46,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"import mnist"
]
},
{
"cell_type": "code",
2024-07-29 06:24:50 +08:00
"execution_count": 47,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"import mlx.core as mx\n",
"import mlx.nn as nn\n",
"import mlx.optimizers as optim\n",
"\n",
"from tqdm import tqdm\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# GAN Architecture"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generator 👨🏻‍🎨"
]
},
{
"cell_type": "code",
2024-07-29 06:24:50 +08:00
"execution_count": 48,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def GenBlock(in_dim:int,out_dim:int):\n",
" \n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
" nn.BatchNorm(out_dim),\n",
" nn.ReLU()\n",
" )"
]
},
{
"cell_type": "code",
2024-07-29 06:24:50 +08:00
"execution_count": 49,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"class Generator(nn.Module):\n",
"\n",
" def __init__(self, z_dim:int = 10, im_dim:int = 784, hidden_dim: int =128):\n",
" super(Generator, self).__init__()\n",
" # Build the neural network\n",
" self.gen = nn.Sequential(\n",
" GenBlock(z_dim, hidden_dim),\n",
" GenBlock(hidden_dim, hidden_dim * 2),\n",
" GenBlock(hidden_dim * 2, hidden_dim * 4),\n",
" GenBlock(hidden_dim * 4, hidden_dim * 8),\n",
"\n",
"\n",
" nn.Linear(hidden_dim * 8,im_dim),\n",
" nn.Sigmoid()\n",
" )\n",
" \n",
" def __call__(self, noise):\n",
"\n",
" return self.gen(noise)"
]
},
{
"cell_type": "code",
2024-07-29 06:24:50 +08:00
"execution_count": 50,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Generator(\n",
" (gen): Sequential(\n",
" (layers.0): Sequential(\n",
" (layers.0): Linear(input_dims=100, output_dims=128, bias=True)\n",
" (layers.1): BatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): ReLU()\n",
" )\n",
" (layers.1): Sequential(\n",
" (layers.0): Linear(input_dims=128, output_dims=256, bias=True)\n",
" (layers.1): BatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): ReLU()\n",
" )\n",
" (layers.2): Sequential(\n",
" (layers.0): Linear(input_dims=256, output_dims=512, bias=True)\n",
" (layers.1): BatchNorm(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): ReLU()\n",
" )\n",
" (layers.3): Sequential(\n",
" (layers.0): Linear(input_dims=512, output_dims=1024, bias=True)\n",
" (layers.1): BatchNorm(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers.2): ReLU()\n",
" )\n",
" (layers.4): Linear(input_dims=1024, output_dims=784, bias=True)\n",
" (layers.5): Sigmoid()\n",
" )\n",
")"
]
},
2024-07-29 06:24:50 +08:00
"execution_count": 50,
2024-07-26 21:07:40 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gen = Generator(100)\n",
"gen"
]
},
{
"cell_type": "code",
2024-07-29 06:24:50 +08:00
"execution_count": 51,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def get_noise(n_samples, z_dim):\n",
" return np.random.randn(n_samples,z_dim)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Discriminator 🕵🏻‍♂️"
]
},
{
"cell_type": "code",
2024-07-29 06:24:50 +08:00
"execution_count": 52,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"def DisBlock(in_dim:int,out_dim:int):\n",
" return nn.Sequential(\n",
" nn.Linear(in_dim,out_dim),\n",
" nn.LeakyReLU(negative_slope=0.2)\n",
" )"
]
},
{
"cell_type": "code",
2024-07-29 06:24:50 +08:00
"execution_count": 53,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"class Discriminator(nn.Module):\n",
"\n",
" def __init__(self,im_dim:int = 784, hidden_dim:int = 128):\n",
" super(Discriminator, self).__init__()\n",
"\n",
" self.disc = nn.Sequential(\n",
" DisBlock(im_dim, hidden_dim * 4),\n",
" DisBlock(hidden_dim * 4, hidden_dim * 2),\n",
" DisBlock(hidden_dim * 2, hidden_dim),\n",
"\n",
" nn.Linear(hidden_dim,1),\n",
" )\n",
" \n",
" def __call__(self, noise):\n",
"\n",
" return self.disc(noise)"
]
},
{
"cell_type": "code",
2024-07-29 06:24:50 +08:00
"execution_count": 54,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Discriminator(\n",
" (disc): Sequential(\n",
" (layers.0): Sequential(\n",
" (layers.0): Linear(input_dims=784, output_dims=512, bias=True)\n",
" (layers.1): LeakyReLU()\n",
" )\n",
" (layers.1): Sequential(\n",
" (layers.0): Linear(input_dims=512, output_dims=256, bias=True)\n",
" (layers.1): LeakyReLU()\n",
" )\n",
" (layers.2): Sequential(\n",
" (layers.0): Linear(input_dims=256, output_dims=128, bias=True)\n",
" (layers.1): LeakyReLU()\n",
" )\n",
" (layers.3): Linear(input_dims=128, output_dims=1, bias=True)\n",
" )\n",
")"
]
},
2024-07-29 06:24:50 +08:00
"execution_count": 54,
2024-07-26 21:07:40 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"disc = Discriminator()\n",
"disc"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model Training 🏋🏻‍♂️"
]
},
{
"cell_type": "code",
2024-07-29 06:24:50 +08:00
"execution_count": 55,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"# Set your parameters\n",
"criterion = nn.losses.binary_cross_entropy\n",
"n_epochs = 200\n",
"z_dim = 64\n",
"display_step = 500\n",
"batch_size = 128\n",
"lr = 0.00001"
]
},
{
"cell_type": "code",
2024-07-29 06:24:50 +08:00
"execution_count": 56,
2024-07-26 21:07:40 +08:00
"metadata": {},
2024-07-27 06:09:51 +08:00
"outputs": [],
2024-07-26 21:07:40 +08:00
"source": [
"gen = Generator(z_dim)\n",
2024-07-26 21:36:29 +08:00
"mx.eval(gen.parameters())\n",
2024-07-26 21:07:40 +08:00
"gen_opt = optim.Adam(learning_rate=lr)\n",
2024-07-26 21:36:29 +08:00
"\n",
2024-07-26 21:07:40 +08:00
"disc = Discriminator()\n",
2024-07-26 21:36:29 +08:00
"mx.eval(disc.parameters())\n",
2024-07-26 21:07:40 +08:00
"disc_opt = optim.Adam(learning_rate=lr)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Losses"
]
},
{
"cell_type": "code",
2024-07-29 06:24:50 +08:00
"execution_count": 57,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-26 21:36:29 +08:00
"def disc_loss(gen, disc, real, num_images, z_dim):\n",
2024-07-26 21:07:40 +08:00
" noise = mx.array(get_noise(num_images, z_dim))\n",
" fake_images = gen(noise)\n",
2024-07-27 06:09:51 +08:00
" \n",
2024-07-26 21:07:40 +08:00
" fake_disc = disc(fake_images)\n",
" \n",
2024-07-27 05:19:08 +08:00
" fake_labels = mx.zeros((fake_images.shape[0],1))\n",
2024-07-26 21:36:29 +08:00
" fake_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)\n",
2024-07-26 21:07:40 +08:00
" \n",
" real_disc = disc(real)\n",
2024-07-27 05:19:08 +08:00
" real_labels = mx.ones((real.shape[0],1))\n",
"\n",
2024-07-27 06:09:51 +08:00
" real_loss = nn.losses.binary_cross_entropy(real_disc,real_labels,with_logits=True)\n",
2024-07-26 21:07:40 +08:00
"\n",
" disc_loss = (fake_loss + real_loss) / 2\n",
"\n",
" return disc_loss"
]
},
{
"cell_type": "code",
2024-07-29 06:24:50 +08:00
"execution_count": 58,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-26 21:36:29 +08:00
"def gen_loss(gen, disc, num_images, z_dim):\n",
2024-07-26 21:07:40 +08:00
"\n",
" noise = mx.array(get_noise(num_images, z_dim))\n",
" fake_images = gen(noise)\n",
" fake_disc = disc(fake_images)\n",
"\n",
2024-07-27 05:19:08 +08:00
" fake_labels = mx.ones((fake_images.shape[0],1))\n",
2024-07-26 21:07:40 +08:00
" \n",
2024-07-26 21:36:29 +08:00
" gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)\n",
2024-07-26 21:07:40 +08:00
"\n",
" return gen_loss"
]
},
{
"cell_type": "code",
2024-07-29 06:24:50 +08:00
"execution_count": 59,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
"train_images, _, test_images, _ = map(\n",
" mx.array, getattr(mnist, 'mnist')()\n",
")"
]
},
{
"cell_type": "code",
2024-07-29 06:24:50 +08:00
"execution_count": 133,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"60000"
]
},
"execution_count": 133,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(train_images)"
]
},
{
"cell_type": "code",
"execution_count": 60,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-29 06:24:50 +08:00
"def batch_iterate(batch_size: int, ipt: list):\n",
" perm = np.random.permutation(len(ipt))\n",
" for s in range(0, len(ipt), batch_size):\n",
2024-07-26 21:07:40 +08:00
" ids = perm[s : s + batch_size]\n",
2024-07-29 06:24:50 +08:00
" yield [ipt[i] for i in ids]"
2024-07-26 21:07:40 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### show batch of images"
]
},
{
"cell_type": "code",
2024-07-29 06:24:50 +08:00
"execution_count": 128,
2024-07-26 21:07:40 +08:00
"metadata": {},
"outputs": [
{
"data": {
2024-07-29 06:24:50 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGVCAYAAAAyrrwGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAD+KElEQVR4nOy913Ok15nf/+mcc0Q3choMJg/JYRaDREkridrV2lrf2S6Xy75z+V+w/wKX7Zu9sW+2vLVBtlcbZFKiKM2QHFLD4SRMQM7onHP+XczvHDZmMOQEDNAA3k8VigENoN/T532f86Tvo+p0Oh0UFBQUFBR2EfV+vwEFBQUFhcOHYlwUFBQUFHYdxbgoKCgoKOw6inFRUFBQUNh1FOOioKCgoLDrKMZFQUFBQWHXUYyLgoKCgsKuoxgXBQUFBYVdRzEuCgoKCgq7jvZxX6hSqZ7n+zhQPI2ogbJ+X6Os37PxtKIayhp+jbIHn43HWT/Fc1FQUFBQ2HUU46KgoKCgsOsoxkVBQUFBYddRjIuCgoKCwq7z2Al9BQUFBYVnR61WY7FY0Ol08msnWq0WjUaDdrtNs9mk3W4D95PpnU6HRqNBp9Oh1Wo9dZHH80QxLgoKCgp7iM1m4/vf/z4jIyMMDAwwPDy84+uSySSrq6uUSiXW1tYoFosAtNttyuUym5ubVKtV0uk05XJ5D6/g8VCMi4LCLvCoMtVePFEq7C96vZ6xsTFOnz7N9PQ0p06d2nH/bGxscOvWLXK5HAaDgUwmA9w3LtlsllKpRLFYpFgsUqlUgN7ab4pxUVB4QjQaDQaDAZ1ORzAYxOl0EgwGGR0dRaPRAPdDGtevX2dmZoZarUaxWJRhDQWFbjqdzjbj0m636XQ66PV6fD4fdrsdm81GvV6XrykWi5w8eVJ6LqVSiUgkwuLiIpVKhWQySa1W24/LkSjGRUHhCdFoNFitVsxmM6dOnWJ4eJgXX3yRP/qjP8JgMMh4+J//+Z+TSCQoFApUKhXFuChsQ+ROHvx/rVaLdruNXq8nEAigVqtxOp3o9Xr5ulqtRjabpdFoUC6XqdfrXLlyhX/8x38klUpRLBYV46Lw9Hi9XrxeL1qtFqPRiFqtplQqUavVKBQKJBKJx36gWSwWHA4Her1ebuRarUatVqNUKhGPx2UCsZdc7/3AZDIxODiIw+FgdHSUoaEh/H4/RqMRvV5Pp9NBo9EQDAaZmpoiGo2SzWapVCrbErMKR5NWq0UmkyEWi2Gz2TCZTJhMJlwuF1qtlmazSavVIp1Os7a2hkajIRQKYTabsVgsWCwWtFotZrOZZrMpfyYYDDI+Po7dbicajdJqtajVajQajX25TsW4HFBUKhWvvvoqf/Inf4LD4WBoaAiNRsO9e/fY2triyy+/5O/+7u9kLPbbGBoa4o033iAQCPCd73wHv9/P2toaW1tb3Lt3j7/9278lnU5Tq9VoNpvP+ep6m76+Pv7ZP/tnhMNhpqen5Y2v1Wql4VWr1bzyyiv09fVx8+ZNstksiUSCXC5HtVrd5ytQ2E8qlQo3btxga2uLmzdv4vP5CIVCfPe738XpdNJoNGg2m1y9epVf/vKXaDQazp07h9/v5/jx4xw7dkx6zyqVSh74TCYTQ0NDbGxsUK1WmZ+fJxqNEo/H9+U6FeNyQFGpVDidTkZGRvB4PIyPj6PVamk0Gmg0GpaXl1GrH7+NyWw2EwqFCIfDTE1NyQemyWQin89jMBjQaDRHWl9JpVKhUqkwm82Ew2GGhoYIhUIEAgH5GmFcxOejUqlIJBJYrVby+bzMySgcXURCvtPpUK1WKRaLNJtNUqmUDKk2m01isRgrKyuo1Wp8Ph/NZlMaIoPBgNVqRa1Wo1ar5X7T6XS0222cTidWq3VbKG2vUYzLAWe3wlQ+n4+zZ8/i9/uxWq3A/bCbwWAgnU7jdDopl8s0Go19c7P3G4fDgdPpZGhoiPHxcQYGBrDZbI98vQhf+P1+wuEwAOVymVKptFdvWaEHEYYjk8mg1+vR6/UsLCywsrKC0WiUOZdoNMry8jIqlYpSqYTJZOLLL78kFAoxODjIe++9h8vlwuVyyb4Zi8WC1WrF6XTicDgwGo37dp2KcTkkiAoT8fWkcX273c7Y2Bhut1tuSLvdjt1uJxAIYLVaMRqNR/bkrVKpsFgs+P1+gsEg4XCYUCj0jT9jMBgwGAy4XC68Xi/VapX19fU9escKvYrIuTzI9evXH/kz0WgUAKvVisVi4ezZsxw7doxOpyPzMBqNBo1Gg8lkwmKxYLfb0Wr37xG/r8ZFnAQFwk3sPhmrVCp0Oh1qtRqbzYbNZkOlUqHRaGg2m6ysrJBKpfbh3e8PorvXYDDgcDiwWq2YTKYj+9DfC7RaLRqNhnA4zOnTpxkfH8dgMDzR7xAhtaOAVqtFrVZjtVpxu92y4ESv1+NyuXA6nbRaLarVKq1Wi3q9LhPY0WiUer0uQ0UK22k2mzKUlslk5IGnF9k346JSqejr6+PEiRPypmu1WrJ0s/t1VqsVg8HA+Pg44+Pjss+gXC7zV3/1V0fKuGi1WgKBAA6Hg1AohM/nw2w2o9FojnwV1/NArVZjNBoxGAycPn2aP/7jP8bn82GxWPb7rfUkKpUKvV6PwWBgcHCQc+fOYbVaCQaDWK1WTp06xYkTJ6hWqyQSCVlSWy6XuXnzJr/97W/JZrOsrKwoxmUHqtUqtVqNVCrF+vo6nU6HgYGB/X5bO7LnxkWtVmMymdDr9fj9fvr7+2XiudVqYTQapcwB3O8psFgs6PV6+vv7CQaDaLVadDodpVIJp9MpS/K6m4wOKyqVCoPBgMlkwmg0Sm0iUTWisLvodDoCgQA2m41gMIjH48Futz+RpyhCFWazWXrhh6WkW+w/tVotiz7sdjtGo5GBgQH6+/uxWCwEAgH5T5/PR61WQ6VSUavVMBgMsvFvYGAAi8VCLpeT3z+qOb5HIcLeIvHfq6Xte25cTCYTFy5cIBQK8corr/DGG29sO3XX63V5YhGdq8LNNpvNmM1mubjFYpGZmRkqlQqxWIzl5WVardZeX9KeotFocLlc8kEnKkTUanXPbrKDiEqlQq1W4/f7+bf/9t9y8uRJ+bD8JrHBnbDb7Zw5cwa/38/GxgapVIpGo3EoSpIDgQD9/f04nU4mJiawWq0MDg7i8XjkPhWejDA8cN8oeTwe2u02Ho+HVqtFX18fZ86cIRaL8ctf/pK1tTUWFxdZWVnZ34tUeCr23LgIyYyxsTGOHTvGyZMn5Umumwf/u/tkLlRCzWYzfr8fv99PqVQ6EjFt4fmJDnFx03ZzGE7EvYAoOz558iSvvfYaRqPxqapvDAYDHo+HZrMpq3oOw0FA5FWCwSA+n4/p6WmcTifj4+MEg0HMZjMOh0Pel933p9jH3XvV6XTS399PJBJhZmaGVqtFLBZTvPIDyr4Yl4GBASYmJvB6vahUKlqtFoVCgVqtxvr6+o5NP2Izigfr0NCQLL1zOByYzeYjYVwMBgNTU1NMT08zMDCw7ZpbrRaLi4tcuXKFxcXFZ45ZN5tNms2mbJwUFWlHhU6nQ7lcZmZmBoCJiQnGx8efeJ81Gg3y+Ty5XI5KpUKr1TrQHrYIUdvtdl555RVeffVVbDYboVAIo9GIx+ORPRZPslaiUMdqtUpPr1AosLGxQb1ep1KpHKn9d9DZN+Ny7NgxfD4farWaer1ONpsln8/zxRdfcPfu3R1/1uv14vF4ZHhCr9djtVqPnHE5ceIEr7zyCn6/f1ujZLPZZGFhgU8++YR0Ov3MxqVer0sJGCFJcVRubhF6LZVK3Lhxg2w2i8FgYGxs7KmMSzabJZPJUK1WezpO/jjo9XomJibo7+/nu9/9Lj/5yU9kyetOXsrjIirqbDYb586dY3JykqWlJa5duyZljQ6yUT5q7LlxEV5KOp2WN2+xWGR1dZV8Ps/i4iKbm5s7/qwov1Or1TQaDbRaLXq9XiZKjwKdTodarSZLtrsVVUWljtlsplAoPNENLtRV4f68Cb1eT6vVolKpHEnjAttP0g6HA4PB8NCaihyhKDIRPwdfN7iKZHUikZCey0E0Ljq
2024-07-26 21:07:40 +08:00
"text/plain": [
2024-07-29 06:24:50 +08:00
"<Figure size 500x500 with 25 Axes>"
2024-07-26 21:07:40 +08:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for X in batch_iterate(16, train_images):\n",
2024-07-29 06:24:50 +08:00
" fig,axes = plt.subplots(5, 5, figsize=(5, 5))\n",
2024-07-26 21:07:40 +08:00
"\n",
" for i, ax in enumerate(axes.flat):\n",
" img = mx.array(X[i]).reshape(28,28)\n",
" ax.imshow(img,cmap='gray')\n",
" ax.axis('off')\n",
" break"
]
},
{
"cell_type": "code",
2024-07-29 06:24:50 +08:00
"execution_count": 147,
2024-07-28 06:10:19 +08:00
"metadata": {},
"outputs": [],
"source": [
2024-07-28 22:22:40 +08:00
"def show_images(imgs:list[int],num_imgs:int = 25):\n",
2024-07-29 06:24:50 +08:00
" if (imgs.shape[0] > 0): \n",
" fig,axes = plt.subplots(5, 5, figsize=(5, 5))\n",
" \n",
" for i, ax in enumerate(axes.flat):\n",
" img = mx.array(imgs[0]).reshape(28,28)\n",
" ax.imshow(img,cmap='gray')\n",
" ax.axis('off')\n",
" plt.show()"
2024-07-28 06:10:19 +08:00
]
},
{
"cell_type": "code",
2024-07-29 06:24:50 +08:00
"execution_count": 148,
2024-07-27 06:09:51 +08:00
"metadata": {},
2024-07-29 06:24:50 +08:00
"outputs": [],
2024-07-27 06:09:51 +08:00
"source": [
"z_dim = 64\n",
"gen = Generator(z_dim)\n",
"mx.eval(gen.parameters())\n",
"gen_opt = optim.Adam(learning_rate=lr)\n",
"\n",
"disc = Discriminator()\n",
"mx.eval(disc.parameters())\n",
2024-07-29 06:24:50 +08:00
"disc_opt = optim.Adam(learning_rate=lr)\n"
2024-07-27 06:09:51 +08:00
]
},
2024-07-29 06:30:08 +08:00
{
"cell_type": "code",
"execution_count": 162,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"array(0, dtype=float32)\n"
]
}
],
"source": [
"train_images, train_labels, test_images, test_labels = map(mx.array, mnist.mnist())\n",
"print(min(train_images[0]))\n",
"# train_images = train_images * 2.0 - 1.0 # normalize the image"
]
},
{
"cell_type": "code",
"execution_count": 156,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x33c4ed2d0>"
]
},
"execution_count": 156,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaEAAAGdCAYAAAC7EMwUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAaI0lEQVR4nO3df2jU9x3H8dfVH1d1lytBk7vUmGVF202dpWrVYP3R1cxApf4oWMtGZEPa+YOJ/cGsDNNBjdgpRdI6V0amW239Y9a6KdUMTXRkijpdRYtYjDOdCcFM72LUSMxnf4hHz1j1e975vkueD/iCufu+vY/ffuvTby75xueccwIAwMBD1gsAAHRfRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJjpab2AW3V0dOjcuXMKBALy+XzWywEAeOScU0tLi/Ly8vTQQ3e+1km7CJ07d075+fnWywAA3Kf6+noNHDjwjvuk3afjAoGA9RIAAElwL3+fpyxCH3zwgQoLC/Xwww9r5MiR2rdv3z3N8Sk4AOga7uXv85REaPPmzVq8eLGWLVumI0eO6JlnnlFJSYnOnj2bipcDAGQoXyruoj1mzBg99dRTWrduXeyx73//+5o+fbrKy8vvOBuNRhUMBpO9JADAAxaJRJSVlXXHfZJ+JXTt2jUdPnxYxcXFcY8XFxertra20/5tbW2KRqNxGwCge0h6hM6fP6/r168rNzc37vHc3Fw1NjZ22r+8vFzBYDC28ZVxANB9pOwLE259Q8o5d9s3qZYuXapIJBLb6uvrU7UkAECaSfr3CfXv3189evTodNXT1NTU6epIkvx+v/x+f7KXAQDIAEm/Eurdu7dGjhypqqqquMerqqpUVFSU7JcDAGSwlNwxYcmSJfrpT3+qUaNGady4cfr973+vs2fP6tVXX03FywEAMlRKIjR79mw1NzfrN7/5jRoaGjRs2DDt2LFDBQUFqXg5AECGSsn3Cd0Pvk8IALoGk+8TAgDgXhEhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmelovAEgnPXr08DwTDAZTsJLkWLhwYUJzffv29Tzz+OOPe55ZsGCB55nf/va3nmfmzJnjeUaSrl696nlm5cqVnmfefvttzzNdBVdCAAAzRAgAYCbpESorK5PP54vbQqFQsl8GANAFpOQ9oaFDh+rvf/977ONEPs8OAOj6UhKhnj17cvUDALirlLwndOrUKeXl5amwsFAvvfSSTp8+/a37trW1KRqNxm0AgO4h6REaM2aMNm7cqJ07d+rDDz9UY2OjioqK1NzcfNv9y8vLFQwGY1t+fn6ylwQASFNJj1BJSYlmzZql4cOH67nnntP27dslSRs2bLjt/kuXLlUkEolt9fX1yV4SACBNpfybVfv166fhw4fr1KlTt33e7/fL7/enehkAgDSU8u8Tamtr05dffqlwOJzqlwIAZJikR+j1119XTU2N6urqdODAAb344ouKRqMqLS1N9ksBADJc0j8d9/XXX2vOnDk6f/68BgwYoLFjx2r//v0qKChI9ksBADJc0iP0ySefJPu3RJoaNGiQ55nevXt7nikqKvI8M378eM8zkvTII494npk1a1ZCr9XVfP31155n1q5d63lmxowZnmdaWlo8z0jSv//9b88zNTU1Cb1Wd8W94wAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAMz7nnLNexDdFo1EFg0HrZXQrTz75ZEJzu3fv9jzDf9vM0NHR4XnmZz/7meeZS5cueZ5JRENDQ0JzFy5c8Dxz8uTJhF6rK4pEIsrKyrrjPlwJAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwExP6wXA3tmzZxOaa25u9jzDXbRvOHDggOeZixcvep6ZPHmy5xlJunbtmueZP/3pTwm9Fro3roQAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADPcwBT63//+l9DcG2+84Xnm+eef9zxz5MgRzzNr1671PJOoo0ePep6ZMmWK55nW1lbPM0OHDvU8I0m//OUvE5oDvOJKCABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAw43POOetFfFM0GlUwGLReBlIkKyvL80xLS4vnmfXr13uekaSf//znnmd+8pOfeJ75+OOPPc8AmSYSidz1/3muhAAAZogQAMCM5wjt3btX06ZNU15ennw+n7Zu3Rr3vHNOZWVlysvLU58+fTRp0iQdP348WesFAHQhniPU2tqqESNGqKKi4rbPr1q1SmvWrFFFRYUOHjyoUCikKVOmJPR5fQBA1+b5J6uWlJSopKTkts855/Tee+9p2bJlmjlzpiRpw4YNys3N1aZNm/TKK6/c32oBAF1KUt8TqqurU2Njo4qLi2OP+f1+TZw4UbW1tbedaWtrUzQajdsAAN1DUiPU2NgoScrNzY17PDc3N/bcrcrLyxUMBmNbfn5+MpcEAEhjKfnqOJ/PF/exc67TYzctXbpUkUgkttXX16diSQCANOT5PaE7CYVCkm5cEYXD4djjTU1Nna6ObvL7/fL7/clcBgAgQyT1SqiwsFChUEhVVVWxx65du6aamhoVFRUl86UAAF2A5yuhS5cu6auvvop9XFdXp6NHjyo7O1uDBg3S4sWLtWLFCg0ePFiDBw/WihUr1LdvX7388stJXTgAIPN5jtChQ4c0efLk2MdLliyRJJWWluqPf/yj3nzzTV25ckXz58/XhQsXNGbMGO3atUuBQCB5qwYAdAncwBRd0rvvvpvQ3M1/VHlRU1Pjeea5557zPNPR0eF5BrDEDUwBAGmNCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZriLNrqkfv36JTT317/+1fPMxIkTPc+UlJR4ntm1a5fnGcASd9EGAKQ1IgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMNzAFvuGxxx7zPPOvf/3L88zFixc9z+zZs8fzzKFDhzzPSNL777/veSbN/ipBGuAGpgCAtEaEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmOEGpsB9mjFjhueZyspKzzOBQMDzTKLeeustzzMbN270PNPQ0OB5BpmDG5gCANIaEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGG5gCBoYNG+Z5Zs2aNZ5nfvSjH3meSdT69es9z7zzzjueZ/773/96noENbmAKAEhrRAgAYMZzhPbu3atp06YpLy9PPp9PW7dujXt+7ty58vl8cdvYsWOTtV4AQBfiOUKtra0aMWKEKioqvnWfqVOnqqGhIbbt2LHjvhYJAOiaenodKCkpUUlJyR338fv9CoVCCS8KANA9pOQ9oerqauXk5GjIkCGaN2+empqavnXftrY2RaPRuA0A0D0kPUIlJSX66KOPtHv3bq1evVoHDx7Us88+q7a2ttvuX15ermAwGNvy8/OTvSQAQJry/Om4u5k9e3bs18OGDdOoUaNUUFCg7du3a+bMmZ32X7p0qZYsWRL7OBqNEiI
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.imshow(train_images[0].reshape(28,28),cmap='gray')"
]
},
2024-07-27 06:09:51 +08:00
{
"cell_type": "code",
2024-07-29 06:24:50 +08:00
"execution_count": 151,
2024-07-27 06:09:51 +08:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-29 06:24:50 +08:00
" 7%|▋ | 4/60 [08:34<2:00:03, 128.64s/it]\n"
2024-07-29 00:18:35 +08:00
]
},
{
2024-07-29 06:24:50 +08:00
"ename": "KeyboardInterrupt",
"evalue": "",
2024-07-29 00:18:35 +08:00
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
2024-07-29 06:24:50 +08:00
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[151], line 24\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[38;5;66;03m# Update gradients\u001b[39;00m\n\u001b[1;32m 22\u001b[0m mx\u001b[38;5;241m.\u001b[39meval(disc\u001b[38;5;241m.\u001b[39mparameters(), disc_opt\u001b[38;5;241m.\u001b[39mstate)\n\u001b[0;32m---> 24\u001b[0m G_loss,G_grads \u001b[38;5;241m=\u001b[39m \u001b[43mG_loss_grad\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgen\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdisc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcur_batch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mz_dim\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 26\u001b[0m \u001b[38;5;66;03m# Update optimizer\u001b[39;00m\n\u001b[1;32m 27\u001b[0m gen_opt\u001b[38;5;241m.\u001b[39mupdate(gen, G_grads)\n",
"File \u001b[0;32m~/miniforge3/lib/python3.10/site-packages/mlx/nn/utils.py:34\u001b[0m, in \u001b[0;36mvalue_and_grad.<locals>.wrapped_value_grad_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(fn)\n\u001b[1;32m 33\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapped_value_grad_fn\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m---> 34\u001b[0m value, grad \u001b[38;5;241m=\u001b[39m value_grad_fn(\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrainable_parameters\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 35\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m value, grad\n",
"File \u001b[0;32m~/miniforge3/lib/python3.10/site-packages/mlx/nn/layers/base.py:310\u001b[0m, in \u001b[0;36mModule.trainable_parameters\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 307\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtrainable_parameters\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 308\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Recursively return all the non frozen :class:`mlx.core.array` members of\u001b[39;00m\n\u001b[1;32m 309\u001b[0m \u001b[38;5;124;03m this Module as a dict of dicts and lists.\"\"\"\u001b[39;00m\n\u001b[0;32m--> 310\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfilter_and_map\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrainable_parameter_filter\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniforge3/lib/python3.10/site-packages/mlx/nn/layers/base.py:298\u001b[0m, in \u001b[0;36mModule.filter_and_map\u001b[0;34m(self, filter_fn, map_fn, is_leaf_fn)\u001b[0m\n\u001b[1;32m 292\u001b[0m map_fn \u001b[38;5;241m=\u001b[39m map_fn \u001b[38;5;129;01mor\u001b[39;00m (\u001b[38;5;28;01mlambda\u001b[39;00m x: x)\n\u001b[1;32m 293\u001b[0m is_leaf_fn \u001b[38;5;241m=\u001b[39m is_leaf_fn \u001b[38;5;129;01mor\u001b[39;00m (\n\u001b[1;32m 294\u001b[0m \u001b[38;5;28;01mlambda\u001b[39;00m m, k, v: \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(v, (Module, \u001b[38;5;28mdict\u001b[39m, \u001b[38;5;28mlist\u001b[39m))\n\u001b[1;32m 295\u001b[0m )\n\u001b[1;32m 296\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {\n\u001b[1;32m 297\u001b[0m k: _unwrap(\u001b[38;5;28mself\u001b[39m, k, v, filter_fn, map_fn, is_leaf_fn)\n\u001b[0;32m--> 298\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mitems\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 299\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m filter_fn(\u001b[38;5;28mself\u001b[39m, k, v)\n\u001b[1;32m 300\u001b[0m }\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
2024-07-29 00:18:35 +08:00
]
2024-07-26 21:07:40 +08:00
}
],
"source": [
2024-07-28 22:35:36 +08:00
"# Train the GAN for only 1000 images\n",
2024-07-29 06:24:50 +08:00
"batch_size = 128\n",
"display_step = 15\n",
2024-07-28 06:10:19 +08:00
"cur_step = 0\n",
"mean_generator_loss = 0\n",
"mean_discriminator_loss = 0\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-28 06:10:19 +08:00
"D_loss_grad = nn.value_and_grad(disc, disc_loss)\n",
"G_loss_grad = nn.value_and_grad(gen, gen_loss)\n",
"\n",
"\n",
2024-07-29 06:24:50 +08:00
"for epoch in tqdm(range(60)):\n",
2024-07-28 06:10:19 +08:00
"\n",
2024-07-29 06:24:50 +08:00
" for real in batch_iterate(batch_size, train_images[:2048]):\n",
2024-07-28 06:10:19 +08:00
" \n",
2024-07-29 06:24:50 +08:00
" D_loss,D_grads = D_loss_grad(gen, disc, real, cur_batch_size, z_dim)\n",
2024-07-28 06:10:19 +08:00
"\n",
" # Update optimizer\n",
" disc_opt.update(disc, D_grads)\n",
" \n",
" # Update gradients\n",
" mx.eval(disc.parameters(), disc_opt.state)\n",
2024-07-26 21:07:40 +08:00
"\n",
2024-07-29 06:24:50 +08:00
" G_loss,G_grads = G_loss_grad(gen, disc, cur_batch_size, z_dim)\n",
2024-07-28 06:10:19 +08:00
" \n",
" # Update optimizer\n",
" gen_opt.update(gen, G_grads)\n",
" \n",
" # Update gradients\n",
" mx.eval(gen.parameters(), gen_opt.state)\n",
" \n",
2024-07-29 06:24:50 +08:00
" \n",
" if cur_step % display_step == 0 and cur_step > 0 and batch_size:\n",
2024-07-28 06:10:19 +08:00
" print(f\"Step {epoch}: Generator loss: {G_loss}, discriminator loss: {D_loss}\")\n",
" fake_noise = mx.array(get_noise(batch_size, z_dim))\n",
" fake = gen(fake_noise)\n",
" show_images(fake)\n",
2024-07-28 22:22:40 +08:00
" show_images(real)\n",
" cur_step += 1"
2024-07-27 05:19:08 +08:00
]
2024-07-26 21:07:40 +08:00
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}