From 992f5cc0fa6ff09b55e601f94c3473c8d6e6b96c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 18 Dec 2023 21:54:19 -0800 Subject: [PATCH] some updates / simplifications --- speechcommands/README.md | 29 ++++++++++++++++++---------- speechcommands/kwt.py | 34 +++++++++++---------------------- speechcommands/main.py | 8 +++++--- speechcommands/requirements.txt | 3 +-- 4 files changed, 36 insertions(+), 38 deletions(-) diff --git a/speechcommands/README.md b/speechcommands/README.md index f63d9da6..9e482b94 100644 --- a/speechcommands/README.md +++ b/speechcommands/README.md @@ -1,19 +1,23 @@ # Training a Vision Transformer on SpeechCommands -An example of training [Keyword Spotting Transformer](https://www.isca-speech.org/archive/interspeech_2021/berg21_interspeech.html), a variant of the Vision Transformer, on the [Speech Commands](https://arxiv.org/abs/1804.03209) (v0.02) dataset with MLX. All supervised only configurations from the paper are available.The example also -illustrates how to use [MLX Data](https://github.com/ml-explore/mlx-data) to -load and process an audio dataset. +An example of training a Keyword Spotting Transformer[^1] on the Speech +Commands dataset[^2] with MLX. All supervised only configurations from the +paper are available.The example also illustrates how to use [MLX +Data](https://github.com/ml-explore/mlx-data) to load and process an audio +dataset. ## Pre-requisites -Install `mlx` +Follow the [installation +instructions](https://ml-explore.github.io/mlx-data/build/html/install.html) +for MLX Data. + +Install the remaining python requirements: ``` -pip install mlx==0.0.5 +pip install -r requirements.txt ``` -At the time of writing, the SpeechCommands dataset is not yet a part of a `mlx-data` release. Install `mlx-data` from source from this [commit](https://github.com/ml-explore/mlx-data/commit/ae3431648b8e1594d63175a8f121d9873aeb9daa). - ## Running the example Run the example with: @@ -22,7 +26,7 @@ Run the example with: python main.py ``` -By default the example runs on the GPU. To run on the CPU, use: +By default the example runs on the GPU. To run it on the CPU, use: ``` python main.py --cpu @@ -56,5 +60,10 @@ Test acc -> 0.718 Note that this was run on an M1 Macbook Pro with 16GB RAM. -At the time of writing, `mlx` doesn't have built-in `cosine` learning rate schedules, which is used along with the AdamW optimizer in the official implementaiton. We intend to update this example once these features -are added, as well as with appropriate data augmentations. \ No newline at end of file +At the time of writing, `mlx` doesn't have built-in `cosine` learning rate +schedules, which is used along with the AdamW optimizer in the official +implementaiton. We intend to update this example once these features are added, +as well as with appropriate data augmentations. + +[^1]: Based one the paper [Keyword Transformer: A Self-Attention Model for Keyword Spotting](https://www.isca-speech.org/archive/interspeech_2021/berg21_interspeech.html) +[^2]: We use version 0.02. See the [paper]((https://arxiv.org/abs/1804.03209) for more details. diff --git a/speechcommands/kwt.py b/speechcommands/kwt.py index b1ea7608..40b4c71d 100644 --- a/speechcommands/kwt.py +++ b/speechcommands/kwt.py @@ -8,28 +8,16 @@ __all__ = ["KWT", "kwt1", "kwt2", "kwt3"] STD = 0.02 -class FeedForward(nn.Module): +class FeedForward(nn.Sequential): def __init__(self, dim, hidden_dim, dropout=0.0): - super().__init__() - self.net = nn.Sequential( + super().__init__( nn.Linear(dim, hidden_dim), nn.GELU(), - nn.Dropout(dropout) if dropout != 0.0 else Identity(), + nn.Dropout(dropout), nn.Linear(hidden_dim, dim), - nn.Dropout(dropout) if dropout != 0.0 else Identity(), + nn.Dropout(dropout), ) - def __call__(self, x): - return self.net(x) - - -class Identity(nn.Module): - def __init__(self): - super().__init__() - - def __call__(self, x): - return x - class Attention(nn.Module): def __init__(self, dim, heads, dropout=0.0): @@ -38,17 +26,17 @@ class Attention(nn.Module): self.scale = dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=False) self.out = nn.Sequential( - nn.Linear(dim, dim), nn.Dropout(dropout) if dropout != 0.0 else Identity() + nn.Linear(dim, dim), nn.Dropout(dropout) ) def __call__(self, x): b, n, _, h = *x.shape, self.heads qkv = self.qkv(x) - qkv = qkv.reshape(b, n, 3, h, -1).transpose((2, 0, 3, 1, 4)) + qkv = qkv.reshape(b, n, 3, h, -1).transpose(2, 0, 3, 1, 4) q, k, v = qkv attn = (q @ k.transpose(0, 1, 3, 2)) * self.scale attn = mx.softmax(attn, axis=-1) - x = (attn @ v).transpose((0, 2, 1, 3)).reshape(b, n, -1) + x = (attn @ v).transpose(0, 2, 1, 3).reshape(b, n, -1) x = self.out(x) return x @@ -56,7 +44,6 @@ class Attention(nn.Module): class Block(nn.Module): def __init__(self, dim, heads, mlp_dim, dropout=0.0): super().__init__() - # self.attn = nn.MultiHeadAttention(dim, heads) self.attn = Attention(dim, heads, dropout=dropout) self.norm1 = nn.LayerNorm(dim) self.ff = FeedForward(dim, mlp_dim, dropout=dropout) @@ -89,8 +76,9 @@ class KWT(nn.Module): Implements the Keyword Transformer (KWT) [1] model. KWT is essentially a vision transformer [2] with minor modifications: - - Instead of square patches, KWT uses rectangular patches -> a patch across frequency for every timestep - - KWT modules apply LayerNormalization after attention/feedforward layers, also referred to as PostNorm + - Instead of square patches, KWT uses rectangular patches -> a patch + across frequency for every timestep + - KWT modules apply layer normalization after attention/feedforward layers [1] https://arxiv.org/abs/2104.11178 [2] https://arxiv.org/abs/2010.11929 @@ -148,7 +136,7 @@ class KWT(nn.Module): -1 * STD / 2, STD / 2, (1, self.num_patches + 1, dim) ) self.cls_token = mx.random.truncated_normal(-1 * STD / 2, STD / 2, (1, 1, dim)) - self.dropout = nn.Dropout(emb_dropout) if emb_dropout != 0.0 else Identity() + self.dropout = nn.Dropout(emb_dropout) self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout) self.pool = pool self.mlp_head = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes)) diff --git a/speechcommands/main.py b/speechcommands/main.py index 35a905be..a02cb089 100644 --- a/speechcommands/main.py +++ b/speechcommands/main.py @@ -67,6 +67,7 @@ def train_epoch(model, train_iter, optimizer, epoch): accs = [] samples_per_sec = [] + model.train(True) for batch_counter, batch in enumerate(train_iter): x = mx.array(batch["audio"]) y = mx.array(batch["label"]) @@ -92,6 +93,7 @@ def train_epoch(model, train_iter, optimizer, epoch): ) ) ) + break mean_tr_loss = mx.mean(mx.array(losses)) mean_tr_acc = mx.mean(mx.array(accs)) @@ -100,13 +102,13 @@ def train_epoch(model, train_iter, optimizer, epoch): def test_epoch(model, test_iter): + model.train(False) accs = [] for batch_counter, batch in enumerate(test_iter): x = mx.array(batch["audio"]) y = mx.array(batch["label"]) acc = eval_fn(model, x, y) - acc_value = acc.item() - accs.append(acc_value) + accs.append(acc.item()) mean_acc = mx.mean(mx.array(accs)) return mean_acc @@ -146,7 +148,7 @@ def main(args): best_acc = val_acc best_epoch = epoch best_params = model.parameters() - print(f"Testing best model from Epoch {best_epoch}") + print(f"Testing best model from epoch {best_epoch}") model.update(best_params) test_data = prepare_dataset(args.batch_size, "test") test_acc = test_epoch(model, test_data) diff --git a/speechcommands/requirements.txt b/speechcommands/requirements.txt index 9c2ad53a..5ca13284 100644 --- a/speechcommands/requirements.txt +++ b/speechcommands/requirements.txt @@ -1,2 +1 @@ -mlx==0.0.5 -mlx-data \ No newline at end of file +mlx>=0.0.5