diff --git a/gan/main.py b/gan/main.py index 2ee2b38a..5aea6254 100644 --- a/gan/main.py +++ b/gan/main.py @@ -10,7 +10,6 @@ from tqdm import tqdm import numpy as np import matplotlib.pyplot as plt - # Generator Block def GenBlock(in_dim:int,out_dim:int): return nn.Sequential( @@ -111,6 +110,7 @@ def batch_iterate(batch_size: int, ipt: list[int])-> list[int]: ids = perm[s : s + batch_size] yield ipt[ids] +# plot batch of images at epoch steps def show_images(epoch_num:int,imgs:list[int],num_imgs:int = 25): if (imgs.shape[0] > 0): fig,axes = plt.subplots(5, 5, figsize=(5, 5)) @@ -120,7 +120,7 @@ def show_images(epoch_num:int,imgs:list[int],num_imgs:int = 25): ax.imshow(img,cmap='gray') ax.axis('off') plt.tight_layout() - plt.savefig('tmp/img_{}.png'.format(epoch_num)) + plt.savefig('gen_images/img_{}.png'.format(epoch_num)) plt.show() def main(args:dict):