diff --git a/flux/dreambooth.py b/flux/dreambooth.py index 4a9de6b1..5af20cd8 100644 --- a/flux/dreambooth.py +++ b/flux/dreambooth.py @@ -274,7 +274,11 @@ if __name__ == "__main__": loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( x, t5_feat, clip_feat, guidance ) - grads = tree_map(lambda a, b: a + b, prev_grads, grads) + grads = tree_map( + lambda a, b: (a + b) / args.grad_accumulate, + prev_grads, + grads, + ) grads = average_gradients(grads) optimizer.update(flux.flow, grads)