mlx-examples/speechcommands/kwt.py
dmdaksh 7d7e236061
- Removed unused Python imports (#683)
- bert/model.py:10: tree_unflatten
  - bert/model.py:2: dataclass
  - bert/model.py:8: numpy
  - cifar/resnet.py:6: Any
  - clip/model.py:15: tree_flatten
  - clip/model.py:9: Union
  - gcn/main.py:8: download_cora
  - gcn/main.py:9: cross_entropy
  - llms/gguf_llm/models.py:12: tree_flatten, tree_unflatten
  - llms/gguf_llm/models.py:9: numpy
  - llms/mixtral/mixtral.py:12: tree_map
  - llms/mlx_lm/models/dbrx.py:2: Dict, Union
  - llms/mlx_lm/tuner/trainer.py:5: partial
  - llms/speculative_decoding/decoder.py:1: dataclass, field
  - llms/speculative_decoding/decoder.py:2: Optional
  - llms/speculative_decoding/decoder.py:5: mlx.nn
  - llms/speculative_decoding/decoder.py:6: numpy
  - llms/speculative_decoding/main.py:2: glob
  - llms/speculative_decoding/main.py:3: json
  - llms/speculative_decoding/main.py:5: Path
  - llms/speculative_decoding/main.py:8: mlx.nn
  - llms/speculative_decoding/model.py:6: tree_unflatten
  - llms/speculative_decoding/model.py:7: AutoTokenizer
  - llms/tests/test_lora.py:13: yaml_loader
  - lora/lora.py:14: tree_unflatten
  - lora/models.py:11: numpy
  - lora/models.py:3: glob
  - speechcommands/kwt.py:1: Any
  - speechcommands/main.py:7: mlx.data
  - stable_diffusion/stable_diffusion/model_io.py:4: partial
  - whisper/benchmark.py:5: sys
  - whisper/test.py:5: subprocess
  - whisper/whisper/audio.py:6: Optional
  - whisper/whisper/decoding.py:8: mlx.nn
2024-04-16 07:50:32 -07:00

213 lines
5.8 KiB
Python

import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten
__all__ = ["KWT", "kwt1", "kwt2", "kwt3"]
class FeedForward(nn.Sequential):
def __init__(self, dim, hidden_dim, dropout=0.0):
super().__init__(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout),
)
class Attention(nn.Module):
def __init__(self, dim, heads, dropout=0.0):
super().__init__()
self.heads = heads
self.scale = dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.out = nn.Sequential(nn.Linear(dim, dim), nn.Dropout(dropout))
def __call__(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.qkv(x)
qkv = qkv.reshape(b, n, 3, h, -1).transpose(2, 0, 3, 1, 4)
q, k, v = qkv
attn = (q @ k.transpose(0, 1, 3, 2)) * self.scale
attn = mx.softmax(attn, axis=-1)
x = (attn @ v).transpose(0, 2, 1, 3).reshape(b, n, -1)
x = self.out(x)
return x
class Block(nn.Module):
def __init__(self, dim, heads, mlp_dim, dropout=0.0):
super().__init__()
self.attn = Attention(dim, heads, dropout=dropout)
self.norm1 = nn.LayerNorm(dim)
self.ff = FeedForward(dim, mlp_dim, dropout=dropout)
self.norm2 = nn.LayerNorm(dim)
def __call__(self, x):
x = self.norm1(self.attn(x)) + x
x = self.norm2(self.ff(x)) + x
return x
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, mlp_dim, dropout=0.0):
super().__init__()
self.layers = []
for _ in range(depth):
self.layers.append(Block(dim, heads, mlp_dim, dropout=dropout))
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x
class KWT(nn.Module):
"""
Implements the Keyword Transformer (KWT) [1] model.
KWT is essentially a vision transformer [2] with minor modifications:
- Instead of square patches, KWT uses rectangular patches -> a patch
across frequency for every timestep
- KWT modules apply layer normalization after attention/feedforward layers
[1] https://arxiv.org/abs/2104.11178
[2] https://arxiv.org/abs/2010.11929
Parameters
----------
input_res: tuple of ints
Input resolution (time, frequency)
patch_res: tuple of ints
Patch resolution (time, frequency)
num_classes: int
Number of classes
dim: int
Model Embedding dimension
depth: int
Number of transformer layers
heads: int
Number of attention heads
mlp_dim: int
Feedforward hidden dimension
pool: str
Pooling type, either "cls" or "mean"
in_channels: int, optional
Number of input channels
dropout: float, optional
Dropout rate
emb_dropout: float, optional
Embedding dropout rate
"""
def __init__(
self,
input_res,
patch_res,
num_classes,
dim,
depth,
heads,
mlp_dim,
pool="mean",
in_channels=1,
dropout=0.0,
emb_dropout=0.0,
):
super().__init__()
self.num_patches = int(
(input_res[0] / patch_res[0]) * (input_res[1] / patch_res[1])
)
self.dim = dim
self.patch_embedding = nn.Conv2d(
in_channels, dim, kernel_size=patch_res, stride=patch_res
)
self.pos_embedding = mx.random.truncated_normal(
-0.01,
0.01,
(self.num_patches + 1, dim),
)
self.cls_token = mx.random.truncated_normal(-0.01, 0.01, (dim,))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
self.pool = pool
self.mlp_head = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))
def num_params(self):
nparams = sum(x.size for k, x in tree_flatten(self.parameters()))
return nparams
def __call__(self, x):
if x.ndim != 4:
x = mx.expand_dims(x, axis=-1)
x = self.patch_embedding(x)
x = x.reshape(x.shape[0], -1, self.dim)
assert x.shape[1] == self.num_patches
cls_tokens = mx.broadcast_to(self.cls_token, (x.shape[0], 1, self.dim))
x = mx.concatenate((cls_tokens, x), axis=1)
x = x + self.pos_embedding
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(axis=1) if self.pool == "mean" else x[:, 0]
x = self.mlp_head(x)
return x
def parse_kwt_args(**kwargs):
input_res = kwargs.pop("input_res", [98, 40])
patch_res = kwargs.pop("patch_res", [1, 40])
num_classes = kwargs.pop("num_classes", 35)
emb_dropout = kwargs.pop("emb_dropout", 0.1)
return input_res, patch_res, num_classes, emb_dropout, kwargs
def kwt1(**kwargs):
input_res, patch_res, num_classes, emb_dropout, kwargs = parse_kwt_args(**kwargs)
return KWT(
input_res,
patch_res,
num_classes,
dim=64,
depth=12,
heads=1,
mlp_dim=256,
emb_dropout=emb_dropout,
**kwargs
)
def kwt2(**kwargs):
input_res, patch_res, num_classes, emb_dropout, kwargs = parse_kwt_args(**kwargs)
return KWT(
input_res,
patch_res,
num_classes,
dim=128,
depth=12,
heads=2,
mlp_dim=512,
emb_dropout=emb_dropout,
**kwargs
)
def kwt3(**kwargs):
input_res, patch_res, num_classes, emb_dropout, kwargs = parse_kwt_args(**kwargs)
return KWT(
input_res,
patch_res,
num_classes,
dim=192,
depth=12,
heads=3,
mlp_dim=768,
emb_dropout=emb_dropout,
**kwargs
)