mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Fix gradient accumulation averaging
This commit is contained in:
@@ -274,7 +274,11 @@ if __name__ == "__main__":
|
|||||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||||
x, t5_feat, clip_feat, guidance
|
x, t5_feat, clip_feat, guidance
|
||||||
)
|
)
|
||||||
grads = tree_map(lambda a, b: a + b, prev_grads, grads)
|
grads = tree_map(
|
||||||
|
lambda a, b: (a + b) / args.grad_accumulate,
|
||||||
|
prev_grads,
|
||||||
|
grads,
|
||||||
|
)
|
||||||
grads = average_gradients(grads)
|
grads = average_gradients(grads)
|
||||||
optimizer.update(flux.flow, grads)
|
optimizer.update(flux.flow, grads)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user