mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 10:41:18 +08:00
246 lines
8.7 KiB
Python
246 lines
8.7 KiB
Python
|
||
|
||
import math
|
||
from dataclasses import dataclass
|
||
from typing import Union
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
from einops import rearrange, repeat
|
||
|
||
@dataclass
|
||
class Mamba2Config:
|
||
d_model: int # D
|
||
n_layers: int
|
||
d_head: int # todo : plutot n_heads non ?
|
||
d_state: int = 64 # N in paper/comments
|
||
expand_factor: int = 2 # E in paper/comments
|
||
d_conv: int = 4
|
||
n_groups: int = 1# todo : ??
|
||
|
||
A_init_range: tuple = (1, 16)
|
||
dt_min: float = 0.001
|
||
dt_max: float = 0.1
|
||
dt_init_floor: float = 1e-4
|
||
dt_limit: tuple = (0.0, float("inf"))
|
||
conv_init = None
|
||
|
||
learnable_init_states: bool = False
|
||
activation: str = "swish" # "swish" or "silu"
|
||
|
||
rms_norm_eps: float = 1e-5
|
||
base_std: float = 0.02
|
||
|
||
bias: bool = False
|
||
conv_bias: bool = True
|
||
|
||
mup: bool = False
|
||
mup_base_width: float = 128 # width=d_model
|
||
|
||
chunk_size: int = 256
|
||
use_mem_eff_path: bool = True
|
||
dtype=None
|
||
device=None
|
||
|
||
def __post_init__(self):
|
||
self.d_inner = self.expand_factor * self.d_model # E*D = ED in comments
|
||
self.n_heads = self.d_inner // self.d_head
|
||
assert self.d_inner % self.d_head == 0
|
||
|
||
assert (self.d_inner / self.d_head) % 8 == 0, "requierement of causal_conv1d"
|
||
|
||
# muP
|
||
if self.mup:
|
||
self.mup_width_mult = self.d_model / self.mup_base_width
|
||
|
||
class Mamba2(nn.Module):
|
||
def __init__(self, config: Mamba2Config):
|
||
super().__init__()
|
||
|
||
self.config = config
|
||
|
||
self.layers = nn.ModuleList([ResidualBlock(config) for _ in range(config.n_layers)])
|
||
|
||
def forward(self, x, caches=None):
|
||
if caches is None:
|
||
caches = [None] * self.config.n_layers
|
||
|
||
for i, layer in enumerate(self.layers):
|
||
x, caches[i] = layer(x, caches[i])
|
||
|
||
if caches[0] == None:
|
||
return x
|
||
else:
|
||
return x, caches
|
||
|
||
class ResidualBlock(nn.Module):
|
||
def __init__(self, config: Mamba2Config):
|
||
super().__init__()
|
||
|
||
self.config = config
|
||
|
||
self.mixer = Mamba2Block(self.config)
|
||
self.norm = RMSNorm(self.config.d_model, self.config.rms_norm_eps, self.config.mup)
|
||
|
||
def forward(self, x, cache=None):
|
||
output, cache = self.mixer(self.norm(x), cache)
|
||
output = output + x
|
||
return output, cache
|
||
|
||
class Mamba2Block(nn.Module):
|
||
def __init__(self, config: Mamba2Config):
|
||
super().__init__()
|
||
factory_kwargs = {"device": config.device, "dtype": config.dtype}
|
||
|
||
self.config = config
|
||
|
||
# [z, x, B, C, dt]
|
||
d_in_proj = 2 * self.config.d_inner + 2 * self.config.n_groups * self.config.d_state + self.config.n_heads
|
||
self.in_proj = nn.Linear(self.config.d_model, d_in_proj, bias=self.config.bias)
|
||
|
||
conv_dim = self.config.d_inner + 2 * self.config.n_groups * self.config.d_state
|
||
self.conv1d = nn.Conv1d(
|
||
in_channels=conv_dim,
|
||
out_channels=conv_dim,
|
||
bias=self.config.conv_bias,
|
||
kernel_size=self.config.d_conv,
|
||
groups=conv_dim,
|
||
padding=self.config.d_conv - 1,
|
||
**factory_kwargs,
|
||
)
|
||
|
||
|
||
# Initialize log dt bias
|
||
dt = torch.exp(
|
||
torch.rand(self.config.n_heads) * (math.log(self.config.dt_max) - math.log(self.config.dt_min))
|
||
+ math.log(self.config.dt_min)
|
||
)
|
||
dt = torch.clamp(dt, min=self.config.dt_init_floor)
|
||
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
||
self.dt_bias = nn.Parameter(inv_dt)
|
||
assert self.config.A_init_range[0] > 0 and self.config.A_init_range[1] >= self.config.A_init_range[0]
|
||
A = torch.empty(self.config.n_heads, dtype=torch.float32).uniform_(*self.config.A_init_range)
|
||
self.A_log = torch.log(A).to(dtype=self.config.dtype)
|
||
self.D = nn.Parameter(torch.ones(self.config.n_heads, device=self.config.device))
|
||
|
||
self.norm = RMSNormGated(self.config.d_inner, eps=1e-5, norm_before_gate=False)
|
||
|
||
self.out_proj = nn.Linear(self.config.d_inner, self.config.d_model, bias=self.config.bias)
|
||
|
||
def forward(self, u, cache=None, seq_idx=None):
|
||
"""
|
||
u: (B, L, D)
|
||
Returns: out : same shape as u
|
||
"""
|
||
|
||
batch, length, _ = u.shape
|
||
|
||
return_cache = False
|
||
if cache is not None and length > 1:
|
||
cache = None
|
||
return_cache = True
|
||
|
||
if cache is not None:
|
||
out, cache = self.step(u, cache)
|
||
return out, cache
|
||
|
||
zxbcdt = self.in_proj(u) # (B, L, d_in_proj)
|
||
A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state)
|
||
initial_states=repeat(self.init_states, "... -> b ...", b=batch) if self.config.learnable_init_states else None
|
||
dt_limit_kwargs = {} if self.config.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.config.dt_limit)
|
||
|
||
z, xBC, dt = torch.split(
|
||
zxbcdt,
|
||
[self.config.d_inner, self.config.d_inner + 2 * self.config.n_groups * self.config.d_state, self.config.n_heads],
|
||
dim=-1
|
||
)
|
||
dt = F.softplus(dt + self.dt_bias) # (B, L, nheads)
|
||
|
||
# 1D Convolution
|
||
xBC = self.act(self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)) # (B, L, self.d_inner + 2 * n_groups * d_state)
|
||
|
||
|
||
x, B, C = torch.split(xBC, [self.config.d_inner, self.config.n_groups * self.config.d_state, self.config.n_groups * self.config.d_state], dim=-1)
|
||
y = mamba_chunk_scan_combined(
|
||
rearrange(x, "b l (h p) -> b l h p", p=self.config.d_head),
|
||
dt,
|
||
A,
|
||
rearrange(B, "b l (g n) -> b l g n", g=self.config.n_groups),
|
||
rearrange(C, "b l (g n) -> b l g n", g=self.config.n_groups),
|
||
chunk_size=self.config.chunk_size,
|
||
D=self.D,
|
||
z=None,
|
||
seq_idx=seq_idx,
|
||
initial_states=initial_states,
|
||
**dt_limit_kwargs,
|
||
)
|
||
y = rearrange(y, "b l h p -> b l (h p)")
|
||
|
||
# Multiply "gate" branch and apply extra normalization layer
|
||
y = self.norm(y, z)
|
||
out = self.out_proj(y)
|
||
return out, cache
|
||
|
||
def step(self, u, cache):
|
||
"""
|
||
u: (B, 1, D)
|
||
cache: (h_cache, conv_cache)
|
||
"""
|
||
|
||
h_cache, conv_cache = cache
|
||
|
||
zxbcdt = self.in_proj(u.squeeze(1)) # (B, 2D)
|
||
d_mlp = (zxbcdt.shape[-1] - 2 * self.config.d_inner - 2 * self.config.n_groups * self.config.d_state - self.config.n_heads) // 2
|
||
z0, x0, z, xBC, dt = torch.split(zxbcdt, [d_mlp, d_mlp, self.config.d_inner, self.config.d_inner + 2 * self.config.n_groups * self.config.d_state, self.config.n_heads], dim=-1)
|
||
|
||
# conv step
|
||
conv_cache.copy_(torch.roll(conv_cache, shifts=-1, dims=-1)) # update state (B, D, W)
|
||
conv_cache[:, :, -1] = xBC
|
||
xBC = torch.sum(conv_cache * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B, D)
|
||
if self.conv1d.bias is not None:
|
||
xBC = xBC + self.conv1d.bias
|
||
xBC = self.act(xBC).to(dtype=x.dtype)
|
||
|
||
x, B, C = torch.split(xBC, [self.config.d_inner, self.config.n_groups * self.config.d_state, self.config.n_groups * self.config.d_state], dim=-1)
|
||
A = -torch.exp(self.A_log.float()) # (n_heads)
|
||
|
||
|
||
A = repeat(A, "h -> h p n", p=self.config.d_head, n=self.config.d_state).to(dtype=torch.float32)
|
||
dt = repeat(dt, "b h -> b h p", p=self.config.d_head)
|
||
dt_bias = repeat(self.dt_bias, "h -> h p", p=self.config.d_head)
|
||
D = repeat(self.D, "h -> h p", p=self.config.d_head)
|
||
B = rearrange(B, "b (g n) -> b g n", g=self.config.n_groups)
|
||
C = rearrange(C, "b (g n) -> b g n", g=self.config.n_groups)
|
||
x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.config.d_head)
|
||
|
||
y = selective_state_update(h_cache, x_reshaped, dt, A, B, C, D, z=None, dt_bias=dt_bias, dt_softplus=True)
|
||
y = rearrange(y, "b h p -> b (h p)")
|
||
|
||
#if self.rmsnorm:
|
||
y = self.norm(y, z)
|
||
if d_mlp > 0:
|
||
y = torch.cat([F.silu(z0) * x0, y], dim=-1)
|
||
out = self.out_proj(y)
|
||
return out.unsqueeze(1), (h_cache, conv_cache)
|
||
|
||
# taken straight from https://github.com/johnma2006/mamba-minimal/blob/master/model.py
|
||
class RMSNorm(nn.Module):
|
||
def __init__(self, d_model: int, eps: float = 1e-5, use_mup: bool = False):
|
||
super().__init__()
|
||
|
||
self.use_mup = use_mup
|
||
self.eps = eps
|
||
|
||
# https://arxiv.org/abs/2404.05728, RMSNorm gains prevents muTransfer (section 4.2.3)
|
||
if not use_mup:
|
||
self.weight = nn.Parameter(torch.ones(d_model))
|
||
|
||
def forward(self, x):
|
||
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||
|
||
if not self.use_mup:
|
||
return output * self.weight
|
||
else:
|
||
return output |