mlx-examples/speechcommands/kwt.py

213 lines
5.8 KiB
Python
Raw Permalink Normal View History

import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten
__all__ = ["KWT", "kwt1", "kwt2", "kwt3"]
class FeedForward(nn.Sequential):
def __init__(self, dim, hidden_dim, dropout=0.0):
super().__init__(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout),
)
class Attention(nn.Module):
def __init__(self, dim, heads, dropout=0.0):
super().__init__()
self.heads = heads
self.scale = dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.out = nn.Sequential(nn.Linear(dim, dim), nn.Dropout(dropout))
def __call__(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.qkv(x)
qkv = qkv.reshape(b, n, 3, h, -1).transpose(2, 0, 3, 1, 4)
q, k, v = qkv
attn = (q @ k.transpose(0, 1, 3, 2)) * self.scale
attn = mx.softmax(attn, axis=-1)
x = (attn @ v).transpose(0, 2, 1, 3).reshape(b, n, -1)
x = self.out(x)
return x
class Block(nn.Module):
def __init__(self, dim, heads, mlp_dim, dropout=0.0):
super().__init__()
self.attn = Attention(dim, heads, dropout=dropout)
self.norm1 = nn.LayerNorm(dim)
self.ff = FeedForward(dim, mlp_dim, dropout=dropout)
self.norm2 = nn.LayerNorm(dim)
def __call__(self, x):
x = self.norm1(self.attn(x)) + x
x = self.norm2(self.ff(x)) + x
return x
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, mlp_dim, dropout=0.0):
super().__init__()
self.layers = []
for _ in range(depth):
self.layers.append(Block(dim, heads, mlp_dim, dropout=dropout))
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x
class KWT(nn.Module):
"""
Implements the Keyword Transformer (KWT) [1] model.
KWT is essentially a vision transformer [2] with minor modifications:
- Instead of square patches, KWT uses rectangular patches -> a patch
across frequency for every timestep
- KWT modules apply layer normalization after attention/feedforward layers
[1] https://arxiv.org/abs/2104.11178
[2] https://arxiv.org/abs/2010.11929
Parameters
----------
input_res: tuple of ints
Input resolution (time, frequency)
patch_res: tuple of ints
Patch resolution (time, frequency)
num_classes: int
Number of classes
dim: int
Model Embedding dimension
depth: int
Number of transformer layers
heads: int
Number of attention heads
mlp_dim: int
Feedforward hidden dimension
pool: str
Pooling type, either "cls" or "mean"
in_channels: int, optional
Number of input channels
dropout: float, optional
Dropout rate
emb_dropout: float, optional
Embedding dropout rate
"""
def __init__(
self,
input_res,
patch_res,
num_classes,
dim,
depth,
heads,
mlp_dim,
pool="mean",
in_channels=1,
dropout=0.0,
emb_dropout=0.0,
):
super().__init__()
self.num_patches = int(
(input_res[0] / patch_res[0]) * (input_res[1] / patch_res[1])
)
self.dim = dim
self.patch_embedding = nn.Conv2d(
in_channels, dim, kernel_size=patch_res, stride=patch_res
)
self.pos_embedding = mx.random.truncated_normal(
-0.01,
0.01,
(self.num_patches + 1, dim),
)
self.cls_token = mx.random.truncated_normal(-0.01, 0.01, (dim,))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
self.pool = pool
self.mlp_head = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))
def num_params(self):
nparams = sum(x.size for k, x in tree_flatten(self.parameters()))
return nparams
def __call__(self, x):
if x.ndim != 4:
x = mx.expand_dims(x, axis=-1)
x = self.patch_embedding(x)
x = x.reshape(x.shape[0], -1, self.dim)
assert x.shape[1] == self.num_patches
cls_tokens = mx.broadcast_to(self.cls_token, (x.shape[0], 1, self.dim))
x = mx.concatenate((cls_tokens, x), axis=1)
x = x + self.pos_embedding
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(axis=1) if self.pool == "mean" else x[:, 0]
x = self.mlp_head(x)
return x
def parse_kwt_args(**kwargs):
input_res = kwargs.pop("input_res", [98, 40])
patch_res = kwargs.pop("patch_res", [1, 40])
num_classes = kwargs.pop("num_classes", 35)
emb_dropout = kwargs.pop("emb_dropout", 0.1)
return input_res, patch_res, num_classes, emb_dropout, kwargs
def kwt1(**kwargs):
input_res, patch_res, num_classes, emb_dropout, kwargs = parse_kwt_args(**kwargs)
return KWT(
input_res,
patch_res,
num_classes,
dim=64,
depth=12,
heads=1,
mlp_dim=256,
emb_dropout=emb_dropout,
**kwargs
)
def kwt2(**kwargs):
input_res, patch_res, num_classes, emb_dropout, kwargs = parse_kwt_args(**kwargs)
return KWT(
input_res,
patch_res,
num_classes,
dim=128,
depth=12,
heads=2,
mlp_dim=512,
emb_dropout=emb_dropout,
**kwargs
)
def kwt3(**kwargs):
input_res, patch_res, num_classes, emb_dropout, kwargs = parse_kwt_args(**kwargs)
return KWT(
input_res,
patch_res,
num_classes,
dim=192,
depth=12,
heads=3,
mlp_dim=768,
emb_dropout=emb_dropout,
**kwargs
)