mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-26 15:58:14 +08:00
Fix typo in average_gradients function call (#2594)
This commit is contained in:

committed by
GitHub

parent
6ccfa603cd
commit
8afb6d62f2
@@ -184,7 +184,7 @@ almost identical to the example above:
|
||||
|
||||
def step(model, x, y):
|
||||
loss, grads = loss_grad_fn(model, x, y)
|
||||
grads = mlx.nn.average_gradients(grads) # <---- This line was added
|
||||
grads = mx.nn.average_gradients(grads) # <---- This line was added
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
|
Reference in New Issue
Block a user