Code Arrangement

This commit is contained in:
Shubbair 2024-08-01 15:29:43 +03:00
parent 7e0bdacef3
commit f84b231cf2

View File

@ -104,68 +104,53 @@ def gen_loss(gen, disc, num_images, z_dim):
return mx.mean(gen_loss) return mx.mean(gen_loss)
# make batch of images
def batch_iterate(batch_size: int, ipt: list[int])-> list[int]:
perm = np.random.permutation(len(ipt))
for s in range(0, len(ipt), batch_size):
ids = perm[s : s + batch_size]
yield ipt[ids]
def show_images(epoch_num:int,imgs:list[int],num_imgs:int = 25):
if (imgs.shape[0] > 0):
fig,axes = plt.subplots(5, 5, figsize=(5, 5))
for i, ax in enumerate(axes.flat):
img = mx.array(imgs[i]).reshape(28,28)
ax.imshow(img,cmap='gray')
ax.axis('off')
plt.tight_layout()
plt.savefig('tmp/img_{}.png'.format(epoch_num))
plt.show()
def main(args:dict): def main(args:dict):
seed = 42 seed = 42
criterion = nn.losses.binary_cross_entropy n_epochs = 500
n_epochs = 200 z_dim = 128
z_dim = 64
display_step = 500
batch_size = 128 batch_size = 128
lr = 0.00001 lr = 2e-5
mx.random.seed(seed) mx.random.seed(seed)
# Load the data # Load the data
train_images, train_labels, test_images, test_labels = map( train_images,*_ = map(np.array, getattr(mnist,'mnist')())
mx.array, getattr(mnist, args.dataset)()
) # Normalization images => [-1,1]
train_images = train_images * 2.0 - 1.0
gen = Generator(z_dim) gen = Generator(z_dim)
gen_opt = optim.Adam(learning_rate=lr) mx.eval(gen.parameters())
gen_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999])
disc = Discriminator() disc = Discriminator()
disc_opt = optim.Adam(learning_rate=lr) mx.eval(disc.parameters())
disc_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999])
# use partial function
def disc_loss(gen, disc, criterion, real, num_images, z_dim):
noise = get_noise(num_images, z_dim,device)
fake_images = gen(noise)
fake_disc = disc(fake_images.detach())
fake_labels = mx.zeros(fake_images.size(0),1)
fake_loss = criterion(fake_disc,fake_labels)
real_disc = disc(real)
real_labels = mx.ones(real.size(0),1)
real_loss = criterion(real_disc,real_labels)
disc_loss = (fake_loss + real_loss) / 2
return disc_loss
def gen_loss(gen, disc, criterion, num_images, z_dim):
noise = get_noise(num_images, z_dim,device)
fake_images = gen(noise)
fake_disc = disc(fake_images)
fake_labels = mx.ones(fake_images.size(0),1)
gen_loss = criterion(fake_disc,fake_labels)
return gen_loss
# TODO training... # TODO training...
# Set your parameters
n_epochs = 500
display_step = 5000
cur_step = 0
batch_size = 128 # 128
D_loss_grad = nn.value_and_grad(disc, disc_loss) D_loss_grad = nn.value_and_grad(disc, disc_loss)
G_loss_grad = nn.value_and_grad(gen, gen_loss) G_loss_grad = nn.value_and_grad(gen, gen_loss)
for epoch in tqdm(range(n_epochs)): for epoch in tqdm(range(n_epochs)):
for idx,real in enumerate(batch_iterate(batch_size, train_images)): for idx,real in enumerate(batch_iterate(batch_size, train_images)):
@ -193,7 +178,6 @@ def main(args:dict):
fake_noise = mx.array(get_noise(batch_size, z_dim)) fake_noise = mx.array(get_noise(batch_size, z_dim))
fake = gen(fake_noise) fake = gen(fake_noise)
show_images(epoch,fake) show_images(epoch,fake)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser("Train a simple GAN on MNIST with MLX.") parser = argparse.ArgumentParser("Train a simple GAN on MNIST with MLX.")