mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 18:11:17 +08:00
Code Arrangement
This commit is contained in:
parent
f84b231cf2
commit
a5752be9d9
@ -10,7 +10,6 @@ from tqdm import tqdm
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
# Generator Block
|
# Generator Block
|
||||||
def GenBlock(in_dim:int,out_dim:int):
|
def GenBlock(in_dim:int,out_dim:int):
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
@ -111,6 +110,7 @@ def batch_iterate(batch_size: int, ipt: list[int])-> list[int]:
|
|||||||
ids = perm[s : s + batch_size]
|
ids = perm[s : s + batch_size]
|
||||||
yield ipt[ids]
|
yield ipt[ids]
|
||||||
|
|
||||||
|
# plot batch of images at epoch steps
|
||||||
def show_images(epoch_num:int,imgs:list[int],num_imgs:int = 25):
|
def show_images(epoch_num:int,imgs:list[int],num_imgs:int = 25):
|
||||||
if (imgs.shape[0] > 0):
|
if (imgs.shape[0] > 0):
|
||||||
fig,axes = plt.subplots(5, 5, figsize=(5, 5))
|
fig,axes = plt.subplots(5, 5, figsize=(5, 5))
|
||||||
@ -120,7 +120,7 @@ def show_images(epoch_num:int,imgs:list[int],num_imgs:int = 25):
|
|||||||
ax.imshow(img,cmap='gray')
|
ax.imshow(img,cmap='gray')
|
||||||
ax.axis('off')
|
ax.axis('off')
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.savefig('tmp/img_{}.png'.format(epoch_num))
|
plt.savefig('gen_images/img_{}.png'.format(epoch_num))
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
def main(args:dict):
|
def main(args:dict):
|
||||||
|
Loading…
Reference in New Issue
Block a user