mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Code Arrangement
This commit is contained in:
parent
7e0bdacef3
commit
f84b231cf2
78
gan/main.py
78
gan/main.py
@ -104,68 +104,53 @@ def gen_loss(gen, disc, num_images, z_dim):
|
|||||||
|
|
||||||
return mx.mean(gen_loss)
|
return mx.mean(gen_loss)
|
||||||
|
|
||||||
|
# make batch of images
|
||||||
|
def batch_iterate(batch_size: int, ipt: list[int])-> list[int]:
|
||||||
|
perm = np.random.permutation(len(ipt))
|
||||||
|
for s in range(0, len(ipt), batch_size):
|
||||||
|
ids = perm[s : s + batch_size]
|
||||||
|
yield ipt[ids]
|
||||||
|
|
||||||
|
def show_images(epoch_num:int,imgs:list[int],num_imgs:int = 25):
|
||||||
|
if (imgs.shape[0] > 0):
|
||||||
|
fig,axes = plt.subplots(5, 5, figsize=(5, 5))
|
||||||
|
|
||||||
|
for i, ax in enumerate(axes.flat):
|
||||||
|
img = mx.array(imgs[i]).reshape(28,28)
|
||||||
|
ax.imshow(img,cmap='gray')
|
||||||
|
ax.axis('off')
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig('tmp/img_{}.png'.format(epoch_num))
|
||||||
|
plt.show()
|
||||||
|
|
||||||
def main(args:dict):
|
def main(args:dict):
|
||||||
seed = 42
|
seed = 42
|
||||||
criterion = nn.losses.binary_cross_entropy
|
n_epochs = 500
|
||||||
n_epochs = 200
|
z_dim = 128
|
||||||
z_dim = 64
|
|
||||||
display_step = 500
|
|
||||||
batch_size = 128
|
batch_size = 128
|
||||||
lr = 0.00001
|
lr = 2e-5
|
||||||
|
|
||||||
mx.random.seed(seed)
|
mx.random.seed(seed)
|
||||||
|
|
||||||
# Load the data
|
# Load the data
|
||||||
train_images, train_labels, test_images, test_labels = map(
|
train_images,*_ = map(np.array, getattr(mnist,'mnist')())
|
||||||
mx.array, getattr(mnist, args.dataset)()
|
|
||||||
)
|
# Normalization images => [-1,1]
|
||||||
|
train_images = train_images * 2.0 - 1.0
|
||||||
|
|
||||||
gen = Generator(z_dim)
|
gen = Generator(z_dim)
|
||||||
gen_opt = optim.Adam(learning_rate=lr)
|
mx.eval(gen.parameters())
|
||||||
|
gen_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999])
|
||||||
|
|
||||||
disc = Discriminator()
|
disc = Discriminator()
|
||||||
disc_opt = optim.Adam(learning_rate=lr)
|
mx.eval(disc.parameters())
|
||||||
|
disc_opt = optim.Adam(learning_rate=lr, betas=[0.5, 0.999])
|
||||||
# use partial function
|
|
||||||
def disc_loss(gen, disc, criterion, real, num_images, z_dim):
|
|
||||||
noise = get_noise(num_images, z_dim,device)
|
|
||||||
fake_images = gen(noise)
|
|
||||||
|
|
||||||
fake_disc = disc(fake_images.detach())
|
|
||||||
fake_labels = mx.zeros(fake_images.size(0),1)
|
|
||||||
fake_loss = criterion(fake_disc,fake_labels)
|
|
||||||
|
|
||||||
real_disc = disc(real)
|
|
||||||
real_labels = mx.ones(real.size(0),1)
|
|
||||||
real_loss = criterion(real_disc,real_labels)
|
|
||||||
|
|
||||||
disc_loss = (fake_loss + real_loss) / 2
|
|
||||||
|
|
||||||
return disc_loss
|
|
||||||
|
|
||||||
def gen_loss(gen, disc, criterion, num_images, z_dim):
|
|
||||||
|
|
||||||
noise = get_noise(num_images, z_dim,device)
|
|
||||||
fake_images = gen(noise)
|
|
||||||
|
|
||||||
fake_disc = disc(fake_images)
|
|
||||||
fake_labels = mx.ones(fake_images.size(0),1)
|
|
||||||
|
|
||||||
gen_loss = criterion(fake_disc,fake_labels)
|
|
||||||
|
|
||||||
return gen_loss
|
|
||||||
|
|
||||||
# TODO training...
|
# TODO training...
|
||||||
# Set your parameters
|
|
||||||
n_epochs = 500
|
|
||||||
display_step = 5000
|
|
||||||
cur_step = 0
|
|
||||||
|
|
||||||
batch_size = 128 # 128
|
|
||||||
|
|
||||||
D_loss_grad = nn.value_and_grad(disc, disc_loss)
|
D_loss_grad = nn.value_and_grad(disc, disc_loss)
|
||||||
G_loss_grad = nn.value_and_grad(gen, gen_loss)
|
G_loss_grad = nn.value_and_grad(gen, gen_loss)
|
||||||
|
|
||||||
|
|
||||||
for epoch in tqdm(range(n_epochs)):
|
for epoch in tqdm(range(n_epochs)):
|
||||||
|
|
||||||
for idx,real in enumerate(batch_iterate(batch_size, train_images)):
|
for idx,real in enumerate(batch_iterate(batch_size, train_images)):
|
||||||
@ -193,7 +178,6 @@ def main(args:dict):
|
|||||||
fake_noise = mx.array(get_noise(batch_size, z_dim))
|
fake_noise = mx.array(get_noise(batch_size, z_dim))
|
||||||
fake = gen(fake_noise)
|
fake = gen(fake_noise)
|
||||||
show_images(epoch,fake)
|
show_images(epoch,fake)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser("Train a simple GAN on MNIST with MLX.")
|
parser = argparse.ArgumentParser("Train a simple GAN on MNIST with MLX.")
|
||||||
|
Loading…
Reference in New Issue
Block a user