diff --git a/gan/main.py b/gan/main.py index 1ae1abde..dd794289 100644 --- a/gan/main.py +++ b/gan/main.py @@ -1,73 +1,108 @@ import mnist -from tqdm import tqdm import argparse import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim + +from tqdm import tqdm import numpy as np +import matplotlib.pyplot as plt + # Generator Block def GenBlock(in_dim:int,out_dim:int): return nn.Sequential( nn.Linear(in_dim,out_dim), - nn.BatchNorm(out_dim), - nn.ReLU() + nn.BatchNorm(out_dim, 0.8), + nn.LeakyReLU(0.2) ) -# Generator Layer +# Generator Model class Generator(nn.Module): - def __init__(self, z_dim:int = 10, im_dim:int = 784, hidden_dim: int =128): + def __init__(self, z_dim:int = 32, im_dim:int = 784, hidden_dim: int = 256): super(Generator, self).__init__() - # Build the neural network + self.gen = nn.Sequential( GenBlock(z_dim, hidden_dim), GenBlock(hidden_dim, hidden_dim * 2), GenBlock(hidden_dim * 2, hidden_dim * 4), - GenBlock(hidden_dim * 4, hidden_dim * 8), - - nn.Linear(hidden_dim * 8,im_dim), - nn.Sigmoid() + nn.Linear(hidden_dim * 4,im_dim), ) - def forward(self, noise): + def __call__(self, noise): + x = self.gen(noise) + return mx.tanh(x) - return self.gen(noise) - - -# return random n,m normal distribution -def get_noise(n_samples:int, z_dim:int)->list: - return np.random.randn(n_samples,z_dim) +# make 2D noise with shape n_samples x z_dim +def get_noise(n_samples:list[int], z_dim:int)->list[int]: + return mx.random.normal(shape=(n_samples, z_dim)) #---------------------------------------------# # Discriminator Block -def DisBlock(in_dim:int,out_dim:int): +def DisBlock(in_dim:int,out_dim:int): return nn.Sequential( nn.Linear(in_dim,out_dim), - nn.LeakyReLU(negative_slope=0.2) + nn.LeakyReLU(negative_slope=0.2), + nn.Dropout(0.3), ) -# Discriminator Layer +# Discriminator Model class Discriminator(nn.Module): - def __init__(self,im_dim:int = 784, hidden_dim:int = 128): + def __init__(self,im_dim:int = 784, hidden_dim:int = 256): super(Discriminator, self).__init__() self.disc = nn.Sequential( DisBlock(im_dim, hidden_dim * 4), DisBlock(hidden_dim * 4, hidden_dim * 2), DisBlock(hidden_dim * 2, hidden_dim), - + nn.Linear(hidden_dim,1), + nn.Sigmoid() ) - def forward(self, noise): - + def __call__(self, noise): return self.disc(noise) + +# Discriminator Loss +def disc_loss(gen, disc, real, num_images, z_dim): + + noise = mx.array(get_noise(num_images, z_dim)) + fake_images = gen(noise) + + fake_disc = disc(fake_images) + + fake_labels = mx.zeros((fake_images.shape[0],1)) + + fake_loss = mx.mean(nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)) + + real_disc = mx.array(disc(real)) + real_labels = mx.ones((real.shape[0],1)) + + real_loss = mx.mean(nn.losses.binary_cross_entropy(real_disc,real_labels,with_logits=True)) + + disc_loss = (fake_loss + real_loss) / 2.0 + + return disc_loss + +# Genearator Loss +def gen_loss(gen, disc, num_images, z_dim): + + noise = mx.array(get_noise(num_images, z_dim)) + + fake_images = gen(noise) + fake_disc = mx.array(disc(fake_images)) + + fake_labels = mx.ones((fake_images.shape[0],1)) + + gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True) + + return mx.mean(gen_loss) def main(args:dict): seed = 42 @@ -78,7 +113,7 @@ def main(args:dict): batch_size = 128 lr = 0.00001 - np.random.seed(seed) + mx.random.seed(seed) # Load the data train_images, train_labels, test_images, test_labels = map(