Fix gradient accumulation averaging

This commit is contained in:
Angelos Katharopoulos 2024-10-10 02:45:26 -07:00
parent 3c587ed618
commit fc88e3b0d0

View File

@ -274,7 +274,11 @@ if __name__ == "__main__":
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
x, t5_feat, clip_feat, guidance
)
grads = tree_map(lambda a, b: a + b, prev_grads, grads)
grads = tree_map(
lambda a, b: (a + b) / args.grad_accumulate,
prev_grads,
grads,
)
grads = average_gradients(grads)
optimizer.update(flux.flow, grads)