From b6e62caf2e13e0492a6d84aa02166795e434bf78 Mon Sep 17 00:00:00 2001 From: Sarthak Yadav Date: Tue, 19 Dec 2023 23:17:48 +0100 Subject: [PATCH] Added Keyword Spotting Transformer + SpeechCommands example (#123) * Added Keyword Transformer + SpeechCommands * minor fixes in README * some updates / simplifications * nits * fixed kwt skip connections * readme + format * updated acknowledgements --------- Co-authored-by: Awni Hannun --- ACKNOWLEDGMENTS.md | 1 + speechcommands/README.md | 69 ++++++++++ speechcommands/kwt.py | 214 ++++++++++++++++++++++++++++++++ speechcommands/main.py | 168 +++++++++++++++++++++++++ speechcommands/requirements.txt | 1 + 5 files changed, 453 insertions(+) create mode 100644 speechcommands/README.md create mode 100644 speechcommands/kwt.py create mode 100644 speechcommands/main.py create mode 100644 speechcommands/requirements.txt diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 3c368003..b46a8283 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -8,3 +8,4 @@ with a short description of your contribution(s) below. For example: MLX Examples was developed with contributions from the following individuals: - Juarez Bochi: Added support for T5 models. +- Sarthak Yadav: Added the `cifar` and `speechcommands` examples. \ No newline at end of file diff --git a/speechcommands/README.md b/speechcommands/README.md new file mode 100644 index 00000000..bcd3a325 --- /dev/null +++ b/speechcommands/README.md @@ -0,0 +1,69 @@ +# Train a Keyword Spotting Transformer on Speech Commands + +An example of training a Keyword Spotting Transformer[^1] on the Speech +Commands dataset[^2] with MLX. All supervised only configurations from the +paper are available. The example also illustrates how to use [MLX +Data](https://github.com/ml-explore/mlx-data) to load and process an audio +dataset. + +## Pre-requisites + +Follow the [installation +instructions](https://ml-explore.github.io/mlx-data/build/html/install.html) +for MLX Data. + +Install the remaining python requirements: + +``` +pip install -r requirements.txt +``` + +## Running the example + +Run the example with: + +``` +python main.py +``` + +By default the example runs on the GPU. To run it on the CPU, use: + +``` +python main.py --cpu +``` + +For all available options, run: + +``` +python main.py --help +``` + +## Results + +After training with the `kwt1` architecture for 10 epochs, you +should see the following results: + +``` +Epoch: 9 | avg. Train loss 0.519 | avg. Train acc 0.857 | Throughput: 661.28 samples/sec +Epoch: 9 | Val acc 0.861 | Throughput: 2976.54 samples/sec +Testing best model from epoch 9 +Test acc -> 0.841 +``` + +For the `kwt2` model, you should see: +``` +Epoch: 9 | avg. Train loss 0.374 | avg. Train acc 0.895 | Throughput: 395.26 samples/sec +Epoch: 9 | Val acc 0.879 | Throughput: 1542.44 samples/sec +Testing best model from epoch 9 +Test acc -> 0.861 +``` + +Note that this was run on an M1 Macbook Pro with 16GB RAM. + +At the time of writing, `mlx` doesn't have built-in `cosine` learning rate +schedules, which is used along with the AdamW optimizer in the official +implementation. We intend to update this example once these features are added, +as well as with appropriate data augmentations. + +[^1]: Based one the paper [Keyword Transformer: A Self-Attention Model for Keyword Spotting](https://www.isca-speech.org/archive/interspeech_2021/berg21_interspeech.html) +[^2]: We use version 0.02. See the [paper]((https://arxiv.org/abs/1804.03209) for more details. diff --git a/speechcommands/kwt.py b/speechcommands/kwt.py new file mode 100644 index 00000000..63d4e074 --- /dev/null +++ b/speechcommands/kwt.py @@ -0,0 +1,214 @@ +from typing import Any +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_flatten + + +__all__ = ["KWT", "kwt1", "kwt2", "kwt3"] + + +class FeedForward(nn.Sequential): + def __init__(self, dim, hidden_dim, dropout=0.0): + super().__init__( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout), + ) + + +class Attention(nn.Module): + def __init__(self, dim, heads, dropout=0.0): + super().__init__() + self.heads = heads + self.scale = dim**-0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=False) + self.out = nn.Sequential(nn.Linear(dim, dim), nn.Dropout(dropout)) + + def __call__(self, x): + b, n, _, h = *x.shape, self.heads + qkv = self.qkv(x) + qkv = qkv.reshape(b, n, 3, h, -1).transpose(2, 0, 3, 1, 4) + q, k, v = qkv + attn = (q @ k.transpose(0, 1, 3, 2)) * self.scale + attn = mx.softmax(attn, axis=-1) + x = (attn @ v).transpose(0, 2, 1, 3).reshape(b, n, -1) + x = self.out(x) + return x + + +class Block(nn.Module): + def __init__(self, dim, heads, mlp_dim, dropout=0.0): + super().__init__() + self.attn = Attention(dim, heads, dropout=dropout) + self.norm1 = nn.LayerNorm(dim) + self.ff = FeedForward(dim, mlp_dim, dropout=dropout) + self.norm2 = nn.LayerNorm(dim) + + def __call__(self, x): + x = self.norm1(self.attn(x)) + x + x = self.norm2(self.ff(x)) + x + return x + + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, mlp_dim, dropout=0.0): + super().__init__() + + self.layers = [] + for _ in range(depth): + self.layers.append(Block(dim, heads, mlp_dim, dropout=dropout)) + + def __call__(self, x): + for layer in self.layers: + x = layer(x) + return x + + +class KWT(nn.Module): + """ + Implements the Keyword Transformer (KWT) [1] model. + + KWT is essentially a vision transformer [2] with minor modifications: + - Instead of square patches, KWT uses rectangular patches -> a patch + across frequency for every timestep + - KWT modules apply layer normalization after attention/feedforward layers + + [1] https://arxiv.org/abs/2104.11178 + [2] https://arxiv.org/abs/2010.11929 + + Parameters + ---------- + input_res: tuple of ints + Input resolution (time, frequency) + patch_res: tuple of ints + Patch resolution (time, frequency) + num_classes: int + Number of classes + dim: int + Model Embedding dimension + depth: int + Number of transformer layers + heads: int + Number of attention heads + mlp_dim: int + Feedforward hidden dimension + pool: str + Pooling type, either "cls" or "mean" + in_channels: int, optional + Number of input channels + dropout: float, optional + Dropout rate + emb_dropout: float, optional + Embedding dropout rate + """ + + def __init__( + self, + input_res, + patch_res, + num_classes, + dim, + depth, + heads, + mlp_dim, + pool="mean", + in_channels=1, + dropout=0.0, + emb_dropout=0.0, + ): + super().__init__() + self.num_patches = int( + (input_res[0] / patch_res[0]) * (input_res[1] / patch_res[1]) + ) + self.dim = dim + + self.patch_embedding = nn.Conv2d( + in_channels, dim, kernel_size=patch_res, stride=patch_res + ) + self.pos_embedding = mx.random.truncated_normal( + -0.01, + 0.01, + (self.num_patches + 1, dim), + ) + self.cls_token = mx.random.truncated_normal(-0.01, 0.01, (dim,)) + self.dropout = nn.Dropout(emb_dropout) + self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout) + self.pool = pool + self.mlp_head = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes)) + + def num_params(self): + nparams = sum(x.size for k, x in tree_flatten(self.parameters())) + return nparams + + def __call__(self, x): + if x.ndim != 4: + x = mx.expand_dims(x, axis=-1) + x = self.patch_embedding(x) + x = x.reshape(x.shape[0], -1, self.dim) + assert x.shape[1] == self.num_patches + + cls_tokens = mx.broadcast_to(self.cls_token, (x.shape[0], 1, self.dim)) + x = mx.concatenate((cls_tokens, x), axis=1) + + x = x + self.pos_embedding + + x = self.dropout(x) + x = self.transformer(x) + x = x.mean(axis=1) if self.pool == "mean" else x[:, 0] + x = self.mlp_head(x) + return x + + +def parse_kwt_args(**kwargs): + input_res = kwargs.pop("input_res", [98, 40]) + patch_res = kwargs.pop("patch_res", [1, 40]) + num_classes = kwargs.pop("num_classes", 35) + emb_dropout = kwargs.pop("emb_dropout", 0.1) + return input_res, patch_res, num_classes, emb_dropout, kwargs + + +def kwt1(**kwargs): + input_res, patch_res, num_classes, emb_dropout, kwargs = parse_kwt_args(**kwargs) + return KWT( + input_res, + patch_res, + num_classes, + dim=64, + depth=12, + heads=1, + mlp_dim=256, + emb_dropout=emb_dropout, + **kwargs + ) + + +def kwt2(**kwargs): + input_res, patch_res, num_classes, emb_dropout, kwargs = parse_kwt_args(**kwargs) + return KWT( + input_res, + patch_res, + num_classes, + dim=128, + depth=12, + heads=2, + mlp_dim=512, + emb_dropout=emb_dropout, + **kwargs + ) + + +def kwt3(**kwargs): + input_res, patch_res, num_classes, emb_dropout, kwargs = parse_kwt_args(**kwargs) + return KWT( + input_res, + patch_res, + num_classes, + dim=192, + depth=12, + heads=3, + mlp_dim=768, + emb_dropout=emb_dropout, + **kwargs + ) diff --git a/speechcommands/main.py b/speechcommands/main.py new file mode 100644 index 00000000..492b8159 --- /dev/null +++ b/speechcommands/main.py @@ -0,0 +1,168 @@ +import argparse +import time +import kwt +import mlx.nn as nn +import mlx.data as dx +import mlx.core as mx +import mlx.optimizers as optim +from mlx.data.features import mfsc +from mlx.data.datasets import load_speechcommands + + +parser = argparse.ArgumentParser(add_help=True) +parser.add_argument( + "--arch", + type=str, + default="kwt1", + choices=[f"kwt{d}" for d in [1, 2, 3]], + help="model architecture", +) +parser.add_argument("--batch_size", type=int, default=256, help="batch size") +parser.add_argument("--epochs", type=int, default=100, help="number of epochs") +parser.add_argument("--lr", type=float, default=1e-3, help="learning rate") +parser.add_argument("--seed", type=int, default=0, help="random seed") +parser.add_argument("--cpu", action="store_true", help="use cpu only") + + +def prepare_dataset(batch_size, split, root=None): + def normalize(x): + return (x - x.mean()) / x.std() + + data = load_speechcommands(split=split, root=root) + + data_iter = ( + data.squeeze("audio") + .key_transform( + "audio", + mfsc( + 40, + 16000, + frame_size_ms=30, + frame_stride_ms=10, + high_freq=7600, + low_freq=20, + ), + ) + .key_transform("audio", normalize) + .shuffle() + .batch(batch_size) + ) + return data_iter + + +def eval_fn(model, inp, tgt): + return mx.mean(mx.argmax(model(inp), axis=1) == tgt) + + +def train_epoch(model, train_iter, optimizer, epoch): + def train_step(model, inp, tgt): + output = model(inp) + loss = mx.mean(nn.losses.cross_entropy(output, tgt)) + acc = mx.mean(mx.argmax(output, axis=1) == tgt) + return loss, acc + + train_step_fn = nn.value_and_grad(model, train_step) + + losses = [] + accs = [] + samples_per_sec = [] + + model.train(True) + for batch_counter, batch in enumerate(train_iter): + x = mx.array(batch["audio"]) + y = mx.array(batch["label"]) + tic = time.perf_counter() + (loss, acc), grads = train_step_fn(model, x, y) + optimizer.update(model, grads) + mx.eval(model.parameters(), optimizer.state) + toc = time.perf_counter() + loss = loss.item() + acc = acc.item() + losses.append(loss) + accs.append(acc) + throughput = x.shape[0] / (toc - tic) + samples_per_sec.append(throughput) + if batch_counter % 25 == 0: + print( + " | ".join( + ( + f"Epoch {epoch:02d} [{batch_counter:03d}]", + f"Train loss {loss:.3f}", + f"Train acc {acc:.3f}", + f"Throughput: {throughput:.2f} samples/second", + ) + ) + ) + + mean_tr_loss = mx.mean(mx.array(losses)) + mean_tr_acc = mx.mean(mx.array(accs)) + samples_per_sec = mx.mean(mx.array(samples_per_sec)) + return mean_tr_loss, mean_tr_acc, samples_per_sec + + +def test_epoch(model, test_iter): + model.train(False) + accs = [] + throughput = [] + for batch_counter, batch in enumerate(test_iter): + x = mx.array(batch["audio"]) + y = mx.array(batch["label"]) + tic = time.perf_counter() + acc = eval_fn(model, x, y) + accs.append(acc.item()) + toc = time.perf_counter() + throughput.append(x.shape[0] / (toc - tic)) + mean_acc = mx.mean(mx.array(accs)) + mean_throughput = mx.mean(mx.array(throughput)) + return mean_acc, mean_throughput + + +def main(args): + mx.random.seed(args.seed) + + model = getattr(kwt, args.arch)() + + print("Number of params: {:0.04f} M".format(model.num_params() / 1e6)) + + optimizer = optim.SGD(learning_rate=args.lr, momentum=0.9, weight_decay=1e-4) + + train_data = prepare_dataset(args.batch_size, "train") + val_data = prepare_dataset(args.batch_size, "validation") + + best_params = None + best_acc = 0.0 + best_epoch = 0 + for epoch in range(args.epochs): + tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch) + print( + " | ".join( + ( + f"Epoch: {epoch}", + f"avg. Train loss {tr_loss.item():.3f}", + f"avg. Train acc {tr_acc.item():.3f}", + f"Throughput: {throughput.item():.2f} samples/sec", + ) + ) + ) + + val_acc, val_throughput = test_epoch(model, val_data) + print( + f"Epoch: {epoch} | Val acc {val_acc.item():.3f} | Throughput: {val_throughput.item():.2f} samples/sec" + ) + + if val_acc >= best_acc: + best_acc = val_acc + best_epoch = epoch + best_params = model.parameters() + print(f"Testing best model from epoch {best_epoch}") + model.update(best_params) + test_data = prepare_dataset(args.batch_size, "test") + test_acc, _ = test_epoch(model, test_data) + print(f"Test acc -> {test_acc.item():.3f}") + + +if __name__ == "__main__": + args = parser.parse_args() + if args.cpu: + mx.set_default_device(mx.cpu) + main(args) diff --git a/speechcommands/requirements.txt b/speechcommands/requirements.txt new file mode 100644 index 00000000..5ca13284 --- /dev/null +++ b/speechcommands/requirements.txt @@ -0,0 +1 @@ +mlx>=0.0.5