mlx-examples/llms/mlx_lm/models/mamba2-prch.py

246 lines
8.7 KiB
Python
Raw Normal View History

2024-10-21 00:41:28 +08:00
2024-10-12 03:36:41 +08:00
import math
from dataclasses import dataclass
2024-10-21 00:41:28 +08:00
from typing import Union
2024-10-12 03:36:41 +08:00
import torch
2024-10-21 00:41:28 +08:00
import torch.nn as nn
import torch.nn.functional as F
2024-10-12 03:36:41 +08:00
2024-10-21 00:41:28 +08:00
from einops import rearrange, repeat
2024-10-12 03:36:41 +08:00
2024-10-21 00:41:28 +08:00
@dataclass
class Mamba2Config:
d_model: int # D
n_layers: int
d_head: int # todo : plutot n_heads non ?
d_state: int = 64 # N in paper/comments
expand_factor: int = 2 # E in paper/comments
d_conv: int = 4
n_groups: int = 1# todo : ??
A_init_range: tuple = (1, 16)
dt_min: float = 0.001
dt_max: float = 0.1
dt_init_floor: float = 1e-4
dt_limit: tuple = (0.0, float("inf"))
conv_init = None
learnable_init_states: bool = False
activation: str = "swish" # "swish" or "silu"
rms_norm_eps: float = 1e-5
base_std: float = 0.02
bias: bool = False
conv_bias: bool = True
mup: bool = False
mup_base_width: float = 128 # width=d_model
chunk_size: int = 256
use_mem_eff_path: bool = True
dtype=None
device=None
def __post_init__(self):
self.d_inner = self.expand_factor * self.d_model # E*D = ED in comments
self.n_heads = self.d_inner // self.d_head
assert self.d_inner % self.d_head == 0
assert (self.d_inner / self.d_head) % 8 == 0, "requierement of causal_conv1d"
# muP
if self.mup:
self.mup_width_mult = self.d_model / self.mup_base_width
class Mamba2(nn.Module):
def __init__(self, config: Mamba2Config):
super().__init__()
2024-10-12 03:36:41 +08:00
2024-10-21 00:41:28 +08:00
self.config = config
2024-10-12 03:36:41 +08:00
2024-10-21 00:41:28 +08:00
self.layers = nn.ModuleList([ResidualBlock(config) for _ in range(config.n_layers)])
2024-10-12 03:36:41 +08:00
2024-10-21 00:41:28 +08:00
def forward(self, x, caches=None):
if caches is None:
caches = [None] * self.config.n_layers
2024-10-12 03:36:41 +08:00
2024-10-21 00:41:28 +08:00
for i, layer in enumerate(self.layers):
x, caches[i] = layer(x, caches[i])
2024-10-12 03:36:41 +08:00
2024-10-21 00:41:28 +08:00
if caches[0] == None:
return x
2024-10-12 03:36:41 +08:00
else:
2024-10-21 00:41:28 +08:00
return x, caches
class ResidualBlock(nn.Module):
def __init__(self, config: Mamba2Config):
2024-10-12 03:36:41 +08:00
super().__init__()
2024-10-21 00:41:28 +08:00
self.config = config
2024-10-12 03:36:41 +08:00
2024-10-21 00:41:28 +08:00
self.mixer = Mamba2Block(self.config)
self.norm = RMSNorm(self.config.d_model, self.config.rms_norm_eps, self.config.mup)
2024-10-12 03:36:41 +08:00
2024-10-21 00:41:28 +08:00
def forward(self, x, cache=None):
output, cache = self.mixer(self.norm(x), cache)
output = output + x
return output, cache
2024-10-12 03:36:41 +08:00
class Mamba2Block(nn.Module):
2024-10-21 00:41:28 +08:00
def __init__(self, config: Mamba2Config):
2024-10-12 03:36:41 +08:00
super().__init__()
2024-10-21 00:41:28 +08:00
factory_kwargs = {"device": config.device, "dtype": config.dtype}
2024-10-12 03:36:41 +08:00
2024-10-21 00:41:28 +08:00
self.config = config
2024-10-12 03:36:41 +08:00
2024-10-21 00:41:28 +08:00
# [z, x, B, C, dt]
d_in_proj = 2 * self.config.d_inner + 2 * self.config.n_groups * self.config.d_state + self.config.n_heads
self.in_proj = nn.Linear(self.config.d_model, d_in_proj, bias=self.config.bias)
2024-10-12 03:36:41 +08:00
2024-10-21 00:41:28 +08:00
conv_dim = self.config.d_inner + 2 * self.config.n_groups * self.config.d_state
self.conv1d = nn.Conv1d(
in_channels=conv_dim,
out_channels=conv_dim,
bias=self.config.conv_bias,
kernel_size=self.config.d_conv,
groups=conv_dim,
padding=self.config.d_conv - 1,
**factory_kwargs,
)
2024-10-12 03:36:41 +08:00
2024-10-21 00:41:28 +08:00
# Initialize log dt bias
dt = torch.exp(
torch.rand(self.config.n_heads) * (math.log(self.config.dt_max) - math.log(self.config.dt_min))
+ math.log(self.config.dt_min)
)
dt = torch.clamp(dt, min=self.config.dt_init_floor)
inv_dt = dt + torch.log(-torch.expm1(-dt))
self.dt_bias = nn.Parameter(inv_dt)
assert self.config.A_init_range[0] > 0 and self.config.A_init_range[1] >= self.config.A_init_range[0]
A = torch.empty(self.config.n_heads, dtype=torch.float32).uniform_(*self.config.A_init_range)
self.A_log = torch.log(A).to(dtype=self.config.dtype)
self.D = nn.Parameter(torch.ones(self.config.n_heads, device=self.config.device))
2024-10-12 03:36:41 +08:00
2024-10-21 00:41:28 +08:00
self.norm = RMSNormGated(self.config.d_inner, eps=1e-5, norm_before_gate=False)
2024-10-12 03:36:41 +08:00
2024-10-21 00:41:28 +08:00
self.out_proj = nn.Linear(self.config.d_inner, self.config.d_model, bias=self.config.bias)
2024-10-12 03:36:41 +08:00
2024-10-21 00:41:28 +08:00
def forward(self, u, cache=None, seq_idx=None):
"""
u: (B, L, D)
Returns: out : same shape as u
"""
batch, length, _ = u.shape
return_cache = False
if cache is not None and length > 1:
cache = None
return_cache = True
if cache is not None:
out, cache = self.step(u, cache)
return out, cache
zxbcdt = self.in_proj(u) # (B, L, d_in_proj)
A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state)
initial_states=repeat(self.init_states, "... -> b ...", b=batch) if self.config.learnable_init_states else None
dt_limit_kwargs = {} if self.config.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.config.dt_limit)
z, xBC, dt = torch.split(
zxbcdt,
[self.config.d_inner, self.config.d_inner + 2 * self.config.n_groups * self.config.d_state, self.config.n_heads],
dim=-1
2024-10-12 03:36:41 +08:00
)
2024-10-21 00:41:28 +08:00
dt = F.softplus(dt + self.dt_bias) # (B, L, nheads)
# 1D Convolution
xBC = self.act(self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)) # (B, L, self.d_inner + 2 * n_groups * d_state)
x, B, C = torch.split(xBC, [self.config.d_inner, self.config.n_groups * self.config.d_state, self.config.n_groups * self.config.d_state], dim=-1)
y = mamba_chunk_scan_combined(
rearrange(x, "b l (h p) -> b l h p", p=self.config.d_head),
dt,
A,
rearrange(B, "b l (g n) -> b l g n", g=self.config.n_groups),
rearrange(C, "b l (g n) -> b l g n", g=self.config.n_groups),
chunk_size=self.config.chunk_size,
D=self.D,
z=None,
seq_idx=seq_idx,
initial_states=initial_states,
**dt_limit_kwargs,
2024-10-12 03:36:41 +08:00
)
2024-10-21 00:41:28 +08:00
y = rearrange(y, "b l h p -> b l (h p)")
# Multiply "gate" branch and apply extra normalization layer
y = self.norm(y, z)
out = self.out_proj(y)
return out, cache
def step(self, u, cache):
"""
u: (B, 1, D)
cache: (h_cache, conv_cache)
"""
h_cache, conv_cache = cache
zxbcdt = self.in_proj(u.squeeze(1)) # (B, 2D)
d_mlp = (zxbcdt.shape[-1] - 2 * self.config.d_inner - 2 * self.config.n_groups * self.config.d_state - self.config.n_heads) // 2
z0, x0, z, xBC, dt = torch.split(zxbcdt, [d_mlp, d_mlp, self.config.d_inner, self.config.d_inner + 2 * self.config.n_groups * self.config.d_state, self.config.n_heads], dim=-1)
# conv step
conv_cache.copy_(torch.roll(conv_cache, shifts=-1, dims=-1)) # update state (B, D, W)
conv_cache[:, :, -1] = xBC
xBC = torch.sum(conv_cache * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B, D)
if self.conv1d.bias is not None:
xBC = xBC + self.conv1d.bias
xBC = self.act(xBC).to(dtype=x.dtype)
x, B, C = torch.split(xBC, [self.config.d_inner, self.config.n_groups * self.config.d_state, self.config.n_groups * self.config.d_state], dim=-1)
A = -torch.exp(self.A_log.float()) # (n_heads)
A = repeat(A, "h -> h p n", p=self.config.d_head, n=self.config.d_state).to(dtype=torch.float32)
dt = repeat(dt, "b h -> b h p", p=self.config.d_head)
dt_bias = repeat(self.dt_bias, "h -> h p", p=self.config.d_head)
D = repeat(self.D, "h -> h p", p=self.config.d_head)
B = rearrange(B, "b (g n) -> b g n", g=self.config.n_groups)
C = rearrange(C, "b (g n) -> b g n", g=self.config.n_groups)
x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.config.d_head)
y = selective_state_update(h_cache, x_reshaped, dt, A, B, C, D, z=None, dt_bias=dt_bias, dt_softplus=True)
y = rearrange(y, "b h p -> b (h p)")
#if self.rmsnorm:
y = self.norm(y, z)
if d_mlp > 0:
y = torch.cat([F.silu(z0) * x0, y], dim=-1)
out = self.out_proj(y)
return out.unsqueeze(1), (h_cache, conv_cache)
# taken straight from https://github.com/johnma2006/mamba-minimal/blob/master/model.py
class RMSNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-5, use_mup: bool = False):
super().__init__()
2024-10-12 03:36:41 +08:00
2024-10-21 00:41:28 +08:00
self.use_mup = use_mup
self.eps = eps
2024-10-12 03:36:41 +08:00
2024-10-21 00:41:28 +08:00
# https://arxiv.org/abs/2404.05728, RMSNorm gains prevents muTransfer (section 4.2.3)
if not use_mup:
self.weight = nn.Parameter(torch.ones(d_model))
def forward(self, x):
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
if not self.use_mup:
return output * self.weight
2024-10-12 03:36:41 +08:00
else:
2024-10-21 00:41:28 +08:00
return output