mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-27 00:08:09 +08:00
Fix typo in average_gradients function call (#2594)
This commit is contained in:

committed by
GitHub

parent
6ccfa603cd
commit
8afb6d62f2
@@ -184,7 +184,7 @@ almost identical to the example above:
|
|||||||
|
|
||||||
def step(model, x, y):
|
def step(model, x, y):
|
||||||
loss, grads = loss_grad_fn(model, x, y)
|
loss, grads = loss_grad_fn(model, x, y)
|
||||||
grads = mlx.nn.average_gradients(grads) # <---- This line was added
|
grads = mx.nn.average_gradients(grads) # <---- This line was added
|
||||||
optimizer.update(model, grads)
|
optimizer.update(model, grads)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user