mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
fixed kwt skip connections
This commit is contained in:
parent
d4f7ecd851
commit
f59a36f94d
@ -40,29 +40,29 @@ python main.py --help
|
|||||||
|
|
||||||
## Results
|
## Results
|
||||||
|
|
||||||
After training with the `kwt1` architecture for 100 epochs, you
|
After training with the `kwt1` architecture for 10 epochs, you
|
||||||
should see the following results:
|
should see the following results:
|
||||||
|
|
||||||
```
|
```
|
||||||
Epoch: 99 | avg. Train loss 0.581 | avg. Train acc 0.826 | Throughput: 677.37 samples/sec
|
Epoch: 9 | avg. Train loss 0.519 | avg. Train acc 0.857 | Throughput: 661.28 samples/sec
|
||||||
Epoch: 99 | Val acc 0.710
|
Epoch: 9 | Val acc 0.861 | Throughput: 2976.54 samples/sec
|
||||||
Testing best model from Epoch 98
|
Testing best model from epoch 9
|
||||||
Test acc -> 0.687
|
Test acc -> 0.841
|
||||||
```
|
```
|
||||||
|
|
||||||
For the `kwt2` model, you should see:
|
For the `kwt2` model, you should see:
|
||||||
```
|
```
|
||||||
Epoch: 99 | avg. Train loss 0.137 | avg. Train acc 0.956 | Throughput: 401.47 samples/sec
|
Epoch: 9 | avg. Train loss 0.374 | avg. Train acc 0.895 | Throughput: 395.26 samples/sec
|
||||||
Epoch: 99 | Val acc 0.739
|
Epoch: 9 | Val acc 0.879 | Throughput: 1542.44 samples/sec
|
||||||
Testing best model from Epoch 97
|
Testing best model from epoch 9
|
||||||
Test acc -> 0.718
|
Test acc -> 0.861
|
||||||
```
|
```
|
||||||
|
|
||||||
Note that this was run on an M1 Macbook Pro with 16GB RAM.
|
Note that this was run on an M1 Macbook Pro with 16GB RAM.
|
||||||
|
|
||||||
At the time of writing, `mlx` doesn't have built-in `cosine` learning rate
|
At the time of writing, `mlx` doesn't have built-in `cosine` learning rate
|
||||||
schedules, which is used along with the AdamW optimizer in the official
|
schedules, which is used along with the AdamW optimizer in the official
|
||||||
implementaiton. We intend to update this example once these features are added,
|
implementation. We intend to update this example once these features are added,
|
||||||
as well as with appropriate data augmentations.
|
as well as with appropriate data augmentations.
|
||||||
|
|
||||||
[^1]: Based one the paper [Keyword Transformer: A Self-Attention Model for Keyword Spotting](https://www.isca-speech.org/archive/interspeech_2021/berg21_interspeech.html)
|
[^1]: Based one the paper [Keyword Transformer: A Self-Attention Model for Keyword Spotting](https://www.isca-speech.org/archive/interspeech_2021/berg21_interspeech.html)
|
||||||
|
@ -47,10 +47,8 @@ class Block(nn.Module):
|
|||||||
self.norm2 = nn.LayerNorm(dim)
|
self.norm2 = nn.LayerNorm(dim)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
x = self.attn(x)
|
x = self.norm1(self.attn(x)) + x
|
||||||
x = self.norm1(x)
|
x = self.norm2(self.ff(x)) + x
|
||||||
x = self.ff(x)
|
|
||||||
x = self.norm2(x)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@ -93,7 +93,6 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
break
|
|
||||||
|
|
||||||
mean_tr_loss = mx.mean(mx.array(losses))
|
mean_tr_loss = mx.mean(mx.array(losses))
|
||||||
mean_tr_acc = mx.mean(mx.array(accs))
|
mean_tr_acc = mx.mean(mx.array(accs))
|
||||||
@ -104,13 +103,18 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
|||||||
def test_epoch(model, test_iter):
|
def test_epoch(model, test_iter):
|
||||||
model.train(False)
|
model.train(False)
|
||||||
accs = []
|
accs = []
|
||||||
|
throughput = []
|
||||||
for batch_counter, batch in enumerate(test_iter):
|
for batch_counter, batch in enumerate(test_iter):
|
||||||
x = mx.array(batch["audio"])
|
x = mx.array(batch["audio"])
|
||||||
y = mx.array(batch["label"])
|
y = mx.array(batch["label"])
|
||||||
|
tic = time.perf_counter()
|
||||||
acc = eval_fn(model, x, y)
|
acc = eval_fn(model, x, y)
|
||||||
accs.append(acc.item())
|
accs.append(acc.item())
|
||||||
|
toc = time.perf_counter()
|
||||||
|
throughput.append(x.shape[0] / (toc - tic))
|
||||||
mean_acc = mx.mean(mx.array(accs))
|
mean_acc = mx.mean(mx.array(accs))
|
||||||
return mean_acc
|
mean_throughput = mx.mean(mx.array(throughput))
|
||||||
|
return mean_acc, mean_throughput
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
@ -141,8 +145,8 @@ def main(args):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
val_acc = test_epoch(model, val_data)
|
val_acc, val_throughput = test_epoch(model, val_data)
|
||||||
print(f"Epoch: {epoch} | Val acc {val_acc.item():.3f}")
|
print(f"Epoch: {epoch} | Val acc {val_acc.item():.3f} | Throughput: {val_throughput.item():.2f} samples/sec")
|
||||||
|
|
||||||
if val_acc >= best_acc:
|
if val_acc >= best_acc:
|
||||||
best_acc = val_acc
|
best_acc = val_acc
|
||||||
@ -151,7 +155,7 @@ def main(args):
|
|||||||
print(f"Testing best model from epoch {best_epoch}")
|
print(f"Testing best model from epoch {best_epoch}")
|
||||||
model.update(best_params)
|
model.update(best_params)
|
||||||
test_data = prepare_dataset(args.batch_size, "test")
|
test_data = prepare_dataset(args.batch_size, "test")
|
||||||
test_acc = test_epoch(model, test_data)
|
test_acc, _ = test_epoch(model, test_data)
|
||||||
print(f"Test acc -> {test_acc.item():.3f}")
|
print(f"Test acc -> {test_acc.item():.3f}")
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user