fixed kwt skip connections

This commit is contained in:
Sarthak Yadav 2023-12-19 22:41:22 +01:00
parent d4f7ecd851
commit f59a36f94d
3 changed files with 21 additions and 19 deletions

View File

@ -40,29 +40,29 @@ python main.py --help
## Results ## Results
After training with the `kwt1` architecture for 100 epochs, you After training with the `kwt1` architecture for 10 epochs, you
should see the following results: should see the following results:
``` ```
Epoch: 99 | avg. Train loss 0.581 | avg. Train acc 0.826 | Throughput: 677.37 samples/sec Epoch: 9 | avg. Train loss 0.519 | avg. Train acc 0.857 | Throughput: 661.28 samples/sec
Epoch: 99 | Val acc 0.710 Epoch: 9 | Val acc 0.861 | Throughput: 2976.54 samples/sec
Testing best model from Epoch 98 Testing best model from epoch 9
Test acc -> 0.687 Test acc -> 0.841
``` ```
For the `kwt2` model, you should see: For the `kwt2` model, you should see:
``` ```
Epoch: 99 | avg. Train loss 0.137 | avg. Train acc 0.956 | Throughput: 401.47 samples/sec Epoch: 9 | avg. Train loss 0.374 | avg. Train acc 0.895 | Throughput: 395.26 samples/sec
Epoch: 99 | Val acc 0.739 Epoch: 9 | Val acc 0.879 | Throughput: 1542.44 samples/sec
Testing best model from Epoch 97 Testing best model from epoch 9
Test acc -> 0.718 Test acc -> 0.861
``` ```
Note that this was run on an M1 Macbook Pro with 16GB RAM. Note that this was run on an M1 Macbook Pro with 16GB RAM.
At the time of writing, `mlx` doesn't have built-in `cosine` learning rate At the time of writing, `mlx` doesn't have built-in `cosine` learning rate
schedules, which is used along with the AdamW optimizer in the official schedules, which is used along with the AdamW optimizer in the official
implementaiton. We intend to update this example once these features are added, implementation. We intend to update this example once these features are added,
as well as with appropriate data augmentations. as well as with appropriate data augmentations.
[^1]: Based one the paper [Keyword Transformer: A Self-Attention Model for Keyword Spotting](https://www.isca-speech.org/archive/interspeech_2021/berg21_interspeech.html) [^1]: Based one the paper [Keyword Transformer: A Self-Attention Model for Keyword Spotting](https://www.isca-speech.org/archive/interspeech_2021/berg21_interspeech.html)

View File

@ -47,10 +47,8 @@ class Block(nn.Module):
self.norm2 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim)
def __call__(self, x): def __call__(self, x):
x = self.attn(x) x = self.norm1(self.attn(x)) + x
x = self.norm1(x) x = self.norm2(self.ff(x)) + x
x = self.ff(x)
x = self.norm2(x)
return x return x

View File

@ -93,7 +93,6 @@ def train_epoch(model, train_iter, optimizer, epoch):
) )
) )
) )
break
mean_tr_loss = mx.mean(mx.array(losses)) mean_tr_loss = mx.mean(mx.array(losses))
mean_tr_acc = mx.mean(mx.array(accs)) mean_tr_acc = mx.mean(mx.array(accs))
@ -104,13 +103,18 @@ def train_epoch(model, train_iter, optimizer, epoch):
def test_epoch(model, test_iter): def test_epoch(model, test_iter):
model.train(False) model.train(False)
accs = [] accs = []
throughput = []
for batch_counter, batch in enumerate(test_iter): for batch_counter, batch in enumerate(test_iter):
x = mx.array(batch["audio"]) x = mx.array(batch["audio"])
y = mx.array(batch["label"]) y = mx.array(batch["label"])
tic = time.perf_counter()
acc = eval_fn(model, x, y) acc = eval_fn(model, x, y)
accs.append(acc.item()) accs.append(acc.item())
toc = time.perf_counter()
throughput.append(x.shape[0] / (toc - tic))
mean_acc = mx.mean(mx.array(accs)) mean_acc = mx.mean(mx.array(accs))
return mean_acc mean_throughput = mx.mean(mx.array(throughput))
return mean_acc, mean_throughput
def main(args): def main(args):
@ -141,8 +145,8 @@ def main(args):
) )
) )
val_acc = test_epoch(model, val_data) val_acc, val_throughput = test_epoch(model, val_data)
print(f"Epoch: {epoch} | Val acc {val_acc.item():.3f}") print(f"Epoch: {epoch} | Val acc {val_acc.item():.3f} | Throughput: {val_throughput.item():.2f} samples/sec")
if val_acc >= best_acc: if val_acc >= best_acc:
best_acc = val_acc best_acc = val_acc
@ -151,7 +155,7 @@ def main(args):
print(f"Testing best model from epoch {best_epoch}") print(f"Testing best model from epoch {best_epoch}")
model.update(best_params) model.update(best_params)
test_data = prepare_dataset(args.batch_size, "test") test_data = prepare_dataset(args.batch_size, "test")
test_acc = test_epoch(model, test_data) test_acc, _ = test_epoch(model, test_data)
print(f"Test acc -> {test_acc.item():.3f}") print(f"Test acc -> {test_acc.item():.3f}")