diff --git a/gan/main.py b/gan/main.py index dd794289..2ee2b38a 100644 --- a/gan/main.py +++ b/gan/main.py @@ -104,68 +104,53 @@ def gen_loss(gen, disc, num_images, z_dim): return mx.mean(gen_loss) +# make batch of images +def batch_iterate(batch_size: int, ipt: list[int])-> list[int]: + perm = np.random.permutation(len(ipt)) + for s in range(0, len(ipt), batch_size): + ids = perm[s : s + batch_size] + yield ipt[ids] + +def show_images(epoch_num:int,imgs:list[int],num_imgs:int = 25): + if (imgs.shape[0] > 0): + fig,axes = plt.subplots(5, 5, figsize=(5, 5)) + + for i, ax in enumerate(axes.flat): + img = mx.array(imgs[i]).reshape(28,28) + ax.imshow(img,cmap='gray') + ax.axis('off') + plt.tight_layout() + plt.savefig('tmp/img_{}.png'.format(epoch_num)) + plt.show() + def main(args:dict): seed = 42 - criterion = nn.losses.binary_cross_entropy - n_epochs = 200 - z_dim = 64 - display_step = 500 + n_epochs = 500 + z_dim = 128 batch_size = 128 - lr = 0.00001 + lr = 2e-5 mx.random.seed(seed) # Load the data - train_images, train_labels, test_images, test_labels = map( - mx.array, getattr(mnist, args.dataset)() - ) + train_images,*_ = map(np.array, getattr(mnist,'mnist')()) + + # Normalization images => [-1,1] + train_images = train_images * 2.0 - 1.0 gen = Generator(z_dim) - gen_opt = optim.Adam(learning_rate=lr) + mx.eval(gen.parameters()) + gen_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999]) + disc = Discriminator() - disc_opt = optim.Adam(learning_rate=lr) - - # use partial function - def disc_loss(gen, disc, criterion, real, num_images, z_dim): - noise = get_noise(num_images, z_dim,device) - fake_images = gen(noise) - - fake_disc = disc(fake_images.detach()) - fake_labels = mx.zeros(fake_images.size(0),1) - fake_loss = criterion(fake_disc,fake_labels) - - real_disc = disc(real) - real_labels = mx.ones(real.size(0),1) - real_loss = criterion(real_disc,real_labels) - - disc_loss = (fake_loss + real_loss) / 2 - - return disc_loss - - def gen_loss(gen, disc, criterion, num_images, z_dim): - - noise = get_noise(num_images, z_dim,device) - fake_images = gen(noise) - - fake_disc = disc(fake_images) - fake_labels = mx.ones(fake_images.size(0),1) - - gen_loss = criterion(fake_disc,fake_labels) - - return gen_loss + mx.eval(disc.parameters()) + disc_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999]) # TODO training... - # Set your parameters - n_epochs = 500 - display_step = 5000 - cur_step = 0 - - batch_size = 128 # 128 D_loss_grad = nn.value_and_grad(disc, disc_loss) G_loss_grad = nn.value_and_grad(gen, gen_loss) - for epoch in tqdm(range(n_epochs)): for idx,real in enumerate(batch_iterate(batch_size, train_images)): @@ -193,7 +178,6 @@ def main(args:dict): fake_noise = mx.array(get_noise(batch_size, z_dim)) fake = gen(fake_noise) show_images(epoch,fake) - if __name__ == "__main__": parser = argparse.ArgumentParser("Train a simple GAN on MNIST with MLX.")