2024-10-21 00:41:28 +08:00
|
|
|
|
|
|
|
|
|
|
2024-10-12 03:36:41 +08:00
|
|
|
|
import math
|
|
|
|
|
from dataclasses import dataclass
|
2024-10-21 00:41:28 +08:00
|
|
|
|
from typing import Union
|
2024-10-12 03:36:41 +08:00
|
|
|
|
|
|
|
|
|
import torch
|
2024-10-21 00:41:28 +08:00
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
2024-10-12 03:36:41 +08:00
|
|
|
|
|
2024-10-21 00:41:28 +08:00
|
|
|
|
from einops import rearrange, repeat
|
2024-10-12 03:36:41 +08:00
|
|
|
|
|
2024-10-21 00:41:28 +08:00
|
|
|
|
@dataclass
|
|
|
|
|
class Mamba2Config:
|
|
|
|
|
d_model: int # D
|
|
|
|
|
n_layers: int
|
|
|
|
|
d_head: int # todo : plutot n_heads non ?
|
|
|
|
|
d_state: int = 64 # N in paper/comments
|
|
|
|
|
expand_factor: int = 2 # E in paper/comments
|
|
|
|
|
d_conv: int = 4
|
|
|
|
|
n_groups: int = 1# todo : ??
|
|
|
|
|
|
|
|
|
|
A_init_range: tuple = (1, 16)
|
|
|
|
|
dt_min: float = 0.001
|
|
|
|
|
dt_max: float = 0.1
|
|
|
|
|
dt_init_floor: float = 1e-4
|
|
|
|
|
dt_limit: tuple = (0.0, float("inf"))
|
|
|
|
|
conv_init = None
|
|
|
|
|
|
|
|
|
|
learnable_init_states: bool = False
|
|
|
|
|
activation: str = "swish" # "swish" or "silu"
|
|
|
|
|
|
|
|
|
|
rms_norm_eps: float = 1e-5
|
|
|
|
|
base_std: float = 0.02
|
|
|
|
|
|
|
|
|
|
bias: bool = False
|
|
|
|
|
conv_bias: bool = True
|
|
|
|
|
|
|
|
|
|
mup: bool = False
|
|
|
|
|
mup_base_width: float = 128 # width=d_model
|
|
|
|
|
|
|
|
|
|
chunk_size: int = 256
|
|
|
|
|
use_mem_eff_path: bool = True
|
|
|
|
|
dtype=None
|
|
|
|
|
device=None
|
|
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
|
self.d_inner = self.expand_factor * self.d_model # E*D = ED in comments
|
|
|
|
|
self.n_heads = self.d_inner // self.d_head
|
|
|
|
|
assert self.d_inner % self.d_head == 0
|
|
|
|
|
|
|
|
|
|
assert (self.d_inner / self.d_head) % 8 == 0, "requierement of causal_conv1d"
|
|
|
|
|
|
|
|
|
|
# muP
|
|
|
|
|
if self.mup:
|
|
|
|
|
self.mup_width_mult = self.d_model / self.mup_base_width
|
|
|
|
|
|
|
|
|
|
class Mamba2(nn.Module):
|
|
|
|
|
def __init__(self, config: Mamba2Config):
|
|
|
|
|
super().__init__()
|
2024-10-12 03:36:41 +08:00
|
|
|
|
|
2024-10-21 00:41:28 +08:00
|
|
|
|
self.config = config
|
2024-10-12 03:36:41 +08:00
|
|
|
|
|
2024-10-21 00:41:28 +08:00
|
|
|
|
self.layers = nn.ModuleList([ResidualBlock(config) for _ in range(config.n_layers)])
|
2024-10-12 03:36:41 +08:00
|
|
|
|
|
2024-10-21 00:41:28 +08:00
|
|
|
|
def forward(self, x, caches=None):
|
|
|
|
|
if caches is None:
|
|
|
|
|
caches = [None] * self.config.n_layers
|
2024-10-12 03:36:41 +08:00
|
|
|
|
|
2024-10-21 00:41:28 +08:00
|
|
|
|
for i, layer in enumerate(self.layers):
|
|
|
|
|
x, caches[i] = layer(x, caches[i])
|
2024-10-12 03:36:41 +08:00
|
|
|
|
|
2024-10-21 00:41:28 +08:00
|
|
|
|
if caches[0] == None:
|
|
|
|
|
return x
|
2024-10-12 03:36:41 +08:00
|
|
|
|
else:
|
2024-10-21 00:41:28 +08:00
|
|
|
|
return x, caches
|
|
|
|
|
|
|
|
|
|
class ResidualBlock(nn.Module):
|
|
|
|
|
def __init__(self, config: Mamba2Config):
|
2024-10-12 03:36:41 +08:00
|
|
|
|
super().__init__()
|
2024-10-21 00:41:28 +08:00
|
|
|
|
|
|
|
|
|
self.config = config
|
2024-10-12 03:36:41 +08:00
|
|
|
|
|
2024-10-21 00:41:28 +08:00
|
|
|
|
self.mixer = Mamba2Block(self.config)
|
|
|
|
|
self.norm = RMSNorm(self.config.d_model, self.config.rms_norm_eps, self.config.mup)
|
2024-10-12 03:36:41 +08:00
|
|
|
|
|
2024-10-21 00:41:28 +08:00
|
|
|
|
def forward(self, x, cache=None):
|
|
|
|
|
output, cache = self.mixer(self.norm(x), cache)
|
|
|
|
|
output = output + x
|
|
|
|
|
return output, cache
|
2024-10-12 03:36:41 +08:00
|
|
|
|
|
|
|
|
|
class Mamba2Block(nn.Module):
|
2024-10-21 00:41:28 +08:00
|
|
|
|
def __init__(self, config: Mamba2Config):
|
2024-10-12 03:36:41 +08:00
|
|
|
|
super().__init__()
|
2024-10-21 00:41:28 +08:00
|
|
|
|
factory_kwargs = {"device": config.device, "dtype": config.dtype}
|
2024-10-12 03:36:41 +08:00
|
|
|
|
|
2024-10-21 00:41:28 +08:00
|
|
|
|
self.config = config
|
2024-10-12 03:36:41 +08:00
|
|
|
|
|
2024-10-21 00:41:28 +08:00
|
|
|
|
# [z, x, B, C, dt]
|
|
|
|
|
d_in_proj = 2 * self.config.d_inner + 2 * self.config.n_groups * self.config.d_state + self.config.n_heads
|
|
|
|
|
self.in_proj = nn.Linear(self.config.d_model, d_in_proj, bias=self.config.bias)
|
2024-10-12 03:36:41 +08:00
|
|
|
|
|
2024-10-21 00:41:28 +08:00
|
|
|
|
conv_dim = self.config.d_inner + 2 * self.config.n_groups * self.config.d_state
|
|
|
|
|
self.conv1d = nn.Conv1d(
|
|
|
|
|
in_channels=conv_dim,
|
|
|
|
|
out_channels=conv_dim,
|
|
|
|
|
bias=self.config.conv_bias,
|
|
|
|
|
kernel_size=self.config.d_conv,
|
|
|
|
|
groups=conv_dim,
|
|
|
|
|
padding=self.config.d_conv - 1,
|
|
|
|
|
**factory_kwargs,
|
|
|
|
|
)
|
2024-10-12 03:36:41 +08:00
|
|
|
|
|
|
|
|
|
|
2024-10-21 00:41:28 +08:00
|
|
|
|
# Initialize log dt bias
|
|
|
|
|
dt = torch.exp(
|
|
|
|
|
torch.rand(self.config.n_heads) * (math.log(self.config.dt_max) - math.log(self.config.dt_min))
|
|
|
|
|
+ math.log(self.config.dt_min)
|
|
|
|
|
)
|
|
|
|
|
dt = torch.clamp(dt, min=self.config.dt_init_floor)
|
|
|
|
|
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
|
|
|
|
self.dt_bias = nn.Parameter(inv_dt)
|
|
|
|
|
assert self.config.A_init_range[0] > 0 and self.config.A_init_range[1] >= self.config.A_init_range[0]
|
|
|
|
|
A = torch.empty(self.config.n_heads, dtype=torch.float32).uniform_(*self.config.A_init_range)
|
|
|
|
|
self.A_log = torch.log(A).to(dtype=self.config.dtype)
|
|
|
|
|
self.D = nn.Parameter(torch.ones(self.config.n_heads, device=self.config.device))
|
2024-10-12 03:36:41 +08:00
|
|
|
|
|
2024-10-21 00:41:28 +08:00
|
|
|
|
self.norm = RMSNormGated(self.config.d_inner, eps=1e-5, norm_before_gate=False)
|
2024-10-12 03:36:41 +08:00
|
|
|
|
|
2024-10-21 00:41:28 +08:00
|
|
|
|
self.out_proj = nn.Linear(self.config.d_inner, self.config.d_model, bias=self.config.bias)
|
2024-10-12 03:36:41 +08:00
|
|
|
|
|
2024-10-21 00:41:28 +08:00
|
|
|
|
def forward(self, u, cache=None, seq_idx=None):
|
|
|
|
|
"""
|
|
|
|
|
u: (B, L, D)
|
|
|
|
|
Returns: out : same shape as u
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
batch, length, _ = u.shape
|
|
|
|
|
|
|
|
|
|
return_cache = False
|
|
|
|
|
if cache is not None and length > 1:
|
|
|
|
|
cache = None
|
|
|
|
|
return_cache = True
|
|
|
|
|
|
|
|
|
|
if cache is not None:
|
|
|
|
|
out, cache = self.step(u, cache)
|
|
|
|
|
return out, cache
|
|
|
|
|
|
|
|
|
|
zxbcdt = self.in_proj(u) # (B, L, d_in_proj)
|
|
|
|
|
A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state)
|
|
|
|
|
initial_states=repeat(self.init_states, "... -> b ...", b=batch) if self.config.learnable_init_states else None
|
|
|
|
|
dt_limit_kwargs = {} if self.config.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.config.dt_limit)
|
|
|
|
|
|
|
|
|
|
z, xBC, dt = torch.split(
|
|
|
|
|
zxbcdt,
|
|
|
|
|
[self.config.d_inner, self.config.d_inner + 2 * self.config.n_groups * self.config.d_state, self.config.n_heads],
|
|
|
|
|
dim=-1
|
2024-10-12 03:36:41 +08:00
|
|
|
|
)
|
2024-10-21 00:41:28 +08:00
|
|
|
|
dt = F.softplus(dt + self.dt_bias) # (B, L, nheads)
|
|
|
|
|
|
|
|
|
|
# 1D Convolution
|
|
|
|
|
xBC = self.act(self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)) # (B, L, self.d_inner + 2 * n_groups * d_state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x, B, C = torch.split(xBC, [self.config.d_inner, self.config.n_groups * self.config.d_state, self.config.n_groups * self.config.d_state], dim=-1)
|
|
|
|
|
y = mamba_chunk_scan_combined(
|
|
|
|
|
rearrange(x, "b l (h p) -> b l h p", p=self.config.d_head),
|
|
|
|
|
dt,
|
|
|
|
|
A,
|
|
|
|
|
rearrange(B, "b l (g n) -> b l g n", g=self.config.n_groups),
|
|
|
|
|
rearrange(C, "b l (g n) -> b l g n", g=self.config.n_groups),
|
|
|
|
|
chunk_size=self.config.chunk_size,
|
|
|
|
|
D=self.D,
|
|
|
|
|
z=None,
|
|
|
|
|
seq_idx=seq_idx,
|
|
|
|
|
initial_states=initial_states,
|
|
|
|
|
**dt_limit_kwargs,
|
2024-10-12 03:36:41 +08:00
|
|
|
|
)
|
2024-10-21 00:41:28 +08:00
|
|
|
|
y = rearrange(y, "b l h p -> b l (h p)")
|
|
|
|
|
|
|
|
|
|
# Multiply "gate" branch and apply extra normalization layer
|
|
|
|
|
y = self.norm(y, z)
|
|
|
|
|
out = self.out_proj(y)
|
|
|
|
|
return out, cache
|
|
|
|
|
|
|
|
|
|
def step(self, u, cache):
|
|
|
|
|
"""
|
|
|
|
|
u: (B, 1, D)
|
|
|
|
|
cache: (h_cache, conv_cache)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
h_cache, conv_cache = cache
|
|
|
|
|
|
|
|
|
|
zxbcdt = self.in_proj(u.squeeze(1)) # (B, 2D)
|
|
|
|
|
d_mlp = (zxbcdt.shape[-1] - 2 * self.config.d_inner - 2 * self.config.n_groups * self.config.d_state - self.config.n_heads) // 2
|
|
|
|
|
z0, x0, z, xBC, dt = torch.split(zxbcdt, [d_mlp, d_mlp, self.config.d_inner, self.config.d_inner + 2 * self.config.n_groups * self.config.d_state, self.config.n_heads], dim=-1)
|
|
|
|
|
|
|
|
|
|
# conv step
|
|
|
|
|
conv_cache.copy_(torch.roll(conv_cache, shifts=-1, dims=-1)) # update state (B, D, W)
|
|
|
|
|
conv_cache[:, :, -1] = xBC
|
|
|
|
|
xBC = torch.sum(conv_cache * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B, D)
|
|
|
|
|
if self.conv1d.bias is not None:
|
|
|
|
|
xBC = xBC + self.conv1d.bias
|
|
|
|
|
xBC = self.act(xBC).to(dtype=x.dtype)
|
|
|
|
|
|
|
|
|
|
x, B, C = torch.split(xBC, [self.config.d_inner, self.config.n_groups * self.config.d_state, self.config.n_groups * self.config.d_state], dim=-1)
|
|
|
|
|
A = -torch.exp(self.A_log.float()) # (n_heads)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
A = repeat(A, "h -> h p n", p=self.config.d_head, n=self.config.d_state).to(dtype=torch.float32)
|
|
|
|
|
dt = repeat(dt, "b h -> b h p", p=self.config.d_head)
|
|
|
|
|
dt_bias = repeat(self.dt_bias, "h -> h p", p=self.config.d_head)
|
|
|
|
|
D = repeat(self.D, "h -> h p", p=self.config.d_head)
|
|
|
|
|
B = rearrange(B, "b (g n) -> b g n", g=self.config.n_groups)
|
|
|
|
|
C = rearrange(C, "b (g n) -> b g n", g=self.config.n_groups)
|
|
|
|
|
x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.config.d_head)
|
|
|
|
|
|
|
|
|
|
y = selective_state_update(h_cache, x_reshaped, dt, A, B, C, D, z=None, dt_bias=dt_bias, dt_softplus=True)
|
|
|
|
|
y = rearrange(y, "b h p -> b (h p)")
|
|
|
|
|
|
|
|
|
|
#if self.rmsnorm:
|
|
|
|
|
y = self.norm(y, z)
|
|
|
|
|
if d_mlp > 0:
|
|
|
|
|
y = torch.cat([F.silu(z0) * x0, y], dim=-1)
|
|
|
|
|
out = self.out_proj(y)
|
|
|
|
|
return out.unsqueeze(1), (h_cache, conv_cache)
|
|
|
|
|
|
|
|
|
|
# taken straight from https://github.com/johnma2006/mamba-minimal/blob/master/model.py
|
|
|
|
|
class RMSNorm(nn.Module):
|
|
|
|
|
def __init__(self, d_model: int, eps: float = 1e-5, use_mup: bool = False):
|
|
|
|
|
super().__init__()
|
2024-10-12 03:36:41 +08:00
|
|
|
|
|
2024-10-21 00:41:28 +08:00
|
|
|
|
self.use_mup = use_mup
|
|
|
|
|
self.eps = eps
|
2024-10-12 03:36:41 +08:00
|
|
|
|
|
2024-10-21 00:41:28 +08:00
|
|
|
|
# https://arxiv.org/abs/2404.05728, RMSNorm gains prevents muTransfer (section 4.2.3)
|
|
|
|
|
if not use_mup:
|
|
|
|
|
self.weight = nn.Parameter(torch.ones(d_model))
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
|
|
|
|
|
|
|
|
if not self.use_mup:
|
|
|
|
|
return output * self.weight
|
2024-10-12 03:36:41 +08:00
|
|
|
|
else:
|
2024-10-21 00:41:28 +08:00
|
|
|
|
return output
|