mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
fixed kwt skip connections
This commit is contained in:
parent
d4f7ecd851
commit
f59a36f94d
@ -40,29 +40,29 @@ python main.py --help
|
||||
|
||||
## Results
|
||||
|
||||
After training with the `kwt1` architecture for 100 epochs, you
|
||||
After training with the `kwt1` architecture for 10 epochs, you
|
||||
should see the following results:
|
||||
|
||||
```
|
||||
Epoch: 99 | avg. Train loss 0.581 | avg. Train acc 0.826 | Throughput: 677.37 samples/sec
|
||||
Epoch: 99 | Val acc 0.710
|
||||
Testing best model from Epoch 98
|
||||
Test acc -> 0.687
|
||||
Epoch: 9 | avg. Train loss 0.519 | avg. Train acc 0.857 | Throughput: 661.28 samples/sec
|
||||
Epoch: 9 | Val acc 0.861 | Throughput: 2976.54 samples/sec
|
||||
Testing best model from epoch 9
|
||||
Test acc -> 0.841
|
||||
```
|
||||
|
||||
For the `kwt2` model, you should see:
|
||||
```
|
||||
Epoch: 99 | avg. Train loss 0.137 | avg. Train acc 0.956 | Throughput: 401.47 samples/sec
|
||||
Epoch: 99 | Val acc 0.739
|
||||
Testing best model from Epoch 97
|
||||
Test acc -> 0.718
|
||||
Epoch: 9 | avg. Train loss 0.374 | avg. Train acc 0.895 | Throughput: 395.26 samples/sec
|
||||
Epoch: 9 | Val acc 0.879 | Throughput: 1542.44 samples/sec
|
||||
Testing best model from epoch 9
|
||||
Test acc -> 0.861
|
||||
```
|
||||
|
||||
Note that this was run on an M1 Macbook Pro with 16GB RAM.
|
||||
|
||||
At the time of writing, `mlx` doesn't have built-in `cosine` learning rate
|
||||
schedules, which is used along with the AdamW optimizer in the official
|
||||
implementaiton. We intend to update this example once these features are added,
|
||||
implementation. We intend to update this example once these features are added,
|
||||
as well as with appropriate data augmentations.
|
||||
|
||||
[^1]: Based one the paper [Keyword Transformer: A Self-Attention Model for Keyword Spotting](https://www.isca-speech.org/archive/interspeech_2021/berg21_interspeech.html)
|
||||
|
@ -47,10 +47,8 @@ class Block(nn.Module):
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.attn(x)
|
||||
x = self.norm1(x)
|
||||
x = self.ff(x)
|
||||
x = self.norm2(x)
|
||||
x = self.norm1(self.attn(x)) + x
|
||||
x = self.norm2(self.ff(x)) + x
|
||||
return x
|
||||
|
||||
|
||||
|
@ -93,7 +93,6 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
||||
)
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
mean_tr_loss = mx.mean(mx.array(losses))
|
||||
mean_tr_acc = mx.mean(mx.array(accs))
|
||||
@ -104,13 +103,18 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
||||
def test_epoch(model, test_iter):
|
||||
model.train(False)
|
||||
accs = []
|
||||
throughput = []
|
||||
for batch_counter, batch in enumerate(test_iter):
|
||||
x = mx.array(batch["audio"])
|
||||
y = mx.array(batch["label"])
|
||||
tic = time.perf_counter()
|
||||
acc = eval_fn(model, x, y)
|
||||
accs.append(acc.item())
|
||||
toc = time.perf_counter()
|
||||
throughput.append(x.shape[0] / (toc - tic))
|
||||
mean_acc = mx.mean(mx.array(accs))
|
||||
return mean_acc
|
||||
mean_throughput = mx.mean(mx.array(throughput))
|
||||
return mean_acc, mean_throughput
|
||||
|
||||
|
||||
def main(args):
|
||||
@ -141,8 +145,8 @@ def main(args):
|
||||
)
|
||||
)
|
||||
|
||||
val_acc = test_epoch(model, val_data)
|
||||
print(f"Epoch: {epoch} | Val acc {val_acc.item():.3f}")
|
||||
val_acc, val_throughput = test_epoch(model, val_data)
|
||||
print(f"Epoch: {epoch} | Val acc {val_acc.item():.3f} | Throughput: {val_throughput.item():.2f} samples/sec")
|
||||
|
||||
if val_acc >= best_acc:
|
||||
best_acc = val_acc
|
||||
@ -151,7 +155,7 @@ def main(args):
|
||||
print(f"Testing best model from epoch {best_epoch}")
|
||||
model.update(best_params)
|
||||
test_data = prepare_dataset(args.batch_size, "test")
|
||||
test_acc = test_epoch(model, test_data)
|
||||
test_acc, _ = test_epoch(model, test_data)
|
||||
print(f"Test acc -> {test_acc.item():.3f}")
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user