diff --git a/gan/main.py b/gan/main.py new file mode 100644 index 00000000..7192a3bc --- /dev/null +++ b/gan/main.py @@ -0,0 +1,138 @@ +import mnist +from tqdm import tqdm + +import argparse + +import mlx.core as mx +import mlx.nn as nn +import mlx.optimizers as optim +import numpy as np + +# Generator Block +def GenBlock(in_dim:int,out_dim:int): + return nn.Sequential( + nn.Linear(in_dim,out_dim), + nn.BatchNorm(out_dim), + nn.ReLU() + ) + +# Generator Layer +class Generator(nn.Module): + + def __init__(self, z_dim:int = 10, im_dim:int = 784, hidden_dim: int =128): + super(Generator, self).__init__() + # Build the neural network + self.gen = nn.Sequential( + GenBlock(z_dim, hidden_dim), + GenBlock(hidden_dim, hidden_dim * 2), + GenBlock(hidden_dim * 2, hidden_dim * 4), + GenBlock(hidden_dim * 4, hidden_dim * 8), + + + nn.Linear(hidden_dim * 8,im_dim), + nn.Sigmoid() + ) + + def forward(self, noise): + + return self.gen(noise) + + +# return random n,m normal distribution +def get_noise(n_samples:int, z_dim:int)->list: + return np.random.randn(n_samples,z_dim) + +#---------------------------------------------# + +# Discriminator Block +def DisBlock(in_dim:int,out_dim:int): + return nn.Sequential( + nn.Linear(in_dim,out_dim), + nn.LeakyReLU(negative_slope=0.2) + ) + +# Discriminator Layer +class Discriminator(nn.Module): + + def __init__(self,im_dim:int = 784, hidden_dim:int = 128): + super(Discriminator, self).__init__() + + self.disc = nn.Sequential( + DisBlock(im_dim, hidden_dim * 4), + DisBlock(hidden_dim * 4, hidden_dim * 2), + DisBlock(hidden_dim * 2, hidden_dim), + + nn.Linear(hidden_dim,1), + ) + + def forward(self, noise): + + return self.disc(noise) + +def main(args:dict): + seed = 42 + criterion = nn.losses.binary_cross_entropy + n_epochs = 200 + z_dim = 64 + display_step = 500 + batch_size = 128 + lr = 0.00001 + + np.random.seed(seed) + + # Load the data + train_images, train_labels, test_images, test_labels = map( + mx.array, getattr(mnist, args.dataset)() + ) + + gen = Generator(z_dim) + gen_opt = optim.Adam(learning_rate=lr) + disc = Discriminator() + disc_opt = optim.Adam(learning_rate=lr) + + # use partial function + def disc_loss(gen, disc, criterion, real, num_images, z_dim): + noise = get_noise(num_images, z_dim,device) + fake_images = gen(noise) + + fake_disc = disc(fake_images.detach()) + fake_labels = mx.zeros(fake_images.size(0),1) + fake_loss = criterion(fake_disc,fake_labels) + + real_disc = disc(real) + real_labels = mx.ones(real.size(0),1) + real_loss = criterion(real_disc,real_labels) + + disc_loss = (fake_loss + real_loss) / 2 + + return disc_loss + + def gen_loss(gen, disc, criterion, num_images, z_dim): + + noise = get_noise(num_images, z_dim,device) + fake_images = gen(noise) + + fake_disc = disc(fake_images) + fake_labels = mx.ones(fake_images.size(0),1) + + gen_loss = criterion(fake_disc,fake_labels) + + return gen_loss + + # training + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Train a simple GAN on MNIST with MLX.") + parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.") + parser.add_argument( + "--dataset", + type=str, + default="mnist", + choices=["mnist", "fashion_mnist"], + help="The dataset to use.", + ) + args = parser.parse_args() + if not args.gpu: + mx.set_default_device(mx.cpu) + main(args) diff --git a/gan/mnist.py b/gan/mnist.py new file mode 100644 index 00000000..c5f920e6 --- /dev/null +++ b/gan/mnist.py @@ -0,0 +1,83 @@ +# Copyright © 2023 Apple Inc. + +import gzip +import os +import pickle +from urllib import request + +import numpy as np + + +def mnist( + save_dir="/tmp", + base_url="https://raw.githubusercontent.com/fgnt/mnist/master/", + filename="mnist.pkl", +): + """ + Load the MNIST dataset in 4 tensors: train images, train labels, + test images, and test labels. + + Checks `save_dir` for already downloaded data otherwise downloads. + + Download code modified from: + https://github.com/hsjeong5/MNIST-for-Numpy + """ + + def download_and_save(save_file): + filename = [ + ["training_images", "train-images-idx3-ubyte.gz"], + ["test_images", "t10k-images-idx3-ubyte.gz"], + ["training_labels", "train-labels-idx1-ubyte.gz"], + ["test_labels", "t10k-labels-idx1-ubyte.gz"], + ] + + mnist = {} + for name in filename: + out_file = os.path.join("/tmp", name[1]) + request.urlretrieve(base_url + name[1], out_file) + for name in filename[:2]: + out_file = os.path.join("/tmp", name[1]) + with gzip.open(out_file, "rb") as f: + mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=16).reshape( + -1, 28 * 28 + ) + for name in filename[-2:]: + out_file = os.path.join("/tmp", name[1]) + with gzip.open(out_file, "rb") as f: + mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8) + with open(save_file, "wb") as f: + pickle.dump(mnist, f) + + save_file = os.path.join(save_dir, filename) + if not os.path.exists(save_file): + download_and_save(save_file) + with open(save_file, "rb") as f: + mnist = pickle.load(f) + + def preproc(x): + return x.astype(np.float32) / 255.0 + + mnist["training_images"] = preproc(mnist["training_images"]) + mnist["test_images"] = preproc(mnist["test_images"]) + return ( + mnist["training_images"], + mnist["training_labels"].astype(np.uint32), + mnist["test_images"], + mnist["test_labels"].astype(np.uint32), + ) + + +def fashion_mnist(save_dir="/tmp"): + return mnist( + save_dir, + base_url="http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/", + filename="fashion_mnist.pkl", + ) + + +if __name__ == "__main__": + train_x, train_y, test_x, test_y = mnist() + assert train_x.shape == (60000, 28 * 28), "Wrong training set size" + assert train_y.shape == (60000,), "Wrong training set size" + assert test_x.shape == (10000, 28 * 28), "Wrong test set size" + assert test_y.shape == (10000,), "Wrong test set size"