mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Added Keyword Spotting Transformer + SpeechCommands example (#123)
* Added Keyword Transformer + SpeechCommands * minor fixes in README * some updates / simplifications * nits * fixed kwt skip connections * readme + format * updated acknowledgements --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
ebbb7083cc
commit
b6e62caf2e
@ -8,3 +8,4 @@ with a short description of your contribution(s) below. For example:
|
|||||||
MLX Examples was developed with contributions from the following individuals:
|
MLX Examples was developed with contributions from the following individuals:
|
||||||
|
|
||||||
- Juarez Bochi: Added support for T5 models.
|
- Juarez Bochi: Added support for T5 models.
|
||||||
|
- Sarthak Yadav: Added the `cifar` and `speechcommands` examples.
|
69
speechcommands/README.md
Normal file
69
speechcommands/README.md
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
# Train a Keyword Spotting Transformer on Speech Commands
|
||||||
|
|
||||||
|
An example of training a Keyword Spotting Transformer[^1] on the Speech
|
||||||
|
Commands dataset[^2] with MLX. All supervised only configurations from the
|
||||||
|
paper are available. The example also illustrates how to use [MLX
|
||||||
|
Data](https://github.com/ml-explore/mlx-data) to load and process an audio
|
||||||
|
dataset.
|
||||||
|
|
||||||
|
## Pre-requisites
|
||||||
|
|
||||||
|
Follow the [installation
|
||||||
|
instructions](https://ml-explore.github.io/mlx-data/build/html/install.html)
|
||||||
|
for MLX Data.
|
||||||
|
|
||||||
|
Install the remaining python requirements:
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running the example
|
||||||
|
|
||||||
|
Run the example with:
|
||||||
|
|
||||||
|
```
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
By default the example runs on the GPU. To run it on the CPU, use:
|
||||||
|
|
||||||
|
```
|
||||||
|
python main.py --cpu
|
||||||
|
```
|
||||||
|
|
||||||
|
For all available options, run:
|
||||||
|
|
||||||
|
```
|
||||||
|
python main.py --help
|
||||||
|
```
|
||||||
|
|
||||||
|
## Results
|
||||||
|
|
||||||
|
After training with the `kwt1` architecture for 10 epochs, you
|
||||||
|
should see the following results:
|
||||||
|
|
||||||
|
```
|
||||||
|
Epoch: 9 | avg. Train loss 0.519 | avg. Train acc 0.857 | Throughput: 661.28 samples/sec
|
||||||
|
Epoch: 9 | Val acc 0.861 | Throughput: 2976.54 samples/sec
|
||||||
|
Testing best model from epoch 9
|
||||||
|
Test acc -> 0.841
|
||||||
|
```
|
||||||
|
|
||||||
|
For the `kwt2` model, you should see:
|
||||||
|
```
|
||||||
|
Epoch: 9 | avg. Train loss 0.374 | avg. Train acc 0.895 | Throughput: 395.26 samples/sec
|
||||||
|
Epoch: 9 | Val acc 0.879 | Throughput: 1542.44 samples/sec
|
||||||
|
Testing best model from epoch 9
|
||||||
|
Test acc -> 0.861
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that this was run on an M1 Macbook Pro with 16GB RAM.
|
||||||
|
|
||||||
|
At the time of writing, `mlx` doesn't have built-in `cosine` learning rate
|
||||||
|
schedules, which is used along with the AdamW optimizer in the official
|
||||||
|
implementation. We intend to update this example once these features are added,
|
||||||
|
as well as with appropriate data augmentations.
|
||||||
|
|
||||||
|
[^1]: Based one the paper [Keyword Transformer: A Self-Attention Model for Keyword Spotting](https://www.isca-speech.org/archive/interspeech_2021/berg21_interspeech.html)
|
||||||
|
[^2]: We use version 0.02. See the [paper]((https://arxiv.org/abs/1804.03209) for more details.
|
214
speechcommands/kwt.py
Normal file
214
speechcommands/kwt.py
Normal file
@ -0,0 +1,214 @@
|
|||||||
|
from typing import Any
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["KWT", "kwt1", "kwt2", "kwt3"]
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Sequential):
|
||||||
|
def __init__(self, dim, hidden_dim, dropout=0.0):
|
||||||
|
super().__init__(
|
||||||
|
nn.Linear(dim, hidden_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(hidden_dim, dim),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, dim, heads, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.heads = heads
|
||||||
|
self.scale = dim**-0.5
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
||||||
|
self.out = nn.Sequential(nn.Linear(dim, dim), nn.Dropout(dropout))
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
b, n, _, h = *x.shape, self.heads
|
||||||
|
qkv = self.qkv(x)
|
||||||
|
qkv = qkv.reshape(b, n, 3, h, -1).transpose(2, 0, 3, 1, 4)
|
||||||
|
q, k, v = qkv
|
||||||
|
attn = (q @ k.transpose(0, 1, 3, 2)) * self.scale
|
||||||
|
attn = mx.softmax(attn, axis=-1)
|
||||||
|
x = (attn @ v).transpose(0, 2, 1, 3).reshape(b, n, -1)
|
||||||
|
x = self.out(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
def __init__(self, dim, heads, mlp_dim, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.attn = Attention(dim, heads, dropout=dropout)
|
||||||
|
self.norm1 = nn.LayerNorm(dim)
|
||||||
|
self.ff = FeedForward(dim, mlp_dim, dropout=dropout)
|
||||||
|
self.norm2 = nn.LayerNorm(dim)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
x = self.norm1(self.attn(x)) + x
|
||||||
|
x = self.norm2(self.ff(x)) + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
def __init__(self, dim, depth, heads, mlp_dim, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layers = []
|
||||||
|
for _ in range(depth):
|
||||||
|
self.layers.append(Block(dim, heads, mlp_dim, dropout=dropout))
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class KWT(nn.Module):
|
||||||
|
"""
|
||||||
|
Implements the Keyword Transformer (KWT) [1] model.
|
||||||
|
|
||||||
|
KWT is essentially a vision transformer [2] with minor modifications:
|
||||||
|
- Instead of square patches, KWT uses rectangular patches -> a patch
|
||||||
|
across frequency for every timestep
|
||||||
|
- KWT modules apply layer normalization after attention/feedforward layers
|
||||||
|
|
||||||
|
[1] https://arxiv.org/abs/2104.11178
|
||||||
|
[2] https://arxiv.org/abs/2010.11929
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
input_res: tuple of ints
|
||||||
|
Input resolution (time, frequency)
|
||||||
|
patch_res: tuple of ints
|
||||||
|
Patch resolution (time, frequency)
|
||||||
|
num_classes: int
|
||||||
|
Number of classes
|
||||||
|
dim: int
|
||||||
|
Model Embedding dimension
|
||||||
|
depth: int
|
||||||
|
Number of transformer layers
|
||||||
|
heads: int
|
||||||
|
Number of attention heads
|
||||||
|
mlp_dim: int
|
||||||
|
Feedforward hidden dimension
|
||||||
|
pool: str
|
||||||
|
Pooling type, either "cls" or "mean"
|
||||||
|
in_channels: int, optional
|
||||||
|
Number of input channels
|
||||||
|
dropout: float, optional
|
||||||
|
Dropout rate
|
||||||
|
emb_dropout: float, optional
|
||||||
|
Embedding dropout rate
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_res,
|
||||||
|
patch_res,
|
||||||
|
num_classes,
|
||||||
|
dim,
|
||||||
|
depth,
|
||||||
|
heads,
|
||||||
|
mlp_dim,
|
||||||
|
pool="mean",
|
||||||
|
in_channels=1,
|
||||||
|
dropout=0.0,
|
||||||
|
emb_dropout=0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_patches = int(
|
||||||
|
(input_res[0] / patch_res[0]) * (input_res[1] / patch_res[1])
|
||||||
|
)
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
self.patch_embedding = nn.Conv2d(
|
||||||
|
in_channels, dim, kernel_size=patch_res, stride=patch_res
|
||||||
|
)
|
||||||
|
self.pos_embedding = mx.random.truncated_normal(
|
||||||
|
-0.01,
|
||||||
|
0.01,
|
||||||
|
(self.num_patches + 1, dim),
|
||||||
|
)
|
||||||
|
self.cls_token = mx.random.truncated_normal(-0.01, 0.01, (dim,))
|
||||||
|
self.dropout = nn.Dropout(emb_dropout)
|
||||||
|
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
|
||||||
|
self.pool = pool
|
||||||
|
self.mlp_head = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))
|
||||||
|
|
||||||
|
def num_params(self):
|
||||||
|
nparams = sum(x.size for k, x in tree_flatten(self.parameters()))
|
||||||
|
return nparams
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
if x.ndim != 4:
|
||||||
|
x = mx.expand_dims(x, axis=-1)
|
||||||
|
x = self.patch_embedding(x)
|
||||||
|
x = x.reshape(x.shape[0], -1, self.dim)
|
||||||
|
assert x.shape[1] == self.num_patches
|
||||||
|
|
||||||
|
cls_tokens = mx.broadcast_to(self.cls_token, (x.shape[0], 1, self.dim))
|
||||||
|
x = mx.concatenate((cls_tokens, x), axis=1)
|
||||||
|
|
||||||
|
x = x + self.pos_embedding
|
||||||
|
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.transformer(x)
|
||||||
|
x = x.mean(axis=1) if self.pool == "mean" else x[:, 0]
|
||||||
|
x = self.mlp_head(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def parse_kwt_args(**kwargs):
|
||||||
|
input_res = kwargs.pop("input_res", [98, 40])
|
||||||
|
patch_res = kwargs.pop("patch_res", [1, 40])
|
||||||
|
num_classes = kwargs.pop("num_classes", 35)
|
||||||
|
emb_dropout = kwargs.pop("emb_dropout", 0.1)
|
||||||
|
return input_res, patch_res, num_classes, emb_dropout, kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def kwt1(**kwargs):
|
||||||
|
input_res, patch_res, num_classes, emb_dropout, kwargs = parse_kwt_args(**kwargs)
|
||||||
|
return KWT(
|
||||||
|
input_res,
|
||||||
|
patch_res,
|
||||||
|
num_classes,
|
||||||
|
dim=64,
|
||||||
|
depth=12,
|
||||||
|
heads=1,
|
||||||
|
mlp_dim=256,
|
||||||
|
emb_dropout=emb_dropout,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def kwt2(**kwargs):
|
||||||
|
input_res, patch_res, num_classes, emb_dropout, kwargs = parse_kwt_args(**kwargs)
|
||||||
|
return KWT(
|
||||||
|
input_res,
|
||||||
|
patch_res,
|
||||||
|
num_classes,
|
||||||
|
dim=128,
|
||||||
|
depth=12,
|
||||||
|
heads=2,
|
||||||
|
mlp_dim=512,
|
||||||
|
emb_dropout=emb_dropout,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def kwt3(**kwargs):
|
||||||
|
input_res, patch_res, num_classes, emb_dropout, kwargs = parse_kwt_args(**kwargs)
|
||||||
|
return KWT(
|
||||||
|
input_res,
|
||||||
|
patch_res,
|
||||||
|
num_classes,
|
||||||
|
dim=192,
|
||||||
|
depth=12,
|
||||||
|
heads=3,
|
||||||
|
mlp_dim=768,
|
||||||
|
emb_dropout=emb_dropout,
|
||||||
|
**kwargs
|
||||||
|
)
|
168
speechcommands/main.py
Normal file
168
speechcommands/main.py
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
import kwt
|
||||||
|
import mlx.nn as nn
|
||||||
|
import mlx.data as dx
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.optimizers as optim
|
||||||
|
from mlx.data.features import mfsc
|
||||||
|
from mlx.data.datasets import load_speechcommands
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(add_help=True)
|
||||||
|
parser.add_argument(
|
||||||
|
"--arch",
|
||||||
|
type=str,
|
||||||
|
default="kwt1",
|
||||||
|
choices=[f"kwt{d}" for d in [1, 2, 3]],
|
||||||
|
help="model architecture",
|
||||||
|
)
|
||||||
|
parser.add_argument("--batch_size", type=int, default=256, help="batch size")
|
||||||
|
parser.add_argument("--epochs", type=int, default=100, help="number of epochs")
|
||||||
|
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
|
||||||
|
parser.add_argument("--seed", type=int, default=0, help="random seed")
|
||||||
|
parser.add_argument("--cpu", action="store_true", help="use cpu only")
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_dataset(batch_size, split, root=None):
|
||||||
|
def normalize(x):
|
||||||
|
return (x - x.mean()) / x.std()
|
||||||
|
|
||||||
|
data = load_speechcommands(split=split, root=root)
|
||||||
|
|
||||||
|
data_iter = (
|
||||||
|
data.squeeze("audio")
|
||||||
|
.key_transform(
|
||||||
|
"audio",
|
||||||
|
mfsc(
|
||||||
|
40,
|
||||||
|
16000,
|
||||||
|
frame_size_ms=30,
|
||||||
|
frame_stride_ms=10,
|
||||||
|
high_freq=7600,
|
||||||
|
low_freq=20,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.key_transform("audio", normalize)
|
||||||
|
.shuffle()
|
||||||
|
.batch(batch_size)
|
||||||
|
)
|
||||||
|
return data_iter
|
||||||
|
|
||||||
|
|
||||||
|
def eval_fn(model, inp, tgt):
|
||||||
|
return mx.mean(mx.argmax(model(inp), axis=1) == tgt)
|
||||||
|
|
||||||
|
|
||||||
|
def train_epoch(model, train_iter, optimizer, epoch):
|
||||||
|
def train_step(model, inp, tgt):
|
||||||
|
output = model(inp)
|
||||||
|
loss = mx.mean(nn.losses.cross_entropy(output, tgt))
|
||||||
|
acc = mx.mean(mx.argmax(output, axis=1) == tgt)
|
||||||
|
return loss, acc
|
||||||
|
|
||||||
|
train_step_fn = nn.value_and_grad(model, train_step)
|
||||||
|
|
||||||
|
losses = []
|
||||||
|
accs = []
|
||||||
|
samples_per_sec = []
|
||||||
|
|
||||||
|
model.train(True)
|
||||||
|
for batch_counter, batch in enumerate(train_iter):
|
||||||
|
x = mx.array(batch["audio"])
|
||||||
|
y = mx.array(batch["label"])
|
||||||
|
tic = time.perf_counter()
|
||||||
|
(loss, acc), grads = train_step_fn(model, x, y)
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
mx.eval(model.parameters(), optimizer.state)
|
||||||
|
toc = time.perf_counter()
|
||||||
|
loss = loss.item()
|
||||||
|
acc = acc.item()
|
||||||
|
losses.append(loss)
|
||||||
|
accs.append(acc)
|
||||||
|
throughput = x.shape[0] / (toc - tic)
|
||||||
|
samples_per_sec.append(throughput)
|
||||||
|
if batch_counter % 25 == 0:
|
||||||
|
print(
|
||||||
|
" | ".join(
|
||||||
|
(
|
||||||
|
f"Epoch {epoch:02d} [{batch_counter:03d}]",
|
||||||
|
f"Train loss {loss:.3f}",
|
||||||
|
f"Train acc {acc:.3f}",
|
||||||
|
f"Throughput: {throughput:.2f} samples/second",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
mean_tr_loss = mx.mean(mx.array(losses))
|
||||||
|
mean_tr_acc = mx.mean(mx.array(accs))
|
||||||
|
samples_per_sec = mx.mean(mx.array(samples_per_sec))
|
||||||
|
return mean_tr_loss, mean_tr_acc, samples_per_sec
|
||||||
|
|
||||||
|
|
||||||
|
def test_epoch(model, test_iter):
|
||||||
|
model.train(False)
|
||||||
|
accs = []
|
||||||
|
throughput = []
|
||||||
|
for batch_counter, batch in enumerate(test_iter):
|
||||||
|
x = mx.array(batch["audio"])
|
||||||
|
y = mx.array(batch["label"])
|
||||||
|
tic = time.perf_counter()
|
||||||
|
acc = eval_fn(model, x, y)
|
||||||
|
accs.append(acc.item())
|
||||||
|
toc = time.perf_counter()
|
||||||
|
throughput.append(x.shape[0] / (toc - tic))
|
||||||
|
mean_acc = mx.mean(mx.array(accs))
|
||||||
|
mean_throughput = mx.mean(mx.array(throughput))
|
||||||
|
return mean_acc, mean_throughput
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
mx.random.seed(args.seed)
|
||||||
|
|
||||||
|
model = getattr(kwt, args.arch)()
|
||||||
|
|
||||||
|
print("Number of params: {:0.04f} M".format(model.num_params() / 1e6))
|
||||||
|
|
||||||
|
optimizer = optim.SGD(learning_rate=args.lr, momentum=0.9, weight_decay=1e-4)
|
||||||
|
|
||||||
|
train_data = prepare_dataset(args.batch_size, "train")
|
||||||
|
val_data = prepare_dataset(args.batch_size, "validation")
|
||||||
|
|
||||||
|
best_params = None
|
||||||
|
best_acc = 0.0
|
||||||
|
best_epoch = 0
|
||||||
|
for epoch in range(args.epochs):
|
||||||
|
tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch)
|
||||||
|
print(
|
||||||
|
" | ".join(
|
||||||
|
(
|
||||||
|
f"Epoch: {epoch}",
|
||||||
|
f"avg. Train loss {tr_loss.item():.3f}",
|
||||||
|
f"avg. Train acc {tr_acc.item():.3f}",
|
||||||
|
f"Throughput: {throughput.item():.2f} samples/sec",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
val_acc, val_throughput = test_epoch(model, val_data)
|
||||||
|
print(
|
||||||
|
f"Epoch: {epoch} | Val acc {val_acc.item():.3f} | Throughput: {val_throughput.item():.2f} samples/sec"
|
||||||
|
)
|
||||||
|
|
||||||
|
if val_acc >= best_acc:
|
||||||
|
best_acc = val_acc
|
||||||
|
best_epoch = epoch
|
||||||
|
best_params = model.parameters()
|
||||||
|
print(f"Testing best model from epoch {best_epoch}")
|
||||||
|
model.update(best_params)
|
||||||
|
test_data = prepare_dataset(args.batch_size, "test")
|
||||||
|
test_acc, _ = test_epoch(model, test_data)
|
||||||
|
print(f"Test acc -> {test_acc.item():.3f}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.cpu:
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
main(args)
|
1
speechcommands/requirements.txt
Normal file
1
speechcommands/requirements.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
mlx>=0.0.5
|
Loading…
Reference in New Issue
Block a user