typo / nits

This commit is contained in:
Awni Hannun 2023-12-14 12:14:01 -08:00
parent b1b9b11801
commit b9439ce74e
2 changed files with 2 additions and 2 deletions

View File

@ -47,5 +47,5 @@ Epoch: 99 | Test acc 0.807
Note this was run on an M1 Macbook Pro with 16GB RAM. Note this was run on an M1 Macbook Pro with 16GB RAM.
At the time of writing, `mlx` doesn't have built-in learning rate schedules, At the time of writing, `mlx` doesn't have built-in learning rate schedules,
nor a `BatchNorm` layer. We intend to update this example once these features or a `BatchNorm` layer. We intend to update this example once these features
are added. are added.

View File

@ -65,7 +65,7 @@ def train_epoch(model, train_iter, optimizer, epoch):
) )
) )
eean_tr_loss = mx.mean(mx.array(losses)) mean_tr_loss = mx.mean(mx.array(losses))
mean_tr_acc = mx.mean(mx.array(accs)) mean_tr_acc = mx.mean(mx.array(accs))
samples_per_sec = mx.mean(mx.array(samples_per_sec)) samples_per_sec = mx.mean(mx.array(samples_per_sec))
return mean_tr_loss, mean_tr_acc, samples_per_sec return mean_tr_loss, mean_tr_acc, samples_per_sec