mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
some updates / simplifications
This commit is contained in:
parent
b4ac7cc1df
commit
992f5cc0fa
@ -1,19 +1,23 @@
|
||||
# Training a Vision Transformer on SpeechCommands
|
||||
|
||||
An example of training [Keyword Spotting Transformer](https://www.isca-speech.org/archive/interspeech_2021/berg21_interspeech.html), a variant of the Vision Transformer, on the [Speech Commands](https://arxiv.org/abs/1804.03209) (v0.02) dataset with MLX. All supervised only configurations from the paper are available.The example also
|
||||
illustrates how to use [MLX Data](https://github.com/ml-explore/mlx-data) to
|
||||
load and process an audio dataset.
|
||||
An example of training a Keyword Spotting Transformer[^1] on the Speech
|
||||
Commands dataset[^2] with MLX. All supervised only configurations from the
|
||||
paper are available.The example also illustrates how to use [MLX
|
||||
Data](https://github.com/ml-explore/mlx-data) to load and process an audio
|
||||
dataset.
|
||||
|
||||
## Pre-requisites
|
||||
|
||||
Install `mlx`
|
||||
Follow the [installation
|
||||
instructions](https://ml-explore.github.io/mlx-data/build/html/install.html)
|
||||
for MLX Data.
|
||||
|
||||
Install the remaining python requirements:
|
||||
|
||||
```
|
||||
pip install mlx==0.0.5
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
At the time of writing, the SpeechCommands dataset is not yet a part of a `mlx-data` release. Install `mlx-data` from source from this [commit](https://github.com/ml-explore/mlx-data/commit/ae3431648b8e1594d63175a8f121d9873aeb9daa).
|
||||
|
||||
## Running the example
|
||||
|
||||
Run the example with:
|
||||
@ -22,7 +26,7 @@ Run the example with:
|
||||
python main.py
|
||||
```
|
||||
|
||||
By default the example runs on the GPU. To run on the CPU, use:
|
||||
By default the example runs on the GPU. To run it on the CPU, use:
|
||||
|
||||
```
|
||||
python main.py --cpu
|
||||
@ -56,5 +60,10 @@ Test acc -> 0.718
|
||||
|
||||
Note that this was run on an M1 Macbook Pro with 16GB RAM.
|
||||
|
||||
At the time of writing, `mlx` doesn't have built-in `cosine` learning rate schedules, which is used along with the AdamW optimizer in the official implementaiton. We intend to update this example once these features
|
||||
are added, as well as with appropriate data augmentations.
|
||||
At the time of writing, `mlx` doesn't have built-in `cosine` learning rate
|
||||
schedules, which is used along with the AdamW optimizer in the official
|
||||
implementaiton. We intend to update this example once these features are added,
|
||||
as well as with appropriate data augmentations.
|
||||
|
||||
[^1]: Based one the paper [Keyword Transformer: A Self-Attention Model for Keyword Spotting](https://www.isca-speech.org/archive/interspeech_2021/berg21_interspeech.html)
|
||||
[^2]: We use version 0.02. See the [paper]((https://arxiv.org/abs/1804.03209) for more details.
|
||||
|
@ -8,28 +8,16 @@ __all__ = ["KWT", "kwt1", "kwt2", "kwt3"]
|
||||
STD = 0.02
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
class FeedForward(nn.Sequential):
|
||||
def __init__(self, dim, hidden_dim, dropout=0.0):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
super().__init__(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout) if dropout != 0.0 else Identity(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout) if dropout != 0.0 else Identity(),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads, dropout=0.0):
|
||||
@ -38,17 +26,17 @@ class Attention(nn.Module):
|
||||
self.scale = dim**-0.5
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
||||
self.out = nn.Sequential(
|
||||
nn.Linear(dim, dim), nn.Dropout(dropout) if dropout != 0.0 else Identity()
|
||||
nn.Linear(dim, dim), nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
qkv = self.qkv(x)
|
||||
qkv = qkv.reshape(b, n, 3, h, -1).transpose((2, 0, 3, 1, 4))
|
||||
qkv = qkv.reshape(b, n, 3, h, -1).transpose(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv
|
||||
attn = (q @ k.transpose(0, 1, 3, 2)) * self.scale
|
||||
attn = mx.softmax(attn, axis=-1)
|
||||
x = (attn @ v).transpose((0, 2, 1, 3)).reshape(b, n, -1)
|
||||
x = (attn @ v).transpose(0, 2, 1, 3).reshape(b, n, -1)
|
||||
x = self.out(x)
|
||||
return x
|
||||
|
||||
@ -56,7 +44,6 @@ class Attention(nn.Module):
|
||||
class Block(nn.Module):
|
||||
def __init__(self, dim, heads, mlp_dim, dropout=0.0):
|
||||
super().__init__()
|
||||
# self.attn = nn.MultiHeadAttention(dim, heads)
|
||||
self.attn = Attention(dim, heads, dropout=dropout)
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.ff = FeedForward(dim, mlp_dim, dropout=dropout)
|
||||
@ -89,8 +76,9 @@ class KWT(nn.Module):
|
||||
Implements the Keyword Transformer (KWT) [1] model.
|
||||
|
||||
KWT is essentially a vision transformer [2] with minor modifications:
|
||||
- Instead of square patches, KWT uses rectangular patches -> a patch across frequency for every timestep
|
||||
- KWT modules apply LayerNormalization after attention/feedforward layers, also referred to as PostNorm
|
||||
- Instead of square patches, KWT uses rectangular patches -> a patch
|
||||
across frequency for every timestep
|
||||
- KWT modules apply layer normalization after attention/feedforward layers
|
||||
|
||||
[1] https://arxiv.org/abs/2104.11178
|
||||
[2] https://arxiv.org/abs/2010.11929
|
||||
@ -148,7 +136,7 @@ class KWT(nn.Module):
|
||||
-1 * STD / 2, STD / 2, (1, self.num_patches + 1, dim)
|
||||
)
|
||||
self.cls_token = mx.random.truncated_normal(-1 * STD / 2, STD / 2, (1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout) if emb_dropout != 0.0 else Identity()
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
|
||||
self.pool = pool
|
||||
self.mlp_head = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))
|
||||
|
@ -67,6 +67,7 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
||||
accs = []
|
||||
samples_per_sec = []
|
||||
|
||||
model.train(True)
|
||||
for batch_counter, batch in enumerate(train_iter):
|
||||
x = mx.array(batch["audio"])
|
||||
y = mx.array(batch["label"])
|
||||
@ -92,6 +93,7 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
||||
)
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
mean_tr_loss = mx.mean(mx.array(losses))
|
||||
mean_tr_acc = mx.mean(mx.array(accs))
|
||||
@ -100,13 +102,13 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
||||
|
||||
|
||||
def test_epoch(model, test_iter):
|
||||
model.train(False)
|
||||
accs = []
|
||||
for batch_counter, batch in enumerate(test_iter):
|
||||
x = mx.array(batch["audio"])
|
||||
y = mx.array(batch["label"])
|
||||
acc = eval_fn(model, x, y)
|
||||
acc_value = acc.item()
|
||||
accs.append(acc_value)
|
||||
accs.append(acc.item())
|
||||
mean_acc = mx.mean(mx.array(accs))
|
||||
return mean_acc
|
||||
|
||||
@ -146,7 +148,7 @@ def main(args):
|
||||
best_acc = val_acc
|
||||
best_epoch = epoch
|
||||
best_params = model.parameters()
|
||||
print(f"Testing best model from Epoch {best_epoch}")
|
||||
print(f"Testing best model from epoch {best_epoch}")
|
||||
model.update(best_params)
|
||||
test_data = prepare_dataset(args.batch_size, "test")
|
||||
test_acc = test_epoch(model, test_data)
|
||||
|
@ -1,2 +1 @@
|
||||
mlx==0.0.5
|
||||
mlx-data
|
||||
mlx>=0.0.5
|
||||
|
Loading…
Reference in New Issue
Block a user