fixed kwt skip connections

This commit is contained in:
Sarthak Yadav 2023-12-19 22:41:22 +01:00
parent d4f7ecd851
commit f59a36f94d
3 changed files with 21 additions and 19 deletions

View File

@ -40,29 +40,29 @@ python main.py --help
## Results
After training with the `kwt1` architecture for 100 epochs, you
After training with the `kwt1` architecture for 10 epochs, you
should see the following results:
```
Epoch: 99 | avg. Train loss 0.581 | avg. Train acc 0.826 | Throughput: 677.37 samples/sec
Epoch: 99 | Val acc 0.710
Testing best model from Epoch 98
Test acc -> 0.687
Epoch: 9 | avg. Train loss 0.519 | avg. Train acc 0.857 | Throughput: 661.28 samples/sec
Epoch: 9 | Val acc 0.861 | Throughput: 2976.54 samples/sec
Testing best model from epoch 9
Test acc -> 0.841
```
For the `kwt2` model, you should see:
```
Epoch: 99 | avg. Train loss 0.137 | avg. Train acc 0.956 | Throughput: 401.47 samples/sec
Epoch: 99 | Val acc 0.739
Testing best model from Epoch 97
Test acc -> 0.718
Epoch: 9 | avg. Train loss 0.374 | avg. Train acc 0.895 | Throughput: 395.26 samples/sec
Epoch: 9 | Val acc 0.879 | Throughput: 1542.44 samples/sec
Testing best model from epoch 9
Test acc -> 0.861
```
Note that this was run on an M1 Macbook Pro with 16GB RAM.
At the time of writing, `mlx` doesn't have built-in `cosine` learning rate
schedules, which is used along with the AdamW optimizer in the official
implementaiton. We intend to update this example once these features are added,
implementation. We intend to update this example once these features are added,
as well as with appropriate data augmentations.
[^1]: Based one the paper [Keyword Transformer: A Self-Attention Model for Keyword Spotting](https://www.isca-speech.org/archive/interspeech_2021/berg21_interspeech.html)

View File

@ -47,10 +47,8 @@ class Block(nn.Module):
self.norm2 = nn.LayerNorm(dim)
def __call__(self, x):
x = self.attn(x)
x = self.norm1(x)
x = self.ff(x)
x = self.norm2(x)
x = self.norm1(self.attn(x)) + x
x = self.norm2(self.ff(x)) + x
return x

View File

@ -93,7 +93,6 @@ def train_epoch(model, train_iter, optimizer, epoch):
)
)
)
break
mean_tr_loss = mx.mean(mx.array(losses))
mean_tr_acc = mx.mean(mx.array(accs))
@ -104,13 +103,18 @@ def train_epoch(model, train_iter, optimizer, epoch):
def test_epoch(model, test_iter):
model.train(False)
accs = []
throughput = []
for batch_counter, batch in enumerate(test_iter):
x = mx.array(batch["audio"])
y = mx.array(batch["label"])
tic = time.perf_counter()
acc = eval_fn(model, x, y)
accs.append(acc.item())
toc = time.perf_counter()
throughput.append(x.shape[0] / (toc - tic))
mean_acc = mx.mean(mx.array(accs))
return mean_acc
mean_throughput = mx.mean(mx.array(throughput))
return mean_acc, mean_throughput
def main(args):
@ -141,8 +145,8 @@ def main(args):
)
)
val_acc = test_epoch(model, val_data)
print(f"Epoch: {epoch} | Val acc {val_acc.item():.3f}")
val_acc, val_throughput = test_epoch(model, val_data)
print(f"Epoch: {epoch} | Val acc {val_acc.item():.3f} | Throughput: {val_throughput.item():.2f} samples/sec")
if val_acc >= best_acc:
best_acc = val_acc
@ -151,7 +155,7 @@ def main(args):
print(f"Testing best model from epoch {best_epoch}")
model.update(best_params)
test_data = prepare_dataset(args.batch_size, "test")
test_acc = test_epoch(model, test_data)
test_acc, _ = test_epoch(model, test_data)
print(f"Test acc -> {test_acc.item():.3f}")