mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Updating GAN Code...
This commit is contained in:
parent
1ef3ad2c6c
commit
4d17f80efb
38
gan/main.py
38
gan/main.py
@ -120,6 +120,44 @@ def main(args:dict):
|
||||
return gen_loss
|
||||
|
||||
# TODO training...
|
||||
# Set your parameters
|
||||
n_epochs = 500
|
||||
display_step = 5000
|
||||
cur_step = 0
|
||||
|
||||
batch_size = 128 # 128
|
||||
|
||||
D_loss_grad = nn.value_and_grad(disc, disc_loss)
|
||||
G_loss_grad = nn.value_and_grad(gen, gen_loss)
|
||||
|
||||
|
||||
for epoch in tqdm(range(n_epochs)):
|
||||
|
||||
for idx,real in enumerate(batch_iterate(batch_size, train_images)):
|
||||
|
||||
# TODO Train Discriminator
|
||||
D_loss,D_grads = D_loss_grad(gen, disc,mx.array(real), batch_size, z_dim)
|
||||
|
||||
# Update optimizer
|
||||
disc_opt.update(disc, D_grads)
|
||||
|
||||
# Update gradients
|
||||
mx.eval(disc.parameters(), disc_opt.state)
|
||||
|
||||
# TODO Train Generator
|
||||
G_loss,G_grads = G_loss_grad(gen, disc, batch_size, z_dim)
|
||||
|
||||
# Update optimizer
|
||||
gen_opt.update(gen, G_grads)
|
||||
|
||||
# Update gradients
|
||||
mx.eval(gen.parameters(), gen_opt.state)
|
||||
|
||||
if epoch%100==0:
|
||||
print("Epoch: {}, iteration: {}, Discriminator Loss:{}, Generator Loss: {}".format(epoch,idx,D_loss,G_loss))
|
||||
fake_noise = mx.array(get_noise(batch_size, z_dim))
|
||||
fake = gen(fake_noise)
|
||||
show_images(epoch,fake)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user