Updating GAN Code...

This commit is contained in:
Shubbair 2024-07-31 20:23:57 +03:00
parent 1ef3ad2c6c
commit 4d17f80efb
2 changed files with 69 additions and 48 deletions

View File

@ -120,6 +120,44 @@ def main(args:dict):
return gen_loss return gen_loss
# TODO training... # TODO training...
# Set your parameters
n_epochs = 500
display_step = 5000
cur_step = 0
batch_size = 128 # 128
D_loss_grad = nn.value_and_grad(disc, disc_loss)
G_loss_grad = nn.value_and_grad(gen, gen_loss)
for epoch in tqdm(range(n_epochs)):
for idx,real in enumerate(batch_iterate(batch_size, train_images)):
# TODO Train Discriminator
D_loss,D_grads = D_loss_grad(gen, disc,mx.array(real), batch_size, z_dim)
# Update optimizer
disc_opt.update(disc, D_grads)
# Update gradients
mx.eval(disc.parameters(), disc_opt.state)
# TODO Train Generator
G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)
# Update optimizer
gen_opt.update(gen, G_grads)
# Update gradients
mx.eval(gen.parameters(), gen_opt.state)
if epoch%100==0:
print("Epoch: {}, iteration: {}, Discriminator Loss:{}, Generator Loss: {}".format(epoch,idx,D_loss,G_loss))
fake_noise = mx.array(get_noise(batch_size, z_dim))
fake = gen(fake_noise)
show_images(epoch,fake)
if __name__ == "__main__": if __name__ == "__main__":

File diff suppressed because one or more lines are too long