Code Arrangement

This commit is contained in:
Shubbair 2024-08-01 15:22:19 +03:00
parent 37bbf3ec54
commit 7e0bdacef3

View File

@ -1,46 +1,45 @@
import mnist
from tqdm import tqdm
import argparse
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
# Generator Block
def GenBlock(in_dim:int,out_dim:int):
return nn.Sequential(
nn.Linear(in_dim,out_dim),
nn.BatchNorm(out_dim),
nn.ReLU()
nn.BatchNorm(out_dim, 0.8),
nn.LeakyReLU(0.2)
)
# Generator Layer
# Generator Model
class Generator(nn.Module):
def __init__(self, z_dim:int = 10, im_dim:int = 784, hidden_dim: int =128):
def __init__(self, z_dim:int = 32, im_dim:int = 784, hidden_dim: int = 256):
super(Generator, self).__init__()
# Build the neural network
self.gen = nn.Sequential(
GenBlock(z_dim, hidden_dim),
GenBlock(hidden_dim, hidden_dim * 2),
GenBlock(hidden_dim * 2, hidden_dim * 4),
GenBlock(hidden_dim * 4, hidden_dim * 8),
nn.Linear(hidden_dim * 8,im_dim),
nn.Sigmoid()
nn.Linear(hidden_dim * 4,im_dim),
)
def forward(self, noise):
def __call__(self, noise):
x = self.gen(noise)
return mx.tanh(x)
return self.gen(noise)
# return random n,m normal distribution
def get_noise(n_samples:int, z_dim:int)->list:
return np.random.randn(n_samples,z_dim)
# make 2D noise with shape n_samples x z_dim
def get_noise(n_samples:list[int], z_dim:int)->list[int]:
return mx.random.normal(shape=(n_samples, z_dim))
#---------------------------------------------#
@ -48,13 +47,14 @@ def get_noise(n_samples:int, z_dim:int)->list:
def DisBlock(in_dim:int,out_dim:int):
return nn.Sequential(
nn.Linear(in_dim,out_dim),
nn.LeakyReLU(negative_slope=0.2)
nn.LeakyReLU(negative_slope=0.2),
nn.Dropout(0.3),
)
# Discriminator Layer
# Discriminator Model
class Discriminator(nn.Module):
def __init__(self,im_dim:int = 784, hidden_dim:int = 128):
def __init__(self,im_dim:int = 784, hidden_dim:int = 256):
super(Discriminator, self).__init__()
self.disc = nn.Sequential(
@ -63,12 +63,47 @@ class Discriminator(nn.Module):
DisBlock(hidden_dim * 2, hidden_dim),
nn.Linear(hidden_dim,1),
nn.Sigmoid()
)
def forward(self, noise):
def __call__(self, noise):
return self.disc(noise)
# Discriminator Loss
def disc_loss(gen, disc, real, num_images, z_dim):
noise = mx.array(get_noise(num_images, z_dim))
fake_images = gen(noise)
fake_disc = disc(fake_images)
fake_labels = mx.zeros((fake_images.shape[0],1))
fake_loss = mx.mean(nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True))
real_disc = mx.array(disc(real))
real_labels = mx.ones((real.shape[0],1))
real_loss = mx.mean(nn.losses.binary_cross_entropy(real_disc,real_labels,with_logits=True))
disc_loss = (fake_loss + real_loss) / 2.0
return disc_loss
# Genearator Loss
def gen_loss(gen, disc, num_images, z_dim):
noise = mx.array(get_noise(num_images, z_dim))
fake_images = gen(noise)
fake_disc = mx.array(disc(fake_images))
fake_labels = mx.ones((fake_images.shape[0],1))
gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)
return mx.mean(gen_loss)
def main(args:dict):
seed = 42
criterion = nn.losses.binary_cross_entropy
@ -78,7 +113,7 @@ def main(args:dict):
batch_size = 128
lr = 0.00001
np.random.seed(seed)
mx.random.seed(seed)
# Load the data
train_images, train_labels, test_images, test_labels = map(