mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Added Keyword Transformer + SpeechCommands
This commit is contained in:
parent
08e862336a
commit
3e24277ba3
60
speechcommands/README.md
Normal file
60
speechcommands/README.md
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
# Training a Vision Transformer on SpeechCommands
|
||||||
|
|
||||||
|
An example of training [Keyword Spotting Transformer](https://www.isca-speech.org/archive/interspeech_2021/berg21_interspeech.html), a variant of the Vision Transformer, on the [Speech Commands](https://arxiv.org/abs/1804.03209) (v0.02) dataset with MLX. All supervised only configurations from the paper are available.The example also
|
||||||
|
illustrates how to use [MLX Data](https://github.com/ml-explore/mlx-data) to
|
||||||
|
load and process an audio dataset.
|
||||||
|
|
||||||
|
## Pre-requisites
|
||||||
|
|
||||||
|
Install the `mlx`
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install mlx==0.0.5
|
||||||
|
```
|
||||||
|
|
||||||
|
At the time of writing, the SpeechCommands dataset is not yet a part of a `mlx-data` release. Install `mlx-data` from source using this [commit](https://github.com/ml-explore/mlx-data/commit/ae3431648b8e1594d63175a8f121d9873aeb9daa).
|
||||||
|
|
||||||
|
## Running the example
|
||||||
|
|
||||||
|
Run the example with:
|
||||||
|
|
||||||
|
```
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
By default the example runs on the GPU. To run on the CPU, use:
|
||||||
|
|
||||||
|
```
|
||||||
|
python main.py --cpu
|
||||||
|
```
|
||||||
|
|
||||||
|
For all available options, run:
|
||||||
|
|
||||||
|
```
|
||||||
|
python main.py --help
|
||||||
|
```
|
||||||
|
|
||||||
|
## Results
|
||||||
|
|
||||||
|
After training with the `kwt1` architecture for 100 epochs, you
|
||||||
|
should see the following results:
|
||||||
|
|
||||||
|
```
|
||||||
|
Epoch: 99 | avg. Train loss 0.581 | avg. Train acc 0.826 | Throughput: 677.37 samples/sec
|
||||||
|
Epoch: 99 | Val acc 0.710
|
||||||
|
Testing best model from Epoch 98
|
||||||
|
Test acc -> 0.687
|
||||||
|
```
|
||||||
|
|
||||||
|
For the `kwt2` model, you should see:
|
||||||
|
```
|
||||||
|
Epoch: 99 | avg. Train loss 0.137 | avg. Train acc 0.956 | Throughput: 401.47 samples/sec
|
||||||
|
Epoch: 99 | Val acc 0.739
|
||||||
|
Testing best model from Epoch 97
|
||||||
|
Test acc -> 0.718
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that this was run on an M1 Macbook Pro with 16GB RAM.
|
||||||
|
|
||||||
|
At the time of writing, `mlx` doesn't have built-in `cosine` learning rate schedules, which is used along with the AdamW optimizer in the official implementaiton. We intend to update this example once these features
|
||||||
|
are added, as well as with appropriate data augmentations.
|
231
speechcommands/kwt.py
Normal file
231
speechcommands/kwt.py
Normal file
@ -0,0 +1,231 @@
|
|||||||
|
from typing import Any
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["KWT", "kwt1", "kwt2", "kwt3"]
|
||||||
|
STD = 0.02
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim, hidden_dim, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Linear(dim, hidden_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout) if dropout != 0.0 else Identity(),
|
||||||
|
nn.Linear(hidden_dim, dim),
|
||||||
|
nn.Dropout(dropout) if dropout != 0.0 else Identity(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Identity(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, dim, heads, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.heads = heads
|
||||||
|
self.scale = dim**-0.5
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
||||||
|
self.out = nn.Sequential(
|
||||||
|
nn.Linear(dim, dim), nn.Dropout(dropout) if dropout != 0.0 else Identity()
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
b, n, _, h = *x.shape, self.heads
|
||||||
|
qkv = self.qkv(x)
|
||||||
|
qkv = qkv.reshape(b, n, 3, h, -1).transpose((2, 0, 3, 1, 4))
|
||||||
|
q, k, v = qkv
|
||||||
|
attn = (q @ k.transpose(0, 1, 3, 2)) * self.scale
|
||||||
|
attn = mx.softmax(attn, axis=-1)
|
||||||
|
x = (attn @ v).transpose((0, 2, 1, 3)).reshape(b, n, -1)
|
||||||
|
x = self.out(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
def __init__(self, dim, heads, mlp_dim, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
# self.attn = nn.MultiHeadAttention(dim, heads)
|
||||||
|
self.attn = Attention(dim, heads, dropout=dropout)
|
||||||
|
self.norm1 = nn.LayerNorm(dim)
|
||||||
|
self.ff = FeedForward(dim, mlp_dim, dropout=dropout)
|
||||||
|
self.norm2 = nn.LayerNorm(dim)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
x = self.attn(x)
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.ff(x)
|
||||||
|
x = self.norm2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
def __init__(self, dim, depth, heads, mlp_dim, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layers = []
|
||||||
|
for _ in range(depth):
|
||||||
|
self.layers.append(Block(dim, heads, mlp_dim, dropout=dropout))
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class KWT(nn.Module):
|
||||||
|
"""
|
||||||
|
Implements the Keyword Transformer (KWT) [1] model.
|
||||||
|
|
||||||
|
KWT is essentially a vision transformer [2] with minor modifications:
|
||||||
|
- Instead of square patches, KWT uses rectangular patches -> a patch across frequency for every timestep
|
||||||
|
- KWT modules apply LayerNormalization after attention/feedforward layers, also referred to as PostNorm
|
||||||
|
|
||||||
|
[1] https://arxiv.org/abs/2104.11178
|
||||||
|
[2] https://arxiv.org/abs/2010.11929
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
input_res: tuple of ints
|
||||||
|
Input resolution (time, frequency)
|
||||||
|
patch_res: tuple of ints
|
||||||
|
Patch resolution (time, frequency)
|
||||||
|
num_classes: int
|
||||||
|
Number of classes
|
||||||
|
dim: int
|
||||||
|
Model Embedding dimension
|
||||||
|
depth: int
|
||||||
|
Number of transformer layers
|
||||||
|
heads: int
|
||||||
|
Number of attention heads
|
||||||
|
mlp_dim: int
|
||||||
|
Feedforward hidden dimension
|
||||||
|
pool: str
|
||||||
|
Pooling type, either "cls" or "mean"
|
||||||
|
in_channels: int, optional
|
||||||
|
Number of input channels
|
||||||
|
dropout: float, optional
|
||||||
|
Dropout rate
|
||||||
|
emb_dropout: float, optional
|
||||||
|
Embedding dropout rate
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_res,
|
||||||
|
patch_res,
|
||||||
|
num_classes,
|
||||||
|
dim,
|
||||||
|
depth,
|
||||||
|
heads,
|
||||||
|
mlp_dim,
|
||||||
|
pool="mean",
|
||||||
|
in_channels=1,
|
||||||
|
dropout=0.0,
|
||||||
|
emb_dropout=0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_patches = int(
|
||||||
|
(input_res[0] / patch_res[0]) * (input_res[1] / patch_res[1])
|
||||||
|
)
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
self.patch_embedding = nn.Conv2d(
|
||||||
|
in_channels, dim, kernel_size=patch_res, stride=patch_res
|
||||||
|
)
|
||||||
|
self.pos_embedding = mx.random.truncated_normal(
|
||||||
|
-1 * STD / 2, STD / 2, (1, self.num_patches + 1, dim)
|
||||||
|
)
|
||||||
|
self.cls_token = mx.random.truncated_normal(-1 * STD / 2, STD / 2, (1, 1, dim))
|
||||||
|
self.dropout = nn.Dropout(emb_dropout) if emb_dropout != 0.0 else Identity()
|
||||||
|
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
|
||||||
|
self.pool = pool
|
||||||
|
self.mlp_head = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))
|
||||||
|
|
||||||
|
def num_params(self):
|
||||||
|
nparams = sum(x.size for k, x in tree_flatten(self.parameters()))
|
||||||
|
return nparams
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
if x.ndim != 4:
|
||||||
|
x = mx.expand_dims(x, axis=-1)
|
||||||
|
x = self.patch_embedding(x)
|
||||||
|
x = x.reshape(x.shape[0], -1, self.dim)
|
||||||
|
assert x.shape[1] == self.num_patches
|
||||||
|
|
||||||
|
# x = x + self.pos_embedding[:, 1:, :]
|
||||||
|
|
||||||
|
cls_tokens = mx.broadcast_to(self.cls_token, (x.shape[0], 1, self.dim))
|
||||||
|
x = mx.concatenate((cls_tokens, x), axis=1)
|
||||||
|
|
||||||
|
x = x + self.pos_embedding
|
||||||
|
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.transformer(x)
|
||||||
|
x = x.mean(axis=1) if self.pool == "mean" else x[:, 0]
|
||||||
|
x = self.mlp_head(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def parse_kwt_args(**kwargs):
|
||||||
|
input_res = kwargs.pop("input_res", [98, 40])
|
||||||
|
patch_res = kwargs.pop("patch_res", [1, 40])
|
||||||
|
num_classes = kwargs.pop("num_classes", 35)
|
||||||
|
emb_dropout = kwargs.pop("emb_dropout", 0.1)
|
||||||
|
return input_res, patch_res, num_classes, emb_dropout, kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def kwt1(**kwargs):
|
||||||
|
input_res, patch_res, num_classes, emb_dropout, kwargs = parse_kwt_args(**kwargs)
|
||||||
|
return KWT(
|
||||||
|
input_res,
|
||||||
|
patch_res,
|
||||||
|
num_classes,
|
||||||
|
dim=64,
|
||||||
|
depth=12,
|
||||||
|
heads=1,
|
||||||
|
mlp_dim=256,
|
||||||
|
emb_dropout=emb_dropout,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def kwt2(**kwargs):
|
||||||
|
input_res, patch_res, num_classes, emb_dropout, kwargs = parse_kwt_args(**kwargs)
|
||||||
|
return KWT(
|
||||||
|
input_res,
|
||||||
|
patch_res,
|
||||||
|
num_classes,
|
||||||
|
dim=128,
|
||||||
|
depth=12,
|
||||||
|
heads=2,
|
||||||
|
mlp_dim=512,
|
||||||
|
emb_dropout=emb_dropout,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def kwt3(**kwargs):
|
||||||
|
input_res, patch_res, num_classes, emb_dropout, kwargs = parse_kwt_args(**kwargs)
|
||||||
|
return KWT(
|
||||||
|
input_res,
|
||||||
|
patch_res,
|
||||||
|
num_classes,
|
||||||
|
dim=192,
|
||||||
|
depth=12,
|
||||||
|
heads=3,
|
||||||
|
mlp_dim=768,
|
||||||
|
emb_dropout=emb_dropout,
|
||||||
|
**kwargs
|
||||||
|
)
|
160
speechcommands/main.py
Normal file
160
speechcommands/main.py
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
import kwt
|
||||||
|
import mlx.nn as nn
|
||||||
|
import mlx.data as dx
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.optimizers as optim
|
||||||
|
from mlx.data.features import mfsc
|
||||||
|
from mlx.data.datasets import load_speechcommands
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(add_help=True)
|
||||||
|
parser.add_argument(
|
||||||
|
"--arch",
|
||||||
|
type=str,
|
||||||
|
default="kwt1",
|
||||||
|
choices=[f"kwt{d}" for d in [1, 2, 3]],
|
||||||
|
help="model architecture",
|
||||||
|
)
|
||||||
|
parser.add_argument("--batch_size", type=int, default=256, help="batch size")
|
||||||
|
parser.add_argument("--epochs", type=int, default=100, help="number of epochs")
|
||||||
|
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
|
||||||
|
parser.add_argument("--seed", type=int, default=0, help="random seed")
|
||||||
|
parser.add_argument("--cpu", action="store_true", help="use cpu only")
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_dataset(batch_size, split, root=None):
|
||||||
|
def normalize(x):
|
||||||
|
return (x - x.mean()) / x.std()
|
||||||
|
|
||||||
|
data = load_speechcommands(split=split, root=root)
|
||||||
|
|
||||||
|
data_iter = (
|
||||||
|
data.squeeze("audio")
|
||||||
|
.key_transform(
|
||||||
|
"audio",
|
||||||
|
mfsc(
|
||||||
|
40,
|
||||||
|
16000,
|
||||||
|
frame_size_ms=30,
|
||||||
|
frame_stride_ms=10,
|
||||||
|
high_freq=7600,
|
||||||
|
low_freq=20,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.key_transform("audio", normalize)
|
||||||
|
.shuffle()
|
||||||
|
.batch(batch_size)
|
||||||
|
)
|
||||||
|
return data_iter
|
||||||
|
|
||||||
|
|
||||||
|
def eval_fn(model, inp, tgt):
|
||||||
|
return mx.mean(mx.argmax(model(inp), axis=1) == tgt)
|
||||||
|
|
||||||
|
|
||||||
|
def train_epoch(model, train_iter, optimizer, epoch):
|
||||||
|
def train_step(model, inp, tgt):
|
||||||
|
output = model(inp)
|
||||||
|
loss = mx.mean(nn.losses.cross_entropy(output, tgt))
|
||||||
|
acc = mx.mean(mx.argmax(output, axis=1) == tgt)
|
||||||
|
return loss, acc
|
||||||
|
|
||||||
|
train_step_fn = nn.value_and_grad(model, train_step)
|
||||||
|
|
||||||
|
losses = []
|
||||||
|
accs = []
|
||||||
|
samples_per_sec = []
|
||||||
|
|
||||||
|
for batch_counter, batch in enumerate(train_iter):
|
||||||
|
x = mx.array(batch["audio"])
|
||||||
|
y = mx.array(batch["label"])
|
||||||
|
tic = time.perf_counter()
|
||||||
|
(loss, acc), grads = train_step_fn(model, x, y)
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
mx.eval(model.parameters(), optimizer.state)
|
||||||
|
toc = time.perf_counter()
|
||||||
|
loss = loss.item()
|
||||||
|
acc = acc.item()
|
||||||
|
losses.append(loss)
|
||||||
|
accs.append(acc)
|
||||||
|
throughput = x.shape[0] / (toc - tic)
|
||||||
|
samples_per_sec.append(throughput)
|
||||||
|
if batch_counter % 25 == 0:
|
||||||
|
print(
|
||||||
|
" | ".join(
|
||||||
|
(
|
||||||
|
f"Epoch {epoch:02d} [{batch_counter:03d}]",
|
||||||
|
f"Train loss {loss:.3f}",
|
||||||
|
f"Train acc {acc:.3f}",
|
||||||
|
f"Throughput: {throughput:.2f} samples/second",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
mean_tr_loss = mx.mean(mx.array(losses))
|
||||||
|
mean_tr_acc = mx.mean(mx.array(accs))
|
||||||
|
samples_per_sec = mx.mean(mx.array(samples_per_sec))
|
||||||
|
return mean_tr_loss, mean_tr_acc, samples_per_sec
|
||||||
|
|
||||||
|
|
||||||
|
def test_epoch(model, test_iter):
|
||||||
|
accs = []
|
||||||
|
for batch_counter, batch in enumerate(test_iter):
|
||||||
|
x = mx.array(batch["audio"])
|
||||||
|
y = mx.array(batch["label"])
|
||||||
|
acc = eval_fn(model, x, y)
|
||||||
|
acc_value = acc.item()
|
||||||
|
accs.append(acc_value)
|
||||||
|
mean_acc = mx.mean(mx.array(accs))
|
||||||
|
return mean_acc
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
mx.random.seed(args.seed)
|
||||||
|
|
||||||
|
model = getattr(kwt, args.arch)()
|
||||||
|
|
||||||
|
print("Number of params: {:0.04f} M".format(model.num_params() / 1e6))
|
||||||
|
|
||||||
|
optimizer = optim.SGD(learning_rate=args.lr, momentum=0.9, weight_decay=1e-4)
|
||||||
|
|
||||||
|
train_data = prepare_dataset(args.batch_size, "train")
|
||||||
|
val_data = prepare_dataset(args.batch_size, "validation")
|
||||||
|
|
||||||
|
best_params = None
|
||||||
|
best_acc = 0.0
|
||||||
|
best_epoch = 0
|
||||||
|
for epoch in range(args.epochs):
|
||||||
|
tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch)
|
||||||
|
print(
|
||||||
|
" | ".join(
|
||||||
|
(
|
||||||
|
f"Epoch: {epoch}",
|
||||||
|
f"avg. Train loss {tr_loss.item():.3f}",
|
||||||
|
f"avg. Train acc {tr_acc.item():.3f}",
|
||||||
|
f"Throughput: {throughput.item():.2f} samples/sec",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
val_acc = test_epoch(model, val_data)
|
||||||
|
print(f"Epoch: {epoch} | Val acc {val_acc.item():.3f}")
|
||||||
|
|
||||||
|
if val_acc >= best_acc:
|
||||||
|
best_acc = val_acc
|
||||||
|
best_epoch = epoch
|
||||||
|
best_params = model.parameters()
|
||||||
|
print(f"Testing best model from Epoch {best_epoch}")
|
||||||
|
model.update(best_params)
|
||||||
|
test_data = prepare_dataset(args.batch_size, "test")
|
||||||
|
test_acc = test_epoch(model, test_data)
|
||||||
|
print(f"Test acc -> {test_acc.item():.3f}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.cpu:
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
main(args)
|
2
speechcommands/requirements.txt
Normal file
2
speechcommands/requirements.txt
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
mlx==0.0.5
|
||||||
|
mlx-data
|
Loading…
Reference in New Issue
Block a user