diff --git a/speechcommands/README.md b/speechcommands/README.md index 9e482b94..e8b2243a 100644 --- a/speechcommands/README.md +++ b/speechcommands/README.md @@ -40,29 +40,29 @@ python main.py --help ## Results -After training with the `kwt1` architecture for 100 epochs, you +After training with the `kwt1` architecture for 10 epochs, you should see the following results: ``` -Epoch: 99 | avg. Train loss 0.581 | avg. Train acc 0.826 | Throughput: 677.37 samples/sec -Epoch: 99 | Val acc 0.710 -Testing best model from Epoch 98 -Test acc -> 0.687 +Epoch: 9 | avg. Train loss 0.519 | avg. Train acc 0.857 | Throughput: 661.28 samples/sec +Epoch: 9 | Val acc 0.861 | Throughput: 2976.54 samples/sec +Testing best model from epoch 9 +Test acc -> 0.841 ``` For the `kwt2` model, you should see: ``` -Epoch: 99 | avg. Train loss 0.137 | avg. Train acc 0.956 | Throughput: 401.47 samples/sec -Epoch: 99 | Val acc 0.739 -Testing best model from Epoch 97 -Test acc -> 0.718 +Epoch: 9 | avg. Train loss 0.374 | avg. Train acc 0.895 | Throughput: 395.26 samples/sec +Epoch: 9 | Val acc 0.879 | Throughput: 1542.44 samples/sec +Testing best model from epoch 9 +Test acc -> 0.861 ``` Note that this was run on an M1 Macbook Pro with 16GB RAM. At the time of writing, `mlx` doesn't have built-in `cosine` learning rate schedules, which is used along with the AdamW optimizer in the official -implementaiton. We intend to update this example once these features are added, +implementation. We intend to update this example once these features are added, as well as with appropriate data augmentations. [^1]: Based one the paper [Keyword Transformer: A Self-Attention Model for Keyword Spotting](https://www.isca-speech.org/archive/interspeech_2021/berg21_interspeech.html) diff --git a/speechcommands/kwt.py b/speechcommands/kwt.py index f68fd632..63d4e074 100644 --- a/speechcommands/kwt.py +++ b/speechcommands/kwt.py @@ -47,10 +47,8 @@ class Block(nn.Module): self.norm2 = nn.LayerNorm(dim) def __call__(self, x): - x = self.attn(x) - x = self.norm1(x) - x = self.ff(x) - x = self.norm2(x) + x = self.norm1(self.attn(x)) + x + x = self.norm2(self.ff(x)) + x return x diff --git a/speechcommands/main.py b/speechcommands/main.py index a02cb089..3890baa4 100644 --- a/speechcommands/main.py +++ b/speechcommands/main.py @@ -93,7 +93,6 @@ def train_epoch(model, train_iter, optimizer, epoch): ) ) ) - break mean_tr_loss = mx.mean(mx.array(losses)) mean_tr_acc = mx.mean(mx.array(accs)) @@ -104,13 +103,18 @@ def train_epoch(model, train_iter, optimizer, epoch): def test_epoch(model, test_iter): model.train(False) accs = [] + throughput = [] for batch_counter, batch in enumerate(test_iter): x = mx.array(batch["audio"]) y = mx.array(batch["label"]) + tic = time.perf_counter() acc = eval_fn(model, x, y) accs.append(acc.item()) + toc = time.perf_counter() + throughput.append(x.shape[0] / (toc - tic)) mean_acc = mx.mean(mx.array(accs)) - return mean_acc + mean_throughput = mx.mean(mx.array(throughput)) + return mean_acc, mean_throughput def main(args): @@ -141,8 +145,8 @@ def main(args): ) ) - val_acc = test_epoch(model, val_data) - print(f"Epoch: {epoch} | Val acc {val_acc.item():.3f}") + val_acc, val_throughput = test_epoch(model, val_data) + print(f"Epoch: {epoch} | Val acc {val_acc.item():.3f} | Throughput: {val_throughput.item():.2f} samples/sec") if val_acc >= best_acc: best_acc = val_acc @@ -151,7 +155,7 @@ def main(args): print(f"Testing best model from epoch {best_epoch}") model.update(best_params) test_data = prepare_dataset(args.batch_size, "test") - test_acc = test_epoch(model, test_data) + test_acc, _ = test_epoch(model, test_data) print(f"Test acc -> {test_acc.item():.3f}")