Code Arrangement

This commit is contained in:
Shubbair 2024-08-01 15:22:19 +03:00
parent 37bbf3ec54
commit 7e0bdacef3

View File

@ -1,46 +1,45 @@
import mnist import mnist
from tqdm import tqdm
import argparse import argparse
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import mlx.optimizers as optim import mlx.optimizers as optim
from tqdm import tqdm
import numpy as np import numpy as np
import matplotlib.pyplot as plt
# Generator Block # Generator Block
def GenBlock(in_dim:int,out_dim:int): def GenBlock(in_dim:int,out_dim:int):
return nn.Sequential( return nn.Sequential(
nn.Linear(in_dim,out_dim), nn.Linear(in_dim,out_dim),
nn.BatchNorm(out_dim), nn.BatchNorm(out_dim, 0.8),
nn.ReLU() nn.LeakyReLU(0.2)
) )
# Generator Layer # Generator Model
class Generator(nn.Module): class Generator(nn.Module):
def __init__(self, z_dim:int = 10, im_dim:int = 784, hidden_dim: int =128): def __init__(self, z_dim:int = 32, im_dim:int = 784, hidden_dim: int = 256):
super(Generator, self).__init__() super(Generator, self).__init__()
# Build the neural network
self.gen = nn.Sequential( self.gen = nn.Sequential(
GenBlock(z_dim, hidden_dim), GenBlock(z_dim, hidden_dim),
GenBlock(hidden_dim, hidden_dim * 2), GenBlock(hidden_dim, hidden_dim * 2),
GenBlock(hidden_dim * 2, hidden_dim * 4), GenBlock(hidden_dim * 2, hidden_dim * 4),
GenBlock(hidden_dim * 4, hidden_dim * 8),
nn.Linear(hidden_dim * 4,im_dim),
nn.Linear(hidden_dim * 8,im_dim),
nn.Sigmoid()
) )
def forward(self, noise): def __call__(self, noise):
x = self.gen(noise)
return mx.tanh(x)
return self.gen(noise) # make 2D noise with shape n_samples x z_dim
def get_noise(n_samples:list[int], z_dim:int)->list[int]:
return mx.random.normal(shape=(n_samples, z_dim))
# return random n,m normal distribution
def get_noise(n_samples:int, z_dim:int)->list:
return np.random.randn(n_samples,z_dim)
#---------------------------------------------# #---------------------------------------------#
@ -48,13 +47,14 @@ def get_noise(n_samples:int, z_dim:int)->list:
def DisBlock(in_dim:int,out_dim:int): def DisBlock(in_dim:int,out_dim:int):
return nn.Sequential( return nn.Sequential(
nn.Linear(in_dim,out_dim), nn.Linear(in_dim,out_dim),
nn.LeakyReLU(negative_slope=0.2) nn.LeakyReLU(negative_slope=0.2),
nn.Dropout(0.3),
) )
# Discriminator Layer # Discriminator Model
class Discriminator(nn.Module): class Discriminator(nn.Module):
def __init__(self,im_dim:int = 784, hidden_dim:int = 128): def __init__(self,im_dim:int = 784, hidden_dim:int = 256):
super(Discriminator, self).__init__() super(Discriminator, self).__init__()
self.disc = nn.Sequential( self.disc = nn.Sequential(
@ -63,12 +63,47 @@ class Discriminator(nn.Module):
DisBlock(hidden_dim * 2, hidden_dim), DisBlock(hidden_dim * 2, hidden_dim),
nn.Linear(hidden_dim,1), nn.Linear(hidden_dim,1),
nn.Sigmoid()
) )
def forward(self, noise): def __call__(self, noise):
return self.disc(noise) return self.disc(noise)
# Discriminator Loss
def disc_loss(gen, disc, real, num_images, z_dim):
noise = mx.array(get_noise(num_images, z_dim))
fake_images = gen(noise)
fake_disc = disc(fake_images)
fake_labels = mx.zeros((fake_images.shape[0],1))
fake_loss = mx.mean(nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True))
real_disc = mx.array(disc(real))
real_labels = mx.ones((real.shape[0],1))
real_loss = mx.mean(nn.losses.binary_cross_entropy(real_disc,real_labels,with_logits=True))
disc_loss = (fake_loss + real_loss) / 2.0
return disc_loss
# Genearator Loss
def gen_loss(gen, disc, num_images, z_dim):
noise = mx.array(get_noise(num_images, z_dim))
fake_images = gen(noise)
fake_disc = mx.array(disc(fake_images))
fake_labels = mx.ones((fake_images.shape[0],1))
gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)
return mx.mean(gen_loss)
def main(args:dict): def main(args:dict):
seed = 42 seed = 42
criterion = nn.losses.binary_cross_entropy criterion = nn.losses.binary_cross_entropy
@ -78,7 +113,7 @@ def main(args:dict):
batch_size = 128 batch_size = 128
lr = 0.00001 lr = 0.00001
np.random.seed(seed) mx.random.seed(seed)
# Load the data # Load the data
train_images, train_labels, test_images, test_labels = map( train_images, train_labels, test_images, test_labels = map(