mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
some updates / simplifications
This commit is contained in:
parent
b4ac7cc1df
commit
992f5cc0fa
@ -1,19 +1,23 @@
|
|||||||
# Training a Vision Transformer on SpeechCommands
|
# Training a Vision Transformer on SpeechCommands
|
||||||
|
|
||||||
An example of training [Keyword Spotting Transformer](https://www.isca-speech.org/archive/interspeech_2021/berg21_interspeech.html), a variant of the Vision Transformer, on the [Speech Commands](https://arxiv.org/abs/1804.03209) (v0.02) dataset with MLX. All supervised only configurations from the paper are available.The example also
|
An example of training a Keyword Spotting Transformer[^1] on the Speech
|
||||||
illustrates how to use [MLX Data](https://github.com/ml-explore/mlx-data) to
|
Commands dataset[^2] with MLX. All supervised only configurations from the
|
||||||
load and process an audio dataset.
|
paper are available.The example also illustrates how to use [MLX
|
||||||
|
Data](https://github.com/ml-explore/mlx-data) to load and process an audio
|
||||||
|
dataset.
|
||||||
|
|
||||||
## Pre-requisites
|
## Pre-requisites
|
||||||
|
|
||||||
Install `mlx`
|
Follow the [installation
|
||||||
|
instructions](https://ml-explore.github.io/mlx-data/build/html/install.html)
|
||||||
|
for MLX Data.
|
||||||
|
|
||||||
|
Install the remaining python requirements:
|
||||||
|
|
||||||
```
|
```
|
||||||
pip install mlx==0.0.5
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
At the time of writing, the SpeechCommands dataset is not yet a part of a `mlx-data` release. Install `mlx-data` from source from this [commit](https://github.com/ml-explore/mlx-data/commit/ae3431648b8e1594d63175a8f121d9873aeb9daa).
|
|
||||||
|
|
||||||
## Running the example
|
## Running the example
|
||||||
|
|
||||||
Run the example with:
|
Run the example with:
|
||||||
@ -22,7 +26,7 @@ Run the example with:
|
|||||||
python main.py
|
python main.py
|
||||||
```
|
```
|
||||||
|
|
||||||
By default the example runs on the GPU. To run on the CPU, use:
|
By default the example runs on the GPU. To run it on the CPU, use:
|
||||||
|
|
||||||
```
|
```
|
||||||
python main.py --cpu
|
python main.py --cpu
|
||||||
@ -56,5 +60,10 @@ Test acc -> 0.718
|
|||||||
|
|
||||||
Note that this was run on an M1 Macbook Pro with 16GB RAM.
|
Note that this was run on an M1 Macbook Pro with 16GB RAM.
|
||||||
|
|
||||||
At the time of writing, `mlx` doesn't have built-in `cosine` learning rate schedules, which is used along with the AdamW optimizer in the official implementaiton. We intend to update this example once these features
|
At the time of writing, `mlx` doesn't have built-in `cosine` learning rate
|
||||||
are added, as well as with appropriate data augmentations.
|
schedules, which is used along with the AdamW optimizer in the official
|
||||||
|
implementaiton. We intend to update this example once these features are added,
|
||||||
|
as well as with appropriate data augmentations.
|
||||||
|
|
||||||
|
[^1]: Based one the paper [Keyword Transformer: A Self-Attention Model for Keyword Spotting](https://www.isca-speech.org/archive/interspeech_2021/berg21_interspeech.html)
|
||||||
|
[^2]: We use version 0.02. See the [paper]((https://arxiv.org/abs/1804.03209) for more details.
|
||||||
|
@ -8,28 +8,16 @@ __all__ = ["KWT", "kwt1", "kwt2", "kwt3"]
|
|||||||
STD = 0.02
|
STD = 0.02
|
||||||
|
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
class FeedForward(nn.Sequential):
|
||||||
def __init__(self, dim, hidden_dim, dropout=0.0):
|
def __init__(self, dim, hidden_dim, dropout=0.0):
|
||||||
super().__init__()
|
super().__init__(
|
||||||
self.net = nn.Sequential(
|
|
||||||
nn.Linear(dim, hidden_dim),
|
nn.Linear(dim, hidden_dim),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Dropout(dropout) if dropout != 0.0 else Identity(),
|
nn.Dropout(dropout),
|
||||||
nn.Linear(hidden_dim, dim),
|
nn.Linear(hidden_dim, dim),
|
||||||
nn.Dropout(dropout) if dropout != 0.0 else Identity(),
|
nn.Dropout(dropout),
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, x):
|
|
||||||
return self.net(x)
|
|
||||||
|
|
||||||
|
|
||||||
class Identity(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def __call__(self, x):
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(self, dim, heads, dropout=0.0):
|
def __init__(self, dim, heads, dropout=0.0):
|
||||||
@ -38,17 +26,17 @@ class Attention(nn.Module):
|
|||||||
self.scale = dim**-0.5
|
self.scale = dim**-0.5
|
||||||
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
||||||
self.out = nn.Sequential(
|
self.out = nn.Sequential(
|
||||||
nn.Linear(dim, dim), nn.Dropout(dropout) if dropout != 0.0 else Identity()
|
nn.Linear(dim, dim), nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
b, n, _, h = *x.shape, self.heads
|
b, n, _, h = *x.shape, self.heads
|
||||||
qkv = self.qkv(x)
|
qkv = self.qkv(x)
|
||||||
qkv = qkv.reshape(b, n, 3, h, -1).transpose((2, 0, 3, 1, 4))
|
qkv = qkv.reshape(b, n, 3, h, -1).transpose(2, 0, 3, 1, 4)
|
||||||
q, k, v = qkv
|
q, k, v = qkv
|
||||||
attn = (q @ k.transpose(0, 1, 3, 2)) * self.scale
|
attn = (q @ k.transpose(0, 1, 3, 2)) * self.scale
|
||||||
attn = mx.softmax(attn, axis=-1)
|
attn = mx.softmax(attn, axis=-1)
|
||||||
x = (attn @ v).transpose((0, 2, 1, 3)).reshape(b, n, -1)
|
x = (attn @ v).transpose(0, 2, 1, 3).reshape(b, n, -1)
|
||||||
x = self.out(x)
|
x = self.out(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -56,7 +44,6 @@ class Attention(nn.Module):
|
|||||||
class Block(nn.Module):
|
class Block(nn.Module):
|
||||||
def __init__(self, dim, heads, mlp_dim, dropout=0.0):
|
def __init__(self, dim, heads, mlp_dim, dropout=0.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# self.attn = nn.MultiHeadAttention(dim, heads)
|
|
||||||
self.attn = Attention(dim, heads, dropout=dropout)
|
self.attn = Attention(dim, heads, dropout=dropout)
|
||||||
self.norm1 = nn.LayerNorm(dim)
|
self.norm1 = nn.LayerNorm(dim)
|
||||||
self.ff = FeedForward(dim, mlp_dim, dropout=dropout)
|
self.ff = FeedForward(dim, mlp_dim, dropout=dropout)
|
||||||
@ -89,8 +76,9 @@ class KWT(nn.Module):
|
|||||||
Implements the Keyword Transformer (KWT) [1] model.
|
Implements the Keyword Transformer (KWT) [1] model.
|
||||||
|
|
||||||
KWT is essentially a vision transformer [2] with minor modifications:
|
KWT is essentially a vision transformer [2] with minor modifications:
|
||||||
- Instead of square patches, KWT uses rectangular patches -> a patch across frequency for every timestep
|
- Instead of square patches, KWT uses rectangular patches -> a patch
|
||||||
- KWT modules apply LayerNormalization after attention/feedforward layers, also referred to as PostNorm
|
across frequency for every timestep
|
||||||
|
- KWT modules apply layer normalization after attention/feedforward layers
|
||||||
|
|
||||||
[1] https://arxiv.org/abs/2104.11178
|
[1] https://arxiv.org/abs/2104.11178
|
||||||
[2] https://arxiv.org/abs/2010.11929
|
[2] https://arxiv.org/abs/2010.11929
|
||||||
@ -148,7 +136,7 @@ class KWT(nn.Module):
|
|||||||
-1 * STD / 2, STD / 2, (1, self.num_patches + 1, dim)
|
-1 * STD / 2, STD / 2, (1, self.num_patches + 1, dim)
|
||||||
)
|
)
|
||||||
self.cls_token = mx.random.truncated_normal(-1 * STD / 2, STD / 2, (1, 1, dim))
|
self.cls_token = mx.random.truncated_normal(-1 * STD / 2, STD / 2, (1, 1, dim))
|
||||||
self.dropout = nn.Dropout(emb_dropout) if emb_dropout != 0.0 else Identity()
|
self.dropout = nn.Dropout(emb_dropout)
|
||||||
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
|
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
|
||||||
self.pool = pool
|
self.pool = pool
|
||||||
self.mlp_head = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))
|
self.mlp_head = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))
|
||||||
|
@ -67,6 +67,7 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
|||||||
accs = []
|
accs = []
|
||||||
samples_per_sec = []
|
samples_per_sec = []
|
||||||
|
|
||||||
|
model.train(True)
|
||||||
for batch_counter, batch in enumerate(train_iter):
|
for batch_counter, batch in enumerate(train_iter):
|
||||||
x = mx.array(batch["audio"])
|
x = mx.array(batch["audio"])
|
||||||
y = mx.array(batch["label"])
|
y = mx.array(batch["label"])
|
||||||
@ -92,6 +93,7 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
break
|
||||||
|
|
||||||
mean_tr_loss = mx.mean(mx.array(losses))
|
mean_tr_loss = mx.mean(mx.array(losses))
|
||||||
mean_tr_acc = mx.mean(mx.array(accs))
|
mean_tr_acc = mx.mean(mx.array(accs))
|
||||||
@ -100,13 +102,13 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
|||||||
|
|
||||||
|
|
||||||
def test_epoch(model, test_iter):
|
def test_epoch(model, test_iter):
|
||||||
|
model.train(False)
|
||||||
accs = []
|
accs = []
|
||||||
for batch_counter, batch in enumerate(test_iter):
|
for batch_counter, batch in enumerate(test_iter):
|
||||||
x = mx.array(batch["audio"])
|
x = mx.array(batch["audio"])
|
||||||
y = mx.array(batch["label"])
|
y = mx.array(batch["label"])
|
||||||
acc = eval_fn(model, x, y)
|
acc = eval_fn(model, x, y)
|
||||||
acc_value = acc.item()
|
accs.append(acc.item())
|
||||||
accs.append(acc_value)
|
|
||||||
mean_acc = mx.mean(mx.array(accs))
|
mean_acc = mx.mean(mx.array(accs))
|
||||||
return mean_acc
|
return mean_acc
|
||||||
|
|
||||||
@ -146,7 +148,7 @@ def main(args):
|
|||||||
best_acc = val_acc
|
best_acc = val_acc
|
||||||
best_epoch = epoch
|
best_epoch = epoch
|
||||||
best_params = model.parameters()
|
best_params = model.parameters()
|
||||||
print(f"Testing best model from Epoch {best_epoch}")
|
print(f"Testing best model from epoch {best_epoch}")
|
||||||
model.update(best_params)
|
model.update(best_params)
|
||||||
test_data = prepare_dataset(args.batch_size, "test")
|
test_data = prepare_dataset(args.batch_size, "test")
|
||||||
test_acc = test_epoch(model, test_data)
|
test_acc = test_epoch(model, test_data)
|
||||||
|
@ -1,2 +1 @@
|
|||||||
mlx==0.0.5
|
mlx>=0.0.5
|
||||||
mlx-data
|
|
||||||
|
Loading…
Reference in New Issue
Block a user