mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Code Arrangement
This commit is contained in:
parent
f84b231cf2
commit
a5752be9d9
@ -10,7 +10,6 @@ from tqdm import tqdm
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
# Generator Block
|
||||
def GenBlock(in_dim:int,out_dim:int):
|
||||
return nn.Sequential(
|
||||
@ -111,6 +110,7 @@ def batch_iterate(batch_size: int, ipt: list[int])-> list[int]:
|
||||
ids = perm[s : s + batch_size]
|
||||
yield ipt[ids]
|
||||
|
||||
# plot batch of images at epoch steps
|
||||
def show_images(epoch_num:int,imgs:list[int],num_imgs:int = 25):
|
||||
if (imgs.shape[0] > 0):
|
||||
fig,axes = plt.subplots(5, 5, figsize=(5, 5))
|
||||
@ -120,7 +120,7 @@ def show_images(epoch_num:int,imgs:list[int],num_imgs:int = 25):
|
||||
ax.imshow(img,cmap='gray')
|
||||
ax.axis('off')
|
||||
plt.tight_layout()
|
||||
plt.savefig('tmp/img_{}.png'.format(epoch_num))
|
||||
plt.savefig('gen_images/img_{}.png'.format(epoch_num))
|
||||
plt.show()
|
||||
|
||||
def main(args:dict):
|
||||
|
Loading…
Reference in New Issue
Block a user