Added Keyword Spotting Transformer + SpeechCommands example (#123)

* Added Keyword Transformer + SpeechCommands

* minor fixes in README

* some updates / simplifications

* nits

* fixed kwt skip connections

* readme + format

* updated acknowledgements

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Sarthak Yadav 2023-12-19 23:17:48 +01:00 committed by GitHub
parent ebbb7083cc
commit b6e62caf2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 453 additions and 0 deletions

View File

@ -8,3 +8,4 @@ with a short description of your contribution(s) below. For example:
MLX Examples was developed with contributions from the following individuals: MLX Examples was developed with contributions from the following individuals:
- Juarez Bochi: Added support for T5 models. - Juarez Bochi: Added support for T5 models.
- Sarthak Yadav: Added the `cifar` and `speechcommands` examples.

69
speechcommands/README.md Normal file
View File

@ -0,0 +1,69 @@
# Train a Keyword Spotting Transformer on Speech Commands
An example of training a Keyword Spotting Transformer[^1] on the Speech
Commands dataset[^2] with MLX. All supervised only configurations from the
paper are available. The example also illustrates how to use [MLX
Data](https://github.com/ml-explore/mlx-data) to load and process an audio
dataset.
## Pre-requisites
Follow the [installation
instructions](https://ml-explore.github.io/mlx-data/build/html/install.html)
for MLX Data.
Install the remaining python requirements:
```
pip install -r requirements.txt
```
## Running the example
Run the example with:
```
python main.py
```
By default the example runs on the GPU. To run it on the CPU, use:
```
python main.py --cpu
```
For all available options, run:
```
python main.py --help
```
## Results
After training with the `kwt1` architecture for 10 epochs, you
should see the following results:
```
Epoch: 9 | avg. Train loss 0.519 | avg. Train acc 0.857 | Throughput: 661.28 samples/sec
Epoch: 9 | Val acc 0.861 | Throughput: 2976.54 samples/sec
Testing best model from epoch 9
Test acc -> 0.841
```
For the `kwt2` model, you should see:
```
Epoch: 9 | avg. Train loss 0.374 | avg. Train acc 0.895 | Throughput: 395.26 samples/sec
Epoch: 9 | Val acc 0.879 | Throughput: 1542.44 samples/sec
Testing best model from epoch 9
Test acc -> 0.861
```
Note that this was run on an M1 Macbook Pro with 16GB RAM.
At the time of writing, `mlx` doesn't have built-in `cosine` learning rate
schedules, which is used along with the AdamW optimizer in the official
implementation. We intend to update this example once these features are added,
as well as with appropriate data augmentations.
[^1]: Based one the paper [Keyword Transformer: A Self-Attention Model for Keyword Spotting](https://www.isca-speech.org/archive/interspeech_2021/berg21_interspeech.html)
[^2]: We use version 0.02. See the [paper]((https://arxiv.org/abs/1804.03209) for more details.

214
speechcommands/kwt.py Normal file
View File

@ -0,0 +1,214 @@
from typing import Any
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten
__all__ = ["KWT", "kwt1", "kwt2", "kwt3"]
class FeedForward(nn.Sequential):
def __init__(self, dim, hidden_dim, dropout=0.0):
super().__init__(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout),
)
class Attention(nn.Module):
def __init__(self, dim, heads, dropout=0.0):
super().__init__()
self.heads = heads
self.scale = dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.out = nn.Sequential(nn.Linear(dim, dim), nn.Dropout(dropout))
def __call__(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.qkv(x)
qkv = qkv.reshape(b, n, 3, h, -1).transpose(2, 0, 3, 1, 4)
q, k, v = qkv
attn = (q @ k.transpose(0, 1, 3, 2)) * self.scale
attn = mx.softmax(attn, axis=-1)
x = (attn @ v).transpose(0, 2, 1, 3).reshape(b, n, -1)
x = self.out(x)
return x
class Block(nn.Module):
def __init__(self, dim, heads, mlp_dim, dropout=0.0):
super().__init__()
self.attn = Attention(dim, heads, dropout=dropout)
self.norm1 = nn.LayerNorm(dim)
self.ff = FeedForward(dim, mlp_dim, dropout=dropout)
self.norm2 = nn.LayerNorm(dim)
def __call__(self, x):
x = self.norm1(self.attn(x)) + x
x = self.norm2(self.ff(x)) + x
return x
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, mlp_dim, dropout=0.0):
super().__init__()
self.layers = []
for _ in range(depth):
self.layers.append(Block(dim, heads, mlp_dim, dropout=dropout))
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x
class KWT(nn.Module):
"""
Implements the Keyword Transformer (KWT) [1] model.
KWT is essentially a vision transformer [2] with minor modifications:
- Instead of square patches, KWT uses rectangular patches -> a patch
across frequency for every timestep
- KWT modules apply layer normalization after attention/feedforward layers
[1] https://arxiv.org/abs/2104.11178
[2] https://arxiv.org/abs/2010.11929
Parameters
----------
input_res: tuple of ints
Input resolution (time, frequency)
patch_res: tuple of ints
Patch resolution (time, frequency)
num_classes: int
Number of classes
dim: int
Model Embedding dimension
depth: int
Number of transformer layers
heads: int
Number of attention heads
mlp_dim: int
Feedforward hidden dimension
pool: str
Pooling type, either "cls" or "mean"
in_channels: int, optional
Number of input channels
dropout: float, optional
Dropout rate
emb_dropout: float, optional
Embedding dropout rate
"""
def __init__(
self,
input_res,
patch_res,
num_classes,
dim,
depth,
heads,
mlp_dim,
pool="mean",
in_channels=1,
dropout=0.0,
emb_dropout=0.0,
):
super().__init__()
self.num_patches = int(
(input_res[0] / patch_res[0]) * (input_res[1] / patch_res[1])
)
self.dim = dim
self.patch_embedding = nn.Conv2d(
in_channels, dim, kernel_size=patch_res, stride=patch_res
)
self.pos_embedding = mx.random.truncated_normal(
-0.01,
0.01,
(self.num_patches + 1, dim),
)
self.cls_token = mx.random.truncated_normal(-0.01, 0.01, (dim,))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
self.pool = pool
self.mlp_head = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))
def num_params(self):
nparams = sum(x.size for k, x in tree_flatten(self.parameters()))
return nparams
def __call__(self, x):
if x.ndim != 4:
x = mx.expand_dims(x, axis=-1)
x = self.patch_embedding(x)
x = x.reshape(x.shape[0], -1, self.dim)
assert x.shape[1] == self.num_patches
cls_tokens = mx.broadcast_to(self.cls_token, (x.shape[0], 1, self.dim))
x = mx.concatenate((cls_tokens, x), axis=1)
x = x + self.pos_embedding
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(axis=1) if self.pool == "mean" else x[:, 0]
x = self.mlp_head(x)
return x
def parse_kwt_args(**kwargs):
input_res = kwargs.pop("input_res", [98, 40])
patch_res = kwargs.pop("patch_res", [1, 40])
num_classes = kwargs.pop("num_classes", 35)
emb_dropout = kwargs.pop("emb_dropout", 0.1)
return input_res, patch_res, num_classes, emb_dropout, kwargs
def kwt1(**kwargs):
input_res, patch_res, num_classes, emb_dropout, kwargs = parse_kwt_args(**kwargs)
return KWT(
input_res,
patch_res,
num_classes,
dim=64,
depth=12,
heads=1,
mlp_dim=256,
emb_dropout=emb_dropout,
**kwargs
)
def kwt2(**kwargs):
input_res, patch_res, num_classes, emb_dropout, kwargs = parse_kwt_args(**kwargs)
return KWT(
input_res,
patch_res,
num_classes,
dim=128,
depth=12,
heads=2,
mlp_dim=512,
emb_dropout=emb_dropout,
**kwargs
)
def kwt3(**kwargs):
input_res, patch_res, num_classes, emb_dropout, kwargs = parse_kwt_args(**kwargs)
return KWT(
input_res,
patch_res,
num_classes,
dim=192,
depth=12,
heads=3,
mlp_dim=768,
emb_dropout=emb_dropout,
**kwargs
)

168
speechcommands/main.py Normal file
View File

@ -0,0 +1,168 @@
import argparse
import time
import kwt
import mlx.nn as nn
import mlx.data as dx
import mlx.core as mx
import mlx.optimizers as optim
from mlx.data.features import mfsc
from mlx.data.datasets import load_speechcommands
parser = argparse.ArgumentParser(add_help=True)
parser.add_argument(
"--arch",
type=str,
default="kwt1",
choices=[f"kwt{d}" for d in [1, 2, 3]],
help="model architecture",
)
parser.add_argument("--batch_size", type=int, default=256, help="batch size")
parser.add_argument("--epochs", type=int, default=100, help="number of epochs")
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
parser.add_argument("--seed", type=int, default=0, help="random seed")
parser.add_argument("--cpu", action="store_true", help="use cpu only")
def prepare_dataset(batch_size, split, root=None):
def normalize(x):
return (x - x.mean()) / x.std()
data = load_speechcommands(split=split, root=root)
data_iter = (
data.squeeze("audio")
.key_transform(
"audio",
mfsc(
40,
16000,
frame_size_ms=30,
frame_stride_ms=10,
high_freq=7600,
low_freq=20,
),
)
.key_transform("audio", normalize)
.shuffle()
.batch(batch_size)
)
return data_iter
def eval_fn(model, inp, tgt):
return mx.mean(mx.argmax(model(inp), axis=1) == tgt)
def train_epoch(model, train_iter, optimizer, epoch):
def train_step(model, inp, tgt):
output = model(inp)
loss = mx.mean(nn.losses.cross_entropy(output, tgt))
acc = mx.mean(mx.argmax(output, axis=1) == tgt)
return loss, acc
train_step_fn = nn.value_and_grad(model, train_step)
losses = []
accs = []
samples_per_sec = []
model.train(True)
for batch_counter, batch in enumerate(train_iter):
x = mx.array(batch["audio"])
y = mx.array(batch["label"])
tic = time.perf_counter()
(loss, acc), grads = train_step_fn(model, x, y)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)
toc = time.perf_counter()
loss = loss.item()
acc = acc.item()
losses.append(loss)
accs.append(acc)
throughput = x.shape[0] / (toc - tic)
samples_per_sec.append(throughput)
if batch_counter % 25 == 0:
print(
" | ".join(
(
f"Epoch {epoch:02d} [{batch_counter:03d}]",
f"Train loss {loss:.3f}",
f"Train acc {acc:.3f}",
f"Throughput: {throughput:.2f} samples/second",
)
)
)
mean_tr_loss = mx.mean(mx.array(losses))
mean_tr_acc = mx.mean(mx.array(accs))
samples_per_sec = mx.mean(mx.array(samples_per_sec))
return mean_tr_loss, mean_tr_acc, samples_per_sec
def test_epoch(model, test_iter):
model.train(False)
accs = []
throughput = []
for batch_counter, batch in enumerate(test_iter):
x = mx.array(batch["audio"])
y = mx.array(batch["label"])
tic = time.perf_counter()
acc = eval_fn(model, x, y)
accs.append(acc.item())
toc = time.perf_counter()
throughput.append(x.shape[0] / (toc - tic))
mean_acc = mx.mean(mx.array(accs))
mean_throughput = mx.mean(mx.array(throughput))
return mean_acc, mean_throughput
def main(args):
mx.random.seed(args.seed)
model = getattr(kwt, args.arch)()
print("Number of params: {:0.04f} M".format(model.num_params() / 1e6))
optimizer = optim.SGD(learning_rate=args.lr, momentum=0.9, weight_decay=1e-4)
train_data = prepare_dataset(args.batch_size, "train")
val_data = prepare_dataset(args.batch_size, "validation")
best_params = None
best_acc = 0.0
best_epoch = 0
for epoch in range(args.epochs):
tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch)
print(
" | ".join(
(
f"Epoch: {epoch}",
f"avg. Train loss {tr_loss.item():.3f}",
f"avg. Train acc {tr_acc.item():.3f}",
f"Throughput: {throughput.item():.2f} samples/sec",
)
)
)
val_acc, val_throughput = test_epoch(model, val_data)
print(
f"Epoch: {epoch} | Val acc {val_acc.item():.3f} | Throughput: {val_throughput.item():.2f} samples/sec"
)
if val_acc >= best_acc:
best_acc = val_acc
best_epoch = epoch
best_params = model.parameters()
print(f"Testing best model from epoch {best_epoch}")
model.update(best_params)
test_data = prepare_dataset(args.batch_size, "test")
test_acc, _ = test_epoch(model, test_data)
print(f"Test acc -> {test_acc.item():.3f}")
if __name__ == "__main__":
args = parser.parse_args()
if args.cpu:
mx.set_default_device(mx.cpu)
main(args)

View File

@ -0,0 +1 @@
mlx>=0.0.5