mlx-examples/llms/mlx_lm/models/mamba2-prch.py
2024-10-20 18:41:28 +02:00

246 lines
8.7 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
from dataclasses import dataclass
from typing import Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
@dataclass
class Mamba2Config:
d_model: int # D
n_layers: int
d_head: int # todo : plutot n_heads non ?
d_state: int = 64 # N in paper/comments
expand_factor: int = 2 # E in paper/comments
d_conv: int = 4
n_groups: int = 1# todo : ??
A_init_range: tuple = (1, 16)
dt_min: float = 0.001
dt_max: float = 0.1
dt_init_floor: float = 1e-4
dt_limit: tuple = (0.0, float("inf"))
conv_init = None
learnable_init_states: bool = False
activation: str = "swish" # "swish" or "silu"
rms_norm_eps: float = 1e-5
base_std: float = 0.02
bias: bool = False
conv_bias: bool = True
mup: bool = False
mup_base_width: float = 128 # width=d_model
chunk_size: int = 256
use_mem_eff_path: bool = True
dtype=None
device=None
def __post_init__(self):
self.d_inner = self.expand_factor * self.d_model # E*D = ED in comments
self.n_heads = self.d_inner // self.d_head
assert self.d_inner % self.d_head == 0
assert (self.d_inner / self.d_head) % 8 == 0, "requierement of causal_conv1d"
# muP
if self.mup:
self.mup_width_mult = self.d_model / self.mup_base_width
class Mamba2(nn.Module):
def __init__(self, config: Mamba2Config):
super().__init__()
self.config = config
self.layers = nn.ModuleList([ResidualBlock(config) for _ in range(config.n_layers)])
def forward(self, x, caches=None):
if caches is None:
caches = [None] * self.config.n_layers
for i, layer in enumerate(self.layers):
x, caches[i] = layer(x, caches[i])
if caches[0] == None:
return x
else:
return x, caches
class ResidualBlock(nn.Module):
def __init__(self, config: Mamba2Config):
super().__init__()
self.config = config
self.mixer = Mamba2Block(self.config)
self.norm = RMSNorm(self.config.d_model, self.config.rms_norm_eps, self.config.mup)
def forward(self, x, cache=None):
output, cache = self.mixer(self.norm(x), cache)
output = output + x
return output, cache
class Mamba2Block(nn.Module):
def __init__(self, config: Mamba2Config):
super().__init__()
factory_kwargs = {"device": config.device, "dtype": config.dtype}
self.config = config
# [z, x, B, C, dt]
d_in_proj = 2 * self.config.d_inner + 2 * self.config.n_groups * self.config.d_state + self.config.n_heads
self.in_proj = nn.Linear(self.config.d_model, d_in_proj, bias=self.config.bias)
conv_dim = self.config.d_inner + 2 * self.config.n_groups * self.config.d_state
self.conv1d = nn.Conv1d(
in_channels=conv_dim,
out_channels=conv_dim,
bias=self.config.conv_bias,
kernel_size=self.config.d_conv,
groups=conv_dim,
padding=self.config.d_conv - 1,
**factory_kwargs,
)
# Initialize log dt bias
dt = torch.exp(
torch.rand(self.config.n_heads) * (math.log(self.config.dt_max) - math.log(self.config.dt_min))
+ math.log(self.config.dt_min)
)
dt = torch.clamp(dt, min=self.config.dt_init_floor)
inv_dt = dt + torch.log(-torch.expm1(-dt))
self.dt_bias = nn.Parameter(inv_dt)
assert self.config.A_init_range[0] > 0 and self.config.A_init_range[1] >= self.config.A_init_range[0]
A = torch.empty(self.config.n_heads, dtype=torch.float32).uniform_(*self.config.A_init_range)
self.A_log = torch.log(A).to(dtype=self.config.dtype)
self.D = nn.Parameter(torch.ones(self.config.n_heads, device=self.config.device))
self.norm = RMSNormGated(self.config.d_inner, eps=1e-5, norm_before_gate=False)
self.out_proj = nn.Linear(self.config.d_inner, self.config.d_model, bias=self.config.bias)
def forward(self, u, cache=None, seq_idx=None):
"""
u: (B, L, D)
Returns: out : same shape as u
"""
batch, length, _ = u.shape
return_cache = False
if cache is not None and length > 1:
cache = None
return_cache = True
if cache is not None:
out, cache = self.step(u, cache)
return out, cache
zxbcdt = self.in_proj(u) # (B, L, d_in_proj)
A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state)
initial_states=repeat(self.init_states, "... -> b ...", b=batch) if self.config.learnable_init_states else None
dt_limit_kwargs = {} if self.config.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.config.dt_limit)
z, xBC, dt = torch.split(
zxbcdt,
[self.config.d_inner, self.config.d_inner + 2 * self.config.n_groups * self.config.d_state, self.config.n_heads],
dim=-1
)
dt = F.softplus(dt + self.dt_bias) # (B, L, nheads)
# 1D Convolution
xBC = self.act(self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)) # (B, L, self.d_inner + 2 * n_groups * d_state)
x, B, C = torch.split(xBC, [self.config.d_inner, self.config.n_groups * self.config.d_state, self.config.n_groups * self.config.d_state], dim=-1)
y = mamba_chunk_scan_combined(
rearrange(x, "b l (h p) -> b l h p", p=self.config.d_head),
dt,
A,
rearrange(B, "b l (g n) -> b l g n", g=self.config.n_groups),
rearrange(C, "b l (g n) -> b l g n", g=self.config.n_groups),
chunk_size=self.config.chunk_size,
D=self.D,
z=None,
seq_idx=seq_idx,
initial_states=initial_states,
**dt_limit_kwargs,
)
y = rearrange(y, "b l h p -> b l (h p)")
# Multiply "gate" branch and apply extra normalization layer
y = self.norm(y, z)
out = self.out_proj(y)
return out, cache
def step(self, u, cache):
"""
u: (B, 1, D)
cache: (h_cache, conv_cache)
"""
h_cache, conv_cache = cache
zxbcdt = self.in_proj(u.squeeze(1)) # (B, 2D)
d_mlp = (zxbcdt.shape[-1] - 2 * self.config.d_inner - 2 * self.config.n_groups * self.config.d_state - self.config.n_heads) // 2
z0, x0, z, xBC, dt = torch.split(zxbcdt, [d_mlp, d_mlp, self.config.d_inner, self.config.d_inner + 2 * self.config.n_groups * self.config.d_state, self.config.n_heads], dim=-1)
# conv step
conv_cache.copy_(torch.roll(conv_cache, shifts=-1, dims=-1)) # update state (B, D, W)
conv_cache[:, :, -1] = xBC
xBC = torch.sum(conv_cache * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B, D)
if self.conv1d.bias is not None:
xBC = xBC + self.conv1d.bias
xBC = self.act(xBC).to(dtype=x.dtype)
x, B, C = torch.split(xBC, [self.config.d_inner, self.config.n_groups * self.config.d_state, self.config.n_groups * self.config.d_state], dim=-1)
A = -torch.exp(self.A_log.float()) # (n_heads)
A = repeat(A, "h -> h p n", p=self.config.d_head, n=self.config.d_state).to(dtype=torch.float32)
dt = repeat(dt, "b h -> b h p", p=self.config.d_head)
dt_bias = repeat(self.dt_bias, "h -> h p", p=self.config.d_head)
D = repeat(self.D, "h -> h p", p=self.config.d_head)
B = rearrange(B, "b (g n) -> b g n", g=self.config.n_groups)
C = rearrange(C, "b (g n) -> b g n", g=self.config.n_groups)
x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.config.d_head)
y = selective_state_update(h_cache, x_reshaped, dt, A, B, C, D, z=None, dt_bias=dt_bias, dt_softplus=True)
y = rearrange(y, "b h p -> b (h p)")
#if self.rmsnorm:
y = self.norm(y, z)
if d_mlp > 0:
y = torch.cat([F.silu(z0) * x0, y], dim=-1)
out = self.out_proj(y)
return out.unsqueeze(1), (h_cache, conv_cache)
# taken straight from https://github.com/johnma2006/mamba-minimal/blob/master/model.py
class RMSNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-5, use_mup: bool = False):
super().__init__()
self.use_mup = use_mup
self.eps = eps
# https://arxiv.org/abs/2404.05728, RMSNorm gains prevents muTransfer (section 4.2.3)
if not use_mup:
self.weight = nn.Parameter(torch.ones(d_model))
def forward(self, x):
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
if not self.use_mup:
return output * self.weight
else:
return output