mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 11:54:37 +08:00
Fix gradient accumulation averaging
This commit is contained in:
parent
3c587ed618
commit
fc88e3b0d0
@ -274,7 +274,11 @@ if __name__ == "__main__":
|
||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||
x, t5_feat, clip_feat, guidance
|
||||
)
|
||||
grads = tree_map(lambda a, b: a + b, prev_grads, grads)
|
||||
grads = tree_map(
|
||||
lambda a, b: (a + b) / args.grad_accumulate,
|
||||
prev_grads,
|
||||
grads,
|
||||
)
|
||||
grads = average_gradients(grads)
|
||||
optimizer.update(flux.flow, grads)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user