Added Keyword Transformer + SpeechCommands

This commit is contained in:
Sarthak Yadav 2023-12-16 23:30:33 +01:00
parent 08e862336a
commit 3e24277ba3
4 changed files with 453 additions and 0 deletions

60
speechcommands/README.md Normal file
View File

@ -0,0 +1,60 @@
# Training a Vision Transformer on SpeechCommands
An example of training [Keyword Spotting Transformer](https://www.isca-speech.org/archive/interspeech_2021/berg21_interspeech.html), a variant of the Vision Transformer, on the [Speech Commands](https://arxiv.org/abs/1804.03209) (v0.02) dataset with MLX. All supervised only configurations from the paper are available.The example also
illustrates how to use [MLX Data](https://github.com/ml-explore/mlx-data) to
load and process an audio dataset.
## Pre-requisites
Install the `mlx`
```
pip install mlx==0.0.5
```
At the time of writing, the SpeechCommands dataset is not yet a part of a `mlx-data` release. Install `mlx-data` from source using this [commit](https://github.com/ml-explore/mlx-data/commit/ae3431648b8e1594d63175a8f121d9873aeb9daa).
## Running the example
Run the example with:
```
python main.py
```
By default the example runs on the GPU. To run on the CPU, use:
```
python main.py --cpu
```
For all available options, run:
```
python main.py --help
```
## Results
After training with the `kwt1` architecture for 100 epochs, you
should see the following results:
```
Epoch: 99 | avg. Train loss 0.581 | avg. Train acc 0.826 | Throughput: 677.37 samples/sec
Epoch: 99 | Val acc 0.710
Testing best model from Epoch 98
Test acc -> 0.687
```
For the `kwt2` model, you should see:
```
Epoch: 99 | avg. Train loss 0.137 | avg. Train acc 0.956 | Throughput: 401.47 samples/sec
Epoch: 99 | Val acc 0.739
Testing best model from Epoch 97
Test acc -> 0.718
```
Note that this was run on an M1 Macbook Pro with 16GB RAM.
At the time of writing, `mlx` doesn't have built-in `cosine` learning rate schedules, which is used along with the AdamW optimizer in the official implementaiton. We intend to update this example once these features
are added, as well as with appropriate data augmentations.

231
speechcommands/kwt.py Normal file
View File

@ -0,0 +1,231 @@
from typing import Any
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten
__all__ = ["KWT", "kwt1", "kwt2", "kwt3"]
STD = 0.02
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout) if dropout != 0.0 else Identity(),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout) if dropout != 0.0 else Identity(),
)
def __call__(self, x):
return self.net(x)
class Identity(nn.Module):
def __init__(self):
super().__init__()
def __call__(self, x):
return x
class Attention(nn.Module):
def __init__(self, dim, heads, dropout=0.0):
super().__init__()
self.heads = heads
self.scale = dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.out = nn.Sequential(
nn.Linear(dim, dim), nn.Dropout(dropout) if dropout != 0.0 else Identity()
)
def __call__(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.qkv(x)
qkv = qkv.reshape(b, n, 3, h, -1).transpose((2, 0, 3, 1, 4))
q, k, v = qkv
attn = (q @ k.transpose(0, 1, 3, 2)) * self.scale
attn = mx.softmax(attn, axis=-1)
x = (attn @ v).transpose((0, 2, 1, 3)).reshape(b, n, -1)
x = self.out(x)
return x
class Block(nn.Module):
def __init__(self, dim, heads, mlp_dim, dropout=0.0):
super().__init__()
# self.attn = nn.MultiHeadAttention(dim, heads)
self.attn = Attention(dim, heads, dropout=dropout)
self.norm1 = nn.LayerNorm(dim)
self.ff = FeedForward(dim, mlp_dim, dropout=dropout)
self.norm2 = nn.LayerNorm(dim)
def __call__(self, x):
x = self.attn(x)
x = self.norm1(x)
x = self.ff(x)
x = self.norm2(x)
return x
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, mlp_dim, dropout=0.0):
super().__init__()
self.layers = []
for _ in range(depth):
self.layers.append(Block(dim, heads, mlp_dim, dropout=dropout))
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x
class KWT(nn.Module):
"""
Implements the Keyword Transformer (KWT) [1] model.
KWT is essentially a vision transformer [2] with minor modifications:
- Instead of square patches, KWT uses rectangular patches -> a patch across frequency for every timestep
- KWT modules apply LayerNormalization after attention/feedforward layers, also referred to as PostNorm
[1] https://arxiv.org/abs/2104.11178
[2] https://arxiv.org/abs/2010.11929
Parameters
----------
input_res: tuple of ints
Input resolution (time, frequency)
patch_res: tuple of ints
Patch resolution (time, frequency)
num_classes: int
Number of classes
dim: int
Model Embedding dimension
depth: int
Number of transformer layers
heads: int
Number of attention heads
mlp_dim: int
Feedforward hidden dimension
pool: str
Pooling type, either "cls" or "mean"
in_channels: int, optional
Number of input channels
dropout: float, optional
Dropout rate
emb_dropout: float, optional
Embedding dropout rate
"""
def __init__(
self,
input_res,
patch_res,
num_classes,
dim,
depth,
heads,
mlp_dim,
pool="mean",
in_channels=1,
dropout=0.0,
emb_dropout=0.0,
):
super().__init__()
self.num_patches = int(
(input_res[0] / patch_res[0]) * (input_res[1] / patch_res[1])
)
self.dim = dim
self.patch_embedding = nn.Conv2d(
in_channels, dim, kernel_size=patch_res, stride=patch_res
)
self.pos_embedding = mx.random.truncated_normal(
-1 * STD / 2, STD / 2, (1, self.num_patches + 1, dim)
)
self.cls_token = mx.random.truncated_normal(-1 * STD / 2, STD / 2, (1, 1, dim))
self.dropout = nn.Dropout(emb_dropout) if emb_dropout != 0.0 else Identity()
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
self.pool = pool
self.mlp_head = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))
def num_params(self):
nparams = sum(x.size for k, x in tree_flatten(self.parameters()))
return nparams
def __call__(self, x):
if x.ndim != 4:
x = mx.expand_dims(x, axis=-1)
x = self.patch_embedding(x)
x = x.reshape(x.shape[0], -1, self.dim)
assert x.shape[1] == self.num_patches
# x = x + self.pos_embedding[:, 1:, :]
cls_tokens = mx.broadcast_to(self.cls_token, (x.shape[0], 1, self.dim))
x = mx.concatenate((cls_tokens, x), axis=1)
x = x + self.pos_embedding
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(axis=1) if self.pool == "mean" else x[:, 0]
x = self.mlp_head(x)
return x
def parse_kwt_args(**kwargs):
input_res = kwargs.pop("input_res", [98, 40])
patch_res = kwargs.pop("patch_res", [1, 40])
num_classes = kwargs.pop("num_classes", 35)
emb_dropout = kwargs.pop("emb_dropout", 0.1)
return input_res, patch_res, num_classes, emb_dropout, kwargs
def kwt1(**kwargs):
input_res, patch_res, num_classes, emb_dropout, kwargs = parse_kwt_args(**kwargs)
return KWT(
input_res,
patch_res,
num_classes,
dim=64,
depth=12,
heads=1,
mlp_dim=256,
emb_dropout=emb_dropout,
**kwargs
)
def kwt2(**kwargs):
input_res, patch_res, num_classes, emb_dropout, kwargs = parse_kwt_args(**kwargs)
return KWT(
input_res,
patch_res,
num_classes,
dim=128,
depth=12,
heads=2,
mlp_dim=512,
emb_dropout=emb_dropout,
**kwargs
)
def kwt3(**kwargs):
input_res, patch_res, num_classes, emb_dropout, kwargs = parse_kwt_args(**kwargs)
return KWT(
input_res,
patch_res,
num_classes,
dim=192,
depth=12,
heads=3,
mlp_dim=768,
emb_dropout=emb_dropout,
**kwargs
)

160
speechcommands/main.py Normal file
View File

@ -0,0 +1,160 @@
import argparse
import time
import kwt
import mlx.nn as nn
import mlx.data as dx
import mlx.core as mx
import mlx.optimizers as optim
from mlx.data.features import mfsc
from mlx.data.datasets import load_speechcommands
parser = argparse.ArgumentParser(add_help=True)
parser.add_argument(
"--arch",
type=str,
default="kwt1",
choices=[f"kwt{d}" for d in [1, 2, 3]],
help="model architecture",
)
parser.add_argument("--batch_size", type=int, default=256, help="batch size")
parser.add_argument("--epochs", type=int, default=100, help="number of epochs")
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
parser.add_argument("--seed", type=int, default=0, help="random seed")
parser.add_argument("--cpu", action="store_true", help="use cpu only")
def prepare_dataset(batch_size, split, root=None):
def normalize(x):
return (x - x.mean()) / x.std()
data = load_speechcommands(split=split, root=root)
data_iter = (
data.squeeze("audio")
.key_transform(
"audio",
mfsc(
40,
16000,
frame_size_ms=30,
frame_stride_ms=10,
high_freq=7600,
low_freq=20,
),
)
.key_transform("audio", normalize)
.shuffle()
.batch(batch_size)
)
return data_iter
def eval_fn(model, inp, tgt):
return mx.mean(mx.argmax(model(inp), axis=1) == tgt)
def train_epoch(model, train_iter, optimizer, epoch):
def train_step(model, inp, tgt):
output = model(inp)
loss = mx.mean(nn.losses.cross_entropy(output, tgt))
acc = mx.mean(mx.argmax(output, axis=1) == tgt)
return loss, acc
train_step_fn = nn.value_and_grad(model, train_step)
losses = []
accs = []
samples_per_sec = []
for batch_counter, batch in enumerate(train_iter):
x = mx.array(batch["audio"])
y = mx.array(batch["label"])
tic = time.perf_counter()
(loss, acc), grads = train_step_fn(model, x, y)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)
toc = time.perf_counter()
loss = loss.item()
acc = acc.item()
losses.append(loss)
accs.append(acc)
throughput = x.shape[0] / (toc - tic)
samples_per_sec.append(throughput)
if batch_counter % 25 == 0:
print(
" | ".join(
(
f"Epoch {epoch:02d} [{batch_counter:03d}]",
f"Train loss {loss:.3f}",
f"Train acc {acc:.3f}",
f"Throughput: {throughput:.2f} samples/second",
)
)
)
mean_tr_loss = mx.mean(mx.array(losses))
mean_tr_acc = mx.mean(mx.array(accs))
samples_per_sec = mx.mean(mx.array(samples_per_sec))
return mean_tr_loss, mean_tr_acc, samples_per_sec
def test_epoch(model, test_iter):
accs = []
for batch_counter, batch in enumerate(test_iter):
x = mx.array(batch["audio"])
y = mx.array(batch["label"])
acc = eval_fn(model, x, y)
acc_value = acc.item()
accs.append(acc_value)
mean_acc = mx.mean(mx.array(accs))
return mean_acc
def main(args):
mx.random.seed(args.seed)
model = getattr(kwt, args.arch)()
print("Number of params: {:0.04f} M".format(model.num_params() / 1e6))
optimizer = optim.SGD(learning_rate=args.lr, momentum=0.9, weight_decay=1e-4)
train_data = prepare_dataset(args.batch_size, "train")
val_data = prepare_dataset(args.batch_size, "validation")
best_params = None
best_acc = 0.0
best_epoch = 0
for epoch in range(args.epochs):
tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch)
print(
" | ".join(
(
f"Epoch: {epoch}",
f"avg. Train loss {tr_loss.item():.3f}",
f"avg. Train acc {tr_acc.item():.3f}",
f"Throughput: {throughput.item():.2f} samples/sec",
)
)
)
val_acc = test_epoch(model, val_data)
print(f"Epoch: {epoch} | Val acc {val_acc.item():.3f}")
if val_acc >= best_acc:
best_acc = val_acc
best_epoch = epoch
best_params = model.parameters()
print(f"Testing best model from Epoch {best_epoch}")
model.update(best_params)
test_data = prepare_dataset(args.batch_size, "test")
test_acc = test_epoch(model, test_data)
print(f"Test acc -> {test_acc.item():.3f}")
if __name__ == "__main__":
args = parser.parse_args()
if args.cpu:
mx.set_default_device(mx.cpu)
main(args)

View File

@ -0,0 +1,2 @@
mlx==0.0.5
mlx-data