mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Code Arrangement
This commit is contained in:
parent
7e0bdacef3
commit
f84b231cf2
78
gan/main.py
78
gan/main.py
@ -104,68 +104,53 @@ def gen_loss(gen, disc, num_images, z_dim):
|
||||
|
||||
return mx.mean(gen_loss)
|
||||
|
||||
# make batch of images
|
||||
def batch_iterate(batch_size: int, ipt: list[int])-> list[int]:
|
||||
perm = np.random.permutation(len(ipt))
|
||||
for s in range(0, len(ipt), batch_size):
|
||||
ids = perm[s : s + batch_size]
|
||||
yield ipt[ids]
|
||||
|
||||
def show_images(epoch_num:int,imgs:list[int],num_imgs:int = 25):
|
||||
if (imgs.shape[0] > 0):
|
||||
fig,axes = plt.subplots(5, 5, figsize=(5, 5))
|
||||
|
||||
for i, ax in enumerate(axes.flat):
|
||||
img = mx.array(imgs[i]).reshape(28,28)
|
||||
ax.imshow(img,cmap='gray')
|
||||
ax.axis('off')
|
||||
plt.tight_layout()
|
||||
plt.savefig('tmp/img_{}.png'.format(epoch_num))
|
||||
plt.show()
|
||||
|
||||
def main(args:dict):
|
||||
seed = 42
|
||||
criterion = nn.losses.binary_cross_entropy
|
||||
n_epochs = 200
|
||||
z_dim = 64
|
||||
display_step = 500
|
||||
n_epochs = 500
|
||||
z_dim = 128
|
||||
batch_size = 128
|
||||
lr = 0.00001
|
||||
lr = 2e-5
|
||||
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Load the data
|
||||
train_images, train_labels, test_images, test_labels = map(
|
||||
mx.array, getattr(mnist, args.dataset)()
|
||||
)
|
||||
train_images,*_ = map(np.array, getattr(mnist,'mnist')())
|
||||
|
||||
# Normalization images => [-1,1]
|
||||
train_images = train_images * 2.0 - 1.0
|
||||
|
||||
gen = Generator(z_dim)
|
||||
gen_opt = optim.Adam(learning_rate=lr)
|
||||
mx.eval(gen.parameters())
|
||||
gen_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999])
|
||||
|
||||
disc = Discriminator()
|
||||
disc_opt = optim.Adam(learning_rate=lr)
|
||||
|
||||
# use partial function
|
||||
def disc_loss(gen, disc, criterion, real, num_images, z_dim):
|
||||
noise = get_noise(num_images, z_dim,device)
|
||||
fake_images = gen(noise)
|
||||
|
||||
fake_disc = disc(fake_images.detach())
|
||||
fake_labels = mx.zeros(fake_images.size(0),1)
|
||||
fake_loss = criterion(fake_disc,fake_labels)
|
||||
|
||||
real_disc = disc(real)
|
||||
real_labels = mx.ones(real.size(0),1)
|
||||
real_loss = criterion(real_disc,real_labels)
|
||||
|
||||
disc_loss = (fake_loss + real_loss) / 2
|
||||
|
||||
return disc_loss
|
||||
|
||||
def gen_loss(gen, disc, criterion, num_images, z_dim):
|
||||
|
||||
noise = get_noise(num_images, z_dim,device)
|
||||
fake_images = gen(noise)
|
||||
|
||||
fake_disc = disc(fake_images)
|
||||
fake_labels = mx.ones(fake_images.size(0),1)
|
||||
|
||||
gen_loss = criterion(fake_disc,fake_labels)
|
||||
|
||||
return gen_loss
|
||||
mx.eval(disc.parameters())
|
||||
disc_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999])
|
||||
|
||||
# TODO training...
|
||||
# Set your parameters
|
||||
n_epochs = 500
|
||||
display_step = 5000
|
||||
cur_step = 0
|
||||
|
||||
batch_size = 128 # 128
|
||||
|
||||
D_loss_grad = nn.value_and_grad(disc, disc_loss)
|
||||
G_loss_grad = nn.value_and_grad(gen, gen_loss)
|
||||
|
||||
|
||||
for epoch in tqdm(range(n_epochs)):
|
||||
|
||||
for idx,real in enumerate(batch_iterate(batch_size, train_images)):
|
||||
@ -194,7 +179,6 @@ def main(args:dict):
|
||||
fake = gen(fake_noise)
|
||||
show_images(epoch,fake)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Train a simple GAN on MNIST with MLX.")
|
||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||
|
Loading…
Reference in New Issue
Block a user