mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00

- bert/model.py:10: tree_unflatten - bert/model.py:2: dataclass - bert/model.py:8: numpy - cifar/resnet.py:6: Any - clip/model.py:15: tree_flatten - clip/model.py:9: Union - gcn/main.py:8: download_cora - gcn/main.py:9: cross_entropy - llms/gguf_llm/models.py:12: tree_flatten, tree_unflatten - llms/gguf_llm/models.py:9: numpy - llms/mixtral/mixtral.py:12: tree_map - llms/mlx_lm/models/dbrx.py:2: Dict, Union - llms/mlx_lm/tuner/trainer.py:5: partial - llms/speculative_decoding/decoder.py:1: dataclass, field - llms/speculative_decoding/decoder.py:2: Optional - llms/speculative_decoding/decoder.py:5: mlx.nn - llms/speculative_decoding/decoder.py:6: numpy - llms/speculative_decoding/main.py:2: glob - llms/speculative_decoding/main.py:3: json - llms/speculative_decoding/main.py:5: Path - llms/speculative_decoding/main.py:8: mlx.nn - llms/speculative_decoding/model.py:6: tree_unflatten - llms/speculative_decoding/model.py:7: AutoTokenizer - llms/tests/test_lora.py:13: yaml_loader - lora/lora.py:14: tree_unflatten - lora/models.py:11: numpy - lora/models.py:3: glob - speechcommands/kwt.py:1: Any - speechcommands/main.py:7: mlx.data - stable_diffusion/stable_diffusion/model_io.py:4: partial - whisper/benchmark.py:5: sys - whisper/test.py:5: subprocess - whisper/whisper/audio.py:6: Optional - whisper/whisper/decoding.py:8: mlx.nn
213 lines
5.8 KiB
Python
213 lines
5.8 KiB
Python
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
from mlx.utils import tree_flatten
|
|
|
|
__all__ = ["KWT", "kwt1", "kwt2", "kwt3"]
|
|
|
|
|
|
class FeedForward(nn.Sequential):
|
|
def __init__(self, dim, hidden_dim, dropout=0.0):
|
|
super().__init__(
|
|
nn.Linear(dim, hidden_dim),
|
|
nn.GELU(),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(hidden_dim, dim),
|
|
nn.Dropout(dropout),
|
|
)
|
|
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(self, dim, heads, dropout=0.0):
|
|
super().__init__()
|
|
self.heads = heads
|
|
self.scale = dim**-0.5
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
|
self.out = nn.Sequential(nn.Linear(dim, dim), nn.Dropout(dropout))
|
|
|
|
def __call__(self, x):
|
|
b, n, _, h = *x.shape, self.heads
|
|
qkv = self.qkv(x)
|
|
qkv = qkv.reshape(b, n, 3, h, -1).transpose(2, 0, 3, 1, 4)
|
|
q, k, v = qkv
|
|
attn = (q @ k.transpose(0, 1, 3, 2)) * self.scale
|
|
attn = mx.softmax(attn, axis=-1)
|
|
x = (attn @ v).transpose(0, 2, 1, 3).reshape(b, n, -1)
|
|
x = self.out(x)
|
|
return x
|
|
|
|
|
|
class Block(nn.Module):
|
|
def __init__(self, dim, heads, mlp_dim, dropout=0.0):
|
|
super().__init__()
|
|
self.attn = Attention(dim, heads, dropout=dropout)
|
|
self.norm1 = nn.LayerNorm(dim)
|
|
self.ff = FeedForward(dim, mlp_dim, dropout=dropout)
|
|
self.norm2 = nn.LayerNorm(dim)
|
|
|
|
def __call__(self, x):
|
|
x = self.norm1(self.attn(x)) + x
|
|
x = self.norm2(self.ff(x)) + x
|
|
return x
|
|
|
|
|
|
class Transformer(nn.Module):
|
|
def __init__(self, dim, depth, heads, mlp_dim, dropout=0.0):
|
|
super().__init__()
|
|
|
|
self.layers = []
|
|
for _ in range(depth):
|
|
self.layers.append(Block(dim, heads, mlp_dim, dropout=dropout))
|
|
|
|
def __call__(self, x):
|
|
for layer in self.layers:
|
|
x = layer(x)
|
|
return x
|
|
|
|
|
|
class KWT(nn.Module):
|
|
"""
|
|
Implements the Keyword Transformer (KWT) [1] model.
|
|
|
|
KWT is essentially a vision transformer [2] with minor modifications:
|
|
- Instead of square patches, KWT uses rectangular patches -> a patch
|
|
across frequency for every timestep
|
|
- KWT modules apply layer normalization after attention/feedforward layers
|
|
|
|
[1] https://arxiv.org/abs/2104.11178
|
|
[2] https://arxiv.org/abs/2010.11929
|
|
|
|
Parameters
|
|
----------
|
|
input_res: tuple of ints
|
|
Input resolution (time, frequency)
|
|
patch_res: tuple of ints
|
|
Patch resolution (time, frequency)
|
|
num_classes: int
|
|
Number of classes
|
|
dim: int
|
|
Model Embedding dimension
|
|
depth: int
|
|
Number of transformer layers
|
|
heads: int
|
|
Number of attention heads
|
|
mlp_dim: int
|
|
Feedforward hidden dimension
|
|
pool: str
|
|
Pooling type, either "cls" or "mean"
|
|
in_channels: int, optional
|
|
Number of input channels
|
|
dropout: float, optional
|
|
Dropout rate
|
|
emb_dropout: float, optional
|
|
Embedding dropout rate
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_res,
|
|
patch_res,
|
|
num_classes,
|
|
dim,
|
|
depth,
|
|
heads,
|
|
mlp_dim,
|
|
pool="mean",
|
|
in_channels=1,
|
|
dropout=0.0,
|
|
emb_dropout=0.0,
|
|
):
|
|
super().__init__()
|
|
self.num_patches = int(
|
|
(input_res[0] / patch_res[0]) * (input_res[1] / patch_res[1])
|
|
)
|
|
self.dim = dim
|
|
|
|
self.patch_embedding = nn.Conv2d(
|
|
in_channels, dim, kernel_size=patch_res, stride=patch_res
|
|
)
|
|
self.pos_embedding = mx.random.truncated_normal(
|
|
-0.01,
|
|
0.01,
|
|
(self.num_patches + 1, dim),
|
|
)
|
|
self.cls_token = mx.random.truncated_normal(-0.01, 0.01, (dim,))
|
|
self.dropout = nn.Dropout(emb_dropout)
|
|
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
|
|
self.pool = pool
|
|
self.mlp_head = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))
|
|
|
|
def num_params(self):
|
|
nparams = sum(x.size for k, x in tree_flatten(self.parameters()))
|
|
return nparams
|
|
|
|
def __call__(self, x):
|
|
if x.ndim != 4:
|
|
x = mx.expand_dims(x, axis=-1)
|
|
x = self.patch_embedding(x)
|
|
x = x.reshape(x.shape[0], -1, self.dim)
|
|
assert x.shape[1] == self.num_patches
|
|
|
|
cls_tokens = mx.broadcast_to(self.cls_token, (x.shape[0], 1, self.dim))
|
|
x = mx.concatenate((cls_tokens, x), axis=1)
|
|
|
|
x = x + self.pos_embedding
|
|
|
|
x = self.dropout(x)
|
|
x = self.transformer(x)
|
|
x = x.mean(axis=1) if self.pool == "mean" else x[:, 0]
|
|
x = self.mlp_head(x)
|
|
return x
|
|
|
|
|
|
def parse_kwt_args(**kwargs):
|
|
input_res = kwargs.pop("input_res", [98, 40])
|
|
patch_res = kwargs.pop("patch_res", [1, 40])
|
|
num_classes = kwargs.pop("num_classes", 35)
|
|
emb_dropout = kwargs.pop("emb_dropout", 0.1)
|
|
return input_res, patch_res, num_classes, emb_dropout, kwargs
|
|
|
|
|
|
def kwt1(**kwargs):
|
|
input_res, patch_res, num_classes, emb_dropout, kwargs = parse_kwt_args(**kwargs)
|
|
return KWT(
|
|
input_res,
|
|
patch_res,
|
|
num_classes,
|
|
dim=64,
|
|
depth=12,
|
|
heads=1,
|
|
mlp_dim=256,
|
|
emb_dropout=emb_dropout,
|
|
**kwargs
|
|
)
|
|
|
|
|
|
def kwt2(**kwargs):
|
|
input_res, patch_res, num_classes, emb_dropout, kwargs = parse_kwt_args(**kwargs)
|
|
return KWT(
|
|
input_res,
|
|
patch_res,
|
|
num_classes,
|
|
dim=128,
|
|
depth=12,
|
|
heads=2,
|
|
mlp_dim=512,
|
|
emb_dropout=emb_dropout,
|
|
**kwargs
|
|
)
|
|
|
|
|
|
def kwt3(**kwargs):
|
|
input_res, patch_res, num_classes, emb_dropout, kwargs = parse_kwt_args(**kwargs)
|
|
return KWT(
|
|
input_res,
|
|
patch_res,
|
|
num_classes,
|
|
dim=192,
|
|
depth=12,
|
|
heads=3,
|
|
mlp_dim=768,
|
|
emb_dropout=emb_dropout,
|
|
**kwargs
|
|
)
|