2024-07-26 02:00:41 +08:00
|
|
|
import mnist
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
|
|
|
|
import mlx.core as mx
|
|
|
|
import mlx.nn as nn
|
|
|
|
import mlx.optimizers as optim
|
2024-08-01 20:22:19 +08:00
|
|
|
|
|
|
|
from tqdm import tqdm
|
2024-07-26 02:00:41 +08:00
|
|
|
import numpy as np
|
2024-08-01 20:22:19 +08:00
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
2024-07-26 02:00:41 +08:00
|
|
|
# Generator Block
|
|
|
|
def GenBlock(in_dim:int,out_dim:int):
|
|
|
|
return nn.Sequential(
|
|
|
|
nn.Linear(in_dim,out_dim),
|
2024-08-01 20:22:19 +08:00
|
|
|
nn.BatchNorm(out_dim, 0.8),
|
|
|
|
nn.LeakyReLU(0.2)
|
2024-07-26 02:00:41 +08:00
|
|
|
)
|
|
|
|
|
2024-08-01 20:22:19 +08:00
|
|
|
# Generator Model
|
2024-07-26 02:00:41 +08:00
|
|
|
class Generator(nn.Module):
|
|
|
|
|
2024-08-01 20:22:19 +08:00
|
|
|
def __init__(self, z_dim:int = 32, im_dim:int = 784, hidden_dim: int = 256):
|
2024-07-26 02:00:41 +08:00
|
|
|
super(Generator, self).__init__()
|
2024-08-01 20:22:19 +08:00
|
|
|
|
2024-07-26 02:00:41 +08:00
|
|
|
self.gen = nn.Sequential(
|
|
|
|
GenBlock(z_dim, hidden_dim),
|
|
|
|
GenBlock(hidden_dim, hidden_dim * 2),
|
|
|
|
GenBlock(hidden_dim * 2, hidden_dim * 4),
|
|
|
|
|
2024-08-01 20:22:19 +08:00
|
|
|
nn.Linear(hidden_dim * 4,im_dim),
|
2024-07-26 02:00:41 +08:00
|
|
|
)
|
|
|
|
|
2024-08-01 20:22:19 +08:00
|
|
|
def __call__(self, noise):
|
|
|
|
x = self.gen(noise)
|
|
|
|
return mx.tanh(x)
|
2024-07-26 02:00:41 +08:00
|
|
|
|
2024-08-01 20:22:19 +08:00
|
|
|
# make 2D noise with shape n_samples x z_dim
|
|
|
|
def get_noise(n_samples:list[int], z_dim:int)->list[int]:
|
|
|
|
return mx.random.normal(shape=(n_samples, z_dim))
|
2024-07-26 02:00:41 +08:00
|
|
|
|
|
|
|
#---------------------------------------------#
|
|
|
|
|
|
|
|
# Discriminator Block
|
2024-08-01 20:22:19 +08:00
|
|
|
def DisBlock(in_dim:int,out_dim:int):
|
2024-07-26 02:00:41 +08:00
|
|
|
return nn.Sequential(
|
|
|
|
nn.Linear(in_dim,out_dim),
|
2024-08-01 20:22:19 +08:00
|
|
|
nn.LeakyReLU(negative_slope=0.2),
|
|
|
|
nn.Dropout(0.3),
|
2024-07-26 02:00:41 +08:00
|
|
|
)
|
|
|
|
|
2024-08-01 20:22:19 +08:00
|
|
|
# Discriminator Model
|
2024-07-26 02:00:41 +08:00
|
|
|
class Discriminator(nn.Module):
|
|
|
|
|
2024-08-01 20:22:19 +08:00
|
|
|
def __init__(self,im_dim:int = 784, hidden_dim:int = 256):
|
2024-07-26 02:00:41 +08:00
|
|
|
super(Discriminator, self).__init__()
|
|
|
|
|
|
|
|
self.disc = nn.Sequential(
|
|
|
|
DisBlock(im_dim, hidden_dim * 4),
|
|
|
|
DisBlock(hidden_dim * 4, hidden_dim * 2),
|
|
|
|
DisBlock(hidden_dim * 2, hidden_dim),
|
2024-08-01 20:22:19 +08:00
|
|
|
|
2024-07-26 02:00:41 +08:00
|
|
|
nn.Linear(hidden_dim,1),
|
2024-08-01 20:22:19 +08:00
|
|
|
nn.Sigmoid()
|
2024-07-26 02:00:41 +08:00
|
|
|
)
|
|
|
|
|
2024-08-01 20:22:19 +08:00
|
|
|
def __call__(self, noise):
|
2024-07-26 02:00:41 +08:00
|
|
|
return self.disc(noise)
|
2024-08-01 20:22:19 +08:00
|
|
|
|
|
|
|
# Discriminator Loss
|
|
|
|
def disc_loss(gen, disc, real, num_images, z_dim):
|
|
|
|
|
|
|
|
noise = mx.array(get_noise(num_images, z_dim))
|
|
|
|
fake_images = gen(noise)
|
|
|
|
|
|
|
|
fake_disc = disc(fake_images)
|
|
|
|
|
|
|
|
fake_labels = mx.zeros((fake_images.shape[0],1))
|
|
|
|
|
|
|
|
fake_loss = mx.mean(nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True))
|
|
|
|
|
|
|
|
real_disc = mx.array(disc(real))
|
|
|
|
real_labels = mx.ones((real.shape[0],1))
|
|
|
|
|
|
|
|
real_loss = mx.mean(nn.losses.binary_cross_entropy(real_disc,real_labels,with_logits=True))
|
|
|
|
|
|
|
|
disc_loss = (fake_loss + real_loss) / 2.0
|
|
|
|
|
|
|
|
return disc_loss
|
|
|
|
|
|
|
|
# Genearator Loss
|
|
|
|
def gen_loss(gen, disc, num_images, z_dim):
|
|
|
|
|
|
|
|
noise = mx.array(get_noise(num_images, z_dim))
|
|
|
|
|
|
|
|
fake_images = gen(noise)
|
|
|
|
fake_disc = mx.array(disc(fake_images))
|
|
|
|
|
|
|
|
fake_labels = mx.ones((fake_images.shape[0],1))
|
|
|
|
|
|
|
|
gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)
|
|
|
|
|
|
|
|
return mx.mean(gen_loss)
|
2024-07-26 02:00:41 +08:00
|
|
|
|
2024-08-01 20:29:43 +08:00
|
|
|
# make batch of images
|
|
|
|
def batch_iterate(batch_size: int, ipt: list[int])-> list[int]:
|
|
|
|
perm = np.random.permutation(len(ipt))
|
|
|
|
for s in range(0, len(ipt), batch_size):
|
|
|
|
ids = perm[s : s + batch_size]
|
|
|
|
yield ipt[ids]
|
|
|
|
|
2024-08-01 20:41:21 +08:00
|
|
|
# plot batch of images at epoch steps
|
2024-08-01 20:29:43 +08:00
|
|
|
def show_images(epoch_num:int,imgs:list[int],num_imgs:int = 25):
|
|
|
|
if (imgs.shape[0] > 0):
|
|
|
|
fig,axes = plt.subplots(5, 5, figsize=(5, 5))
|
|
|
|
|
|
|
|
for i, ax in enumerate(axes.flat):
|
|
|
|
img = mx.array(imgs[i]).reshape(28,28)
|
|
|
|
ax.imshow(img,cmap='gray')
|
|
|
|
ax.axis('off')
|
|
|
|
plt.tight_layout()
|
2024-08-01 20:41:21 +08:00
|
|
|
plt.savefig('gen_images/img_{}.png'.format(epoch_num))
|
2024-08-01 20:29:43 +08:00
|
|
|
plt.show()
|
|
|
|
|
2024-07-26 02:00:41 +08:00
|
|
|
def main(args:dict):
|
|
|
|
seed = 42
|
2024-08-01 20:29:43 +08:00
|
|
|
n_epochs = 500
|
|
|
|
z_dim = 128
|
2024-07-26 02:00:41 +08:00
|
|
|
batch_size = 128
|
2024-08-01 20:29:43 +08:00
|
|
|
lr = 2e-5
|
2024-07-26 02:00:41 +08:00
|
|
|
|
2024-08-01 20:22:19 +08:00
|
|
|
mx.random.seed(seed)
|
2024-07-26 02:00:41 +08:00
|
|
|
|
|
|
|
# Load the data
|
2024-08-01 20:29:43 +08:00
|
|
|
train_images,*_ = map(np.array, getattr(mnist,'mnist')())
|
2024-07-26 02:00:41 +08:00
|
|
|
|
2024-08-01 20:29:43 +08:00
|
|
|
# Normalization images => [-1,1]
|
|
|
|
train_images = train_images * 2.0 - 1.0
|
2024-07-26 02:00:41 +08:00
|
|
|
|
2024-08-01 20:29:43 +08:00
|
|
|
gen = Generator(z_dim)
|
|
|
|
mx.eval(gen.parameters())
|
|
|
|
gen_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999])
|
2024-07-26 02:00:41 +08:00
|
|
|
|
2024-08-01 20:29:43 +08:00
|
|
|
disc = Discriminator()
|
|
|
|
mx.eval(disc.parameters())
|
|
|
|
disc_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999])
|
2024-07-26 02:00:41 +08:00
|
|
|
|
2024-07-26 21:07:40 +08:00
|
|
|
# TODO training...
|
2024-08-01 01:23:57 +08:00
|
|
|
|
|
|
|
D_loss_grad = nn.value_and_grad(disc, disc_loss)
|
|
|
|
G_loss_grad = nn.value_and_grad(gen, gen_loss)
|
|
|
|
|
|
|
|
for epoch in tqdm(range(n_epochs)):
|
|
|
|
|
|
|
|
for idx,real in enumerate(batch_iterate(batch_size, train_images)):
|
|
|
|
|
|
|
|
# TODO Train Discriminator
|
|
|
|
D_loss,D_grads = D_loss_grad(gen, disc,mx.array(real), batch_size, z_dim)
|
|
|
|
|
|
|
|
# Update optimizer
|
|
|
|
disc_opt.update(disc, D_grads)
|
|
|
|
|
|
|
|
# Update gradients
|
|
|
|
mx.eval(disc.parameters(), disc_opt.state)
|
|
|
|
|
|
|
|
# TODO Train Generator
|
|
|
|
G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)
|
|
|
|
|
|
|
|
# Update optimizer
|
|
|
|
gen_opt.update(gen, G_grads)
|
|
|
|
|
|
|
|
# Update gradients
|
|
|
|
mx.eval(gen.parameters(), gen_opt.state)
|
|
|
|
|
|
|
|
if epoch%100==0:
|
|
|
|
print("Epoch: {}, iteration: {}, Discriminator Loss:{}, Generator Loss: {}".format(epoch,idx,D_loss,G_loss))
|
|
|
|
fake_noise = mx.array(get_noise(batch_size, z_dim))
|
|
|
|
fake = gen(fake_noise)
|
|
|
|
show_images(epoch,fake)
|
2024-07-26 02:00:41 +08:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser("Train a simple GAN on MNIST with MLX.")
|
|
|
|
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
|
|
|
parser.add_argument(
|
|
|
|
"--dataset",
|
|
|
|
type=str,
|
|
|
|
default="mnist",
|
|
|
|
choices=["mnist", "fashion_mnist"],
|
|
|
|
help="The dataset to use.",
|
|
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if not args.gpu:
|
|
|
|
mx.set_default_device(mx.cpu)
|
|
|
|
main(args)
|