mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Add GAN model 25/7
This commit is contained in:
parent
cd8efc7fbc
commit
5e7ce1048c
138
gan/main.py
Normal file
138
gan/main.py
Normal file
@ -0,0 +1,138 @@
|
||||
import mnist
|
||||
from tqdm import tqdm
|
||||
|
||||
import argparse
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import numpy as np
|
||||
|
||||
# Generator Block
|
||||
def GenBlock(in_dim:int,out_dim:int):
|
||||
return nn.Sequential(
|
||||
nn.Linear(in_dim,out_dim),
|
||||
nn.BatchNorm(out_dim),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
# Generator Layer
|
||||
class Generator(nn.Module):
|
||||
|
||||
def __init__(self, z_dim:int = 10, im_dim:int = 784, hidden_dim: int =128):
|
||||
super(Generator, self).__init__()
|
||||
# Build the neural network
|
||||
self.gen = nn.Sequential(
|
||||
GenBlock(z_dim, hidden_dim),
|
||||
GenBlock(hidden_dim, hidden_dim * 2),
|
||||
GenBlock(hidden_dim * 2, hidden_dim * 4),
|
||||
GenBlock(hidden_dim * 4, hidden_dim * 8),
|
||||
|
||||
|
||||
nn.Linear(hidden_dim * 8,im_dim),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, noise):
|
||||
|
||||
return self.gen(noise)
|
||||
|
||||
|
||||
# return random n,m normal distribution
|
||||
def get_noise(n_samples:int, z_dim:int)->list:
|
||||
return np.random.randn(n_samples,z_dim)
|
||||
|
||||
#---------------------------------------------#
|
||||
|
||||
# Discriminator Block
|
||||
def DisBlock(in_dim:int,out_dim:int):
|
||||
return nn.Sequential(
|
||||
nn.Linear(in_dim,out_dim),
|
||||
nn.LeakyReLU(negative_slope=0.2)
|
||||
)
|
||||
|
||||
# Discriminator Layer
|
||||
class Discriminator(nn.Module):
|
||||
|
||||
def __init__(self,im_dim:int = 784, hidden_dim:int = 128):
|
||||
super(Discriminator, self).__init__()
|
||||
|
||||
self.disc = nn.Sequential(
|
||||
DisBlock(im_dim, hidden_dim * 4),
|
||||
DisBlock(hidden_dim * 4, hidden_dim * 2),
|
||||
DisBlock(hidden_dim * 2, hidden_dim),
|
||||
|
||||
nn.Linear(hidden_dim,1),
|
||||
)
|
||||
|
||||
def forward(self, noise):
|
||||
|
||||
return self.disc(noise)
|
||||
|
||||
def main(args:dict):
|
||||
seed = 42
|
||||
criterion = nn.losses.binary_cross_entropy
|
||||
n_epochs = 200
|
||||
z_dim = 64
|
||||
display_step = 500
|
||||
batch_size = 128
|
||||
lr = 0.00001
|
||||
|
||||
np.random.seed(seed)
|
||||
|
||||
# Load the data
|
||||
train_images, train_labels, test_images, test_labels = map(
|
||||
mx.array, getattr(mnist, args.dataset)()
|
||||
)
|
||||
|
||||
gen = Generator(z_dim)
|
||||
gen_opt = optim.Adam(learning_rate=lr)
|
||||
disc = Discriminator()
|
||||
disc_opt = optim.Adam(learning_rate=lr)
|
||||
|
||||
# use partial function
|
||||
def disc_loss(gen, disc, criterion, real, num_images, z_dim):
|
||||
noise = get_noise(num_images, z_dim,device)
|
||||
fake_images = gen(noise)
|
||||
|
||||
fake_disc = disc(fake_images.detach())
|
||||
fake_labels = mx.zeros(fake_images.size(0),1)
|
||||
fake_loss = criterion(fake_disc,fake_labels)
|
||||
|
||||
real_disc = disc(real)
|
||||
real_labels = mx.ones(real.size(0),1)
|
||||
real_loss = criterion(real_disc,real_labels)
|
||||
|
||||
disc_loss = (fake_loss + real_loss) / 2
|
||||
|
||||
return disc_loss
|
||||
|
||||
def gen_loss(gen, disc, criterion, num_images, z_dim):
|
||||
|
||||
noise = get_noise(num_images, z_dim,device)
|
||||
fake_images = gen(noise)
|
||||
|
||||
fake_disc = disc(fake_images)
|
||||
fake_labels = mx.ones(fake_images.size(0),1)
|
||||
|
||||
gen_loss = criterion(fake_disc,fake_labels)
|
||||
|
||||
return gen_loss
|
||||
|
||||
# training
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Train a simple GAN on MNIST with MLX.")
|
||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="mnist",
|
||||
choices=["mnist", "fashion_mnist"],
|
||||
help="The dataset to use.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if not args.gpu:
|
||||
mx.set_default_device(mx.cpu)
|
||||
main(args)
|
83
gan/mnist.py
Normal file
83
gan/mnist.py
Normal file
@ -0,0 +1,83 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import gzip
|
||||
import os
|
||||
import pickle
|
||||
from urllib import request
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def mnist(
|
||||
save_dir="/tmp",
|
||||
base_url="https://raw.githubusercontent.com/fgnt/mnist/master/",
|
||||
filename="mnist.pkl",
|
||||
):
|
||||
"""
|
||||
Load the MNIST dataset in 4 tensors: train images, train labels,
|
||||
test images, and test labels.
|
||||
|
||||
Checks `save_dir` for already downloaded data otherwise downloads.
|
||||
|
||||
Download code modified from:
|
||||
https://github.com/hsjeong5/MNIST-for-Numpy
|
||||
"""
|
||||
|
||||
def download_and_save(save_file):
|
||||
filename = [
|
||||
["training_images", "train-images-idx3-ubyte.gz"],
|
||||
["test_images", "t10k-images-idx3-ubyte.gz"],
|
||||
["training_labels", "train-labels-idx1-ubyte.gz"],
|
||||
["test_labels", "t10k-labels-idx1-ubyte.gz"],
|
||||
]
|
||||
|
||||
mnist = {}
|
||||
for name in filename:
|
||||
out_file = os.path.join("/tmp", name[1])
|
||||
request.urlretrieve(base_url + name[1], out_file)
|
||||
for name in filename[:2]:
|
||||
out_file = os.path.join("/tmp", name[1])
|
||||
with gzip.open(out_file, "rb") as f:
|
||||
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(
|
||||
-1, 28 * 28
|
||||
)
|
||||
for name in filename[-2:]:
|
||||
out_file = os.path.join("/tmp", name[1])
|
||||
with gzip.open(out_file, "rb") as f:
|
||||
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8)
|
||||
with open(save_file, "wb") as f:
|
||||
pickle.dump(mnist, f)
|
||||
|
||||
save_file = os.path.join(save_dir, filename)
|
||||
if not os.path.exists(save_file):
|
||||
download_and_save(save_file)
|
||||
with open(save_file, "rb") as f:
|
||||
mnist = pickle.load(f)
|
||||
|
||||
def preproc(x):
|
||||
return x.astype(np.float32) / 255.0
|
||||
|
||||
mnist["training_images"] = preproc(mnist["training_images"])
|
||||
mnist["test_images"] = preproc(mnist["test_images"])
|
||||
return (
|
||||
mnist["training_images"],
|
||||
mnist["training_labels"].astype(np.uint32),
|
||||
mnist["test_images"],
|
||||
mnist["test_labels"].astype(np.uint32),
|
||||
)
|
||||
|
||||
|
||||
def fashion_mnist(save_dir="/tmp"):
|
||||
return mnist(
|
||||
save_dir,
|
||||
base_url="http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/",
|
||||
filename="fashion_mnist.pkl",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_x, train_y, test_x, test_y = mnist()
|
||||
assert train_x.shape == (60000, 28 * 28), "Wrong training set size"
|
||||
assert train_y.shape == (60000,), "Wrong training set size"
|
||||
assert test_x.shape == (10000, 28 * 28), "Wrong test set size"
|
||||
assert test_y.shape == (10000,), "Wrong test set size"
|
Loading…
Reference in New Issue
Block a user