some updates / simplifications

This commit is contained in:
Awni Hannun 2023-12-18 21:54:19 -08:00
parent b4ac7cc1df
commit 992f5cc0fa
4 changed files with 36 additions and 38 deletions

View File

@ -1,19 +1,23 @@
# Training a Vision Transformer on SpeechCommands # Training a Vision Transformer on SpeechCommands
An example of training [Keyword Spotting Transformer](https://www.isca-speech.org/archive/interspeech_2021/berg21_interspeech.html), a variant of the Vision Transformer, on the [Speech Commands](https://arxiv.org/abs/1804.03209) (v0.02) dataset with MLX. All supervised only configurations from the paper are available.The example also An example of training a Keyword Spotting Transformer[^1] on the Speech
illustrates how to use [MLX Data](https://github.com/ml-explore/mlx-data) to Commands dataset[^2] with MLX. All supervised only configurations from the
load and process an audio dataset. paper are available.The example also illustrates how to use [MLX
Data](https://github.com/ml-explore/mlx-data) to load and process an audio
dataset.
## Pre-requisites ## Pre-requisites
Install `mlx` Follow the [installation
instructions](https://ml-explore.github.io/mlx-data/build/html/install.html)
for MLX Data.
Install the remaining python requirements:
``` ```
pip install mlx==0.0.5 pip install -r requirements.txt
``` ```
At the time of writing, the SpeechCommands dataset is not yet a part of a `mlx-data` release. Install `mlx-data` from source from this [commit](https://github.com/ml-explore/mlx-data/commit/ae3431648b8e1594d63175a8f121d9873aeb9daa).
## Running the example ## Running the example
Run the example with: Run the example with:
@ -22,7 +26,7 @@ Run the example with:
python main.py python main.py
``` ```
By default the example runs on the GPU. To run on the CPU, use: By default the example runs on the GPU. To run it on the CPU, use:
``` ```
python main.py --cpu python main.py --cpu
@ -56,5 +60,10 @@ Test acc -> 0.718
Note that this was run on an M1 Macbook Pro with 16GB RAM. Note that this was run on an M1 Macbook Pro with 16GB RAM.
At the time of writing, `mlx` doesn't have built-in `cosine` learning rate schedules, which is used along with the AdamW optimizer in the official implementaiton. We intend to update this example once these features At the time of writing, `mlx` doesn't have built-in `cosine` learning rate
are added, as well as with appropriate data augmentations. schedules, which is used along with the AdamW optimizer in the official
implementaiton. We intend to update this example once these features are added,
as well as with appropriate data augmentations.
[^1]: Based one the paper [Keyword Transformer: A Self-Attention Model for Keyword Spotting](https://www.isca-speech.org/archive/interspeech_2021/berg21_interspeech.html)
[^2]: We use version 0.02. See the [paper]((https://arxiv.org/abs/1804.03209) for more details.

View File

@ -8,28 +8,16 @@ __all__ = ["KWT", "kwt1", "kwt2", "kwt3"]
STD = 0.02 STD = 0.02
class FeedForward(nn.Module): class FeedForward(nn.Sequential):
def __init__(self, dim, hidden_dim, dropout=0.0): def __init__(self, dim, hidden_dim, dropout=0.0):
super().__init__() super().__init__(
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim), nn.Linear(dim, hidden_dim),
nn.GELU(), nn.GELU(),
nn.Dropout(dropout) if dropout != 0.0 else Identity(), nn.Dropout(dropout),
nn.Linear(hidden_dim, dim), nn.Linear(hidden_dim, dim),
nn.Dropout(dropout) if dropout != 0.0 else Identity(), nn.Dropout(dropout),
) )
def __call__(self, x):
return self.net(x)
class Identity(nn.Module):
def __init__(self):
super().__init__()
def __call__(self, x):
return x
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, dim, heads, dropout=0.0): def __init__(self, dim, heads, dropout=0.0):
@ -38,17 +26,17 @@ class Attention(nn.Module):
self.scale = dim**-0.5 self.scale = dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=False) self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.out = nn.Sequential( self.out = nn.Sequential(
nn.Linear(dim, dim), nn.Dropout(dropout) if dropout != 0.0 else Identity() nn.Linear(dim, dim), nn.Dropout(dropout)
) )
def __call__(self, x): def __call__(self, x):
b, n, _, h = *x.shape, self.heads b, n, _, h = *x.shape, self.heads
qkv = self.qkv(x) qkv = self.qkv(x)
qkv = qkv.reshape(b, n, 3, h, -1).transpose((2, 0, 3, 1, 4)) qkv = qkv.reshape(b, n, 3, h, -1).transpose(2, 0, 3, 1, 4)
q, k, v = qkv q, k, v = qkv
attn = (q @ k.transpose(0, 1, 3, 2)) * self.scale attn = (q @ k.transpose(0, 1, 3, 2)) * self.scale
attn = mx.softmax(attn, axis=-1) attn = mx.softmax(attn, axis=-1)
x = (attn @ v).transpose((0, 2, 1, 3)).reshape(b, n, -1) x = (attn @ v).transpose(0, 2, 1, 3).reshape(b, n, -1)
x = self.out(x) x = self.out(x)
return x return x
@ -56,7 +44,6 @@ class Attention(nn.Module):
class Block(nn.Module): class Block(nn.Module):
def __init__(self, dim, heads, mlp_dim, dropout=0.0): def __init__(self, dim, heads, mlp_dim, dropout=0.0):
super().__init__() super().__init__()
# self.attn = nn.MultiHeadAttention(dim, heads)
self.attn = Attention(dim, heads, dropout=dropout) self.attn = Attention(dim, heads, dropout=dropout)
self.norm1 = nn.LayerNorm(dim) self.norm1 = nn.LayerNorm(dim)
self.ff = FeedForward(dim, mlp_dim, dropout=dropout) self.ff = FeedForward(dim, mlp_dim, dropout=dropout)
@ -89,8 +76,9 @@ class KWT(nn.Module):
Implements the Keyword Transformer (KWT) [1] model. Implements the Keyword Transformer (KWT) [1] model.
KWT is essentially a vision transformer [2] with minor modifications: KWT is essentially a vision transformer [2] with minor modifications:
- Instead of square patches, KWT uses rectangular patches -> a patch across frequency for every timestep - Instead of square patches, KWT uses rectangular patches -> a patch
- KWT modules apply LayerNormalization after attention/feedforward layers, also referred to as PostNorm across frequency for every timestep
- KWT modules apply layer normalization after attention/feedforward layers
[1] https://arxiv.org/abs/2104.11178 [1] https://arxiv.org/abs/2104.11178
[2] https://arxiv.org/abs/2010.11929 [2] https://arxiv.org/abs/2010.11929
@ -148,7 +136,7 @@ class KWT(nn.Module):
-1 * STD / 2, STD / 2, (1, self.num_patches + 1, dim) -1 * STD / 2, STD / 2, (1, self.num_patches + 1, dim)
) )
self.cls_token = mx.random.truncated_normal(-1 * STD / 2, STD / 2, (1, 1, dim)) self.cls_token = mx.random.truncated_normal(-1 * STD / 2, STD / 2, (1, 1, dim))
self.dropout = nn.Dropout(emb_dropout) if emb_dropout != 0.0 else Identity() self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout) self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
self.pool = pool self.pool = pool
self.mlp_head = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes)) self.mlp_head = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))

View File

@ -67,6 +67,7 @@ def train_epoch(model, train_iter, optimizer, epoch):
accs = [] accs = []
samples_per_sec = [] samples_per_sec = []
model.train(True)
for batch_counter, batch in enumerate(train_iter): for batch_counter, batch in enumerate(train_iter):
x = mx.array(batch["audio"]) x = mx.array(batch["audio"])
y = mx.array(batch["label"]) y = mx.array(batch["label"])
@ -92,6 +93,7 @@ def train_epoch(model, train_iter, optimizer, epoch):
) )
) )
) )
break
mean_tr_loss = mx.mean(mx.array(losses)) mean_tr_loss = mx.mean(mx.array(losses))
mean_tr_acc = mx.mean(mx.array(accs)) mean_tr_acc = mx.mean(mx.array(accs))
@ -100,13 +102,13 @@ def train_epoch(model, train_iter, optimizer, epoch):
def test_epoch(model, test_iter): def test_epoch(model, test_iter):
model.train(False)
accs = [] accs = []
for batch_counter, batch in enumerate(test_iter): for batch_counter, batch in enumerate(test_iter):
x = mx.array(batch["audio"]) x = mx.array(batch["audio"])
y = mx.array(batch["label"]) y = mx.array(batch["label"])
acc = eval_fn(model, x, y) acc = eval_fn(model, x, y)
acc_value = acc.item() accs.append(acc.item())
accs.append(acc_value)
mean_acc = mx.mean(mx.array(accs)) mean_acc = mx.mean(mx.array(accs))
return mean_acc return mean_acc
@ -146,7 +148,7 @@ def main(args):
best_acc = val_acc best_acc = val_acc
best_epoch = epoch best_epoch = epoch
best_params = model.parameters() best_params = model.parameters()
print(f"Testing best model from Epoch {best_epoch}") print(f"Testing best model from epoch {best_epoch}")
model.update(best_params) model.update(best_params)
test_data = prepare_dataset(args.batch_size, "test") test_data = prepare_dataset(args.batch_size, "test")
test_acc = test_epoch(model, test_data) test_acc = test_epoch(model, test_data)

View File

@ -1,2 +1 @@
mlx==0.0.5 mlx>=0.0.5
mlx-data