Fix typo in average_gradients function call (#2594)

This commit is contained in:
Umberto Mignozzetti
2025-09-15 11:29:21 -07:00
committed by GitHub
parent 6ccfa603cd
commit 8afb6d62f2

View File

@@ -184,7 +184,7 @@ almost identical to the example above:
def step(model, x, y): def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y) loss, grads = loss_grad_fn(model, x, y)
grads = mlx.nn.average_gradients(grads) # <---- This line was added grads = mx.nn.average_gradients(grads) # <---- This line was added
optimizer.update(model, grads) optimizer.update(model, grads)
return loss return loss