mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Code Arrangement
This commit is contained in:
parent
37bbf3ec54
commit
7e0bdacef3
85
gan/main.py
85
gan/main.py
@ -1,73 +1,108 @@
|
||||
import mnist
|
||||
from tqdm import tqdm
|
||||
|
||||
import argparse
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
# Generator Block
|
||||
def GenBlock(in_dim:int,out_dim:int):
|
||||
return nn.Sequential(
|
||||
nn.Linear(in_dim,out_dim),
|
||||
nn.BatchNorm(out_dim),
|
||||
nn.ReLU()
|
||||
nn.BatchNorm(out_dim, 0.8),
|
||||
nn.LeakyReLU(0.2)
|
||||
)
|
||||
|
||||
# Generator Layer
|
||||
# Generator Model
|
||||
class Generator(nn.Module):
|
||||
|
||||
def __init__(self, z_dim:int = 10, im_dim:int = 784, hidden_dim: int =128):
|
||||
def __init__(self, z_dim:int = 32, im_dim:int = 784, hidden_dim: int = 256):
|
||||
super(Generator, self).__init__()
|
||||
# Build the neural network
|
||||
|
||||
self.gen = nn.Sequential(
|
||||
GenBlock(z_dim, hidden_dim),
|
||||
GenBlock(hidden_dim, hidden_dim * 2),
|
||||
GenBlock(hidden_dim * 2, hidden_dim * 4),
|
||||
GenBlock(hidden_dim * 4, hidden_dim * 8),
|
||||
|
||||
|
||||
nn.Linear(hidden_dim * 8,im_dim),
|
||||
nn.Sigmoid()
|
||||
nn.Linear(hidden_dim * 4,im_dim),
|
||||
)
|
||||
|
||||
def forward(self, noise):
|
||||
def __call__(self, noise):
|
||||
x = self.gen(noise)
|
||||
return mx.tanh(x)
|
||||
|
||||
return self.gen(noise)
|
||||
|
||||
|
||||
# return random n,m normal distribution
|
||||
def get_noise(n_samples:int, z_dim:int)->list:
|
||||
return np.random.randn(n_samples,z_dim)
|
||||
# make 2D noise with shape n_samples x z_dim
|
||||
def get_noise(n_samples:list[int], z_dim:int)->list[int]:
|
||||
return mx.random.normal(shape=(n_samples, z_dim))
|
||||
|
||||
#---------------------------------------------#
|
||||
|
||||
# Discriminator Block
|
||||
def DisBlock(in_dim:int,out_dim:int):
|
||||
def DisBlock(in_dim:int,out_dim:int):
|
||||
return nn.Sequential(
|
||||
nn.Linear(in_dim,out_dim),
|
||||
nn.LeakyReLU(negative_slope=0.2)
|
||||
nn.LeakyReLU(negative_slope=0.2),
|
||||
nn.Dropout(0.3),
|
||||
)
|
||||
|
||||
# Discriminator Layer
|
||||
# Discriminator Model
|
||||
class Discriminator(nn.Module):
|
||||
|
||||
def __init__(self,im_dim:int = 784, hidden_dim:int = 128):
|
||||
def __init__(self,im_dim:int = 784, hidden_dim:int = 256):
|
||||
super(Discriminator, self).__init__()
|
||||
|
||||
self.disc = nn.Sequential(
|
||||
DisBlock(im_dim, hidden_dim * 4),
|
||||
DisBlock(hidden_dim * 4, hidden_dim * 2),
|
||||
DisBlock(hidden_dim * 2, hidden_dim),
|
||||
|
||||
|
||||
nn.Linear(hidden_dim,1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, noise):
|
||||
|
||||
def __call__(self, noise):
|
||||
return self.disc(noise)
|
||||
|
||||
# Discriminator Loss
|
||||
def disc_loss(gen, disc, real, num_images, z_dim):
|
||||
|
||||
noise = mx.array(get_noise(num_images, z_dim))
|
||||
fake_images = gen(noise)
|
||||
|
||||
fake_disc = disc(fake_images)
|
||||
|
||||
fake_labels = mx.zeros((fake_images.shape[0],1))
|
||||
|
||||
fake_loss = mx.mean(nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True))
|
||||
|
||||
real_disc = mx.array(disc(real))
|
||||
real_labels = mx.ones((real.shape[0],1))
|
||||
|
||||
real_loss = mx.mean(nn.losses.binary_cross_entropy(real_disc,real_labels,with_logits=True))
|
||||
|
||||
disc_loss = (fake_loss + real_loss) / 2.0
|
||||
|
||||
return disc_loss
|
||||
|
||||
# Genearator Loss
|
||||
def gen_loss(gen, disc, num_images, z_dim):
|
||||
|
||||
noise = mx.array(get_noise(num_images, z_dim))
|
||||
|
||||
fake_images = gen(noise)
|
||||
fake_disc = mx.array(disc(fake_images))
|
||||
|
||||
fake_labels = mx.ones((fake_images.shape[0],1))
|
||||
|
||||
gen_loss = nn.losses.binary_cross_entropy(fake_disc,fake_labels,with_logits=True)
|
||||
|
||||
return mx.mean(gen_loss)
|
||||
|
||||
def main(args:dict):
|
||||
seed = 42
|
||||
@ -78,7 +113,7 @@ def main(args:dict):
|
||||
batch_size = 128
|
||||
lr = 0.00001
|
||||
|
||||
np.random.seed(seed)
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Load the data
|
||||
train_images, train_labels, test_images, test_labels = map(
|
||||
|
Loading…
Reference in New Issue
Block a user