generation works! trying training now

This commit is contained in:
Goekdeniz-Guelmez 2024-10-22 18:25:59 +02:00
parent c1634ce81b
commit b9c57cd429
3 changed files with 537 additions and 327 deletions

View File

@ -1,14 +1,11 @@
# Copyright © 2024 Apple Inc.
import math
from dataclasses import dataclass, field
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
# python -m mlx_lm.generate --model rokyang/mamba2-130m-hf --prompt "hello how are you."
from .base import BaseModelArgs
from .cache import MambaCache
@dataclass
class ModelArgs(BaseModelArgs):
@ -47,21 +44,6 @@ class ModelArgs(BaseModelArgs):
self.time_step_rank = math.ceil(self.hidden_size / 16)
class Mamba2Cache:
def __init__(self):
self.cache = [None, None]
def __setitem__(self, idx, value):
self.cache[idx] = value
def __getitem__(self, idx):
return self.cache[idx]
@property
def state(self):
return self.cache
class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
@ -111,7 +93,7 @@ class DepthWiseConv1d(nn.Module):
return y, x[:, -K + 1 :, :]
class Mamba2Mixer(nn.Module):
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
@ -124,35 +106,36 @@ class Mamba2Mixer(nn.Module):
self.head_dim = args.hidden_size // args.num_heads
self.n_groups = args.n_groups
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size
self.conv1d = DepthWiseConv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
bias=args.use_conv_bias,
kernel_size=args.conv_kernel,
groups=self.conv_dim,
padding=args.conv_kernel - 1
)
projection_size = self.intermediate_size + self.conv_dim + self.num_heads
# projection_size = 2 * args.intermediate_size + 2 * args.n_groups * args.state_size + args.num_heads
projection_size = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads
self.in_proj = nn.Linear(
self.hidden_size,
args.hidden_size,
projection_size,
bias=args.use_bias
)
self.A_log = mx.zeros(self.num_heads)
self.D = mx.ones(self.num_heads)
self.dt_bias = mx.zeros(self.num_heads)
# self.conv_dim = args.intermediate_size + 2 * args.n_groups * args.state_size
self.conv_dim = args.intermediate_size + 2 * args.state_size
self.conv1d = DepthWiseConv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
kernel_size=args.conv_kernel,
bias=args.use_conv_bias,
groups=self.conv_dim,
padding=args.conv_kernel - 1
)
self.norm = MambaRMSNormGated(self.intermediate_size, eps=args.layer_norm_epsilon)
self.A_log = mx.zeros(args.num_heads)
self.D = mx.ones((args.num_heads,))
self.dt_bias = mx.zeros(args.num_heads)
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias)
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
def ssm_step(self, x, state, dt_proj):
def ssm_step(self, x, state, dt):
A = -mx.exp(self.A_log)
D = self.D
delta = nn.softplus(dt_proj + self.dt_bias)
dt = nn.softplus(dt + self.dt_bias)
B, C = mx.split(x, indices_or_sections=[self.state_size * self.n_groups], axis=-1)
@ -160,13 +143,13 @@ class Mamba2Mixer(nn.Module):
B = B.reshape(batch_size, self.n_groups, self.state_size)
C = C.reshape(batch_size, -1, self.state_size)
delta = delta.reshape(batch_size, self.num_heads, 1)
dt = dt.reshape(batch_size, self.num_heads, 1)
A = A.reshape(1, self.num_heads, 1)
if state is None:
new_state = delta * B
new_state = dt * B
else:
new_state = delta * (B + state * mx.exp(delta * A))
new_state = dt * (B + state * mx.exp(dt * A))
y = mx.sum(new_state[:, :, None, :] * C[:, None, :, :], axis=(-1, -2))
y = y + D * x[:, :self.num_heads]
@ -180,26 +163,31 @@ class Mamba2Mixer(nn.Module):
outputs = []
for t in range(T):
xt = x[:, t, :]
xz = self.in_proj(xt)
zxbcdt = self.in_proj(xt)
x_t, z_t, dt_proj = mx.split(
xz,
indices_or_sections=[self.conv_dim, self.conv_dim + self.intermediate_size],
z, xBC, dt = mx.split(
zxbcdt,
# indices_or_sections=[self.conv_dim, self.conv_dim + self.intermediate_size],
indices_or_sections=[
self.intermediate_size,
self.intermediate_size + 2 * self.state_size,
self.num_heads
],
axis=-1
)
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
x_t = conv_out.squeeze(1)
x_t = nn.silu(x_t)
y_t, cache[1] = self.ssm_step(x_t, cache[1], dt_proj)
z_t = nn.silu(z_t)
# Use the new DepthWiseConv1d with caching
conv_out, cache[0] = self.conv1d(mx.expand_dims(z, 1), cache[0])
z = conv_out.squeeze(1)
z = nn.silu(z)
y_t, cache[1] = self.ssm_step(z, cache[1], dt)
xBC = nn.silu(xBC)
# Element-wise multiplication
output_t = y_t[:, :, None] * z_t[:, None, :]
output_t = y_t[:, :, None] * xBC[:, None, :]
# Sum across the second dimension to match the intermediate_size
output_t = self.norm(output_t)
output_t = output_t.sum(axis=1)
output_t = self.out_proj(output_t)
outputs.append(output_t)
@ -207,10 +195,10 @@ class Mamba2Mixer(nn.Module):
return output
class Mamba2Block(nn.Module):
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.mixer = Mamba2Mixer(args)
self.mixer = Mamba2Block(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
@ -222,24 +210,16 @@ class Mamba2(nn.Module):
super().__init__()
self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [Mamba2Block(args) for idx in range(args.num_hidden_layers)]
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(
self,
inputs: mx.array,
cache=None
):
hidden_states = self.embeddings(inputs)
def __call__(self, x: mx.array, cache):
x = self.embeddings(x)
if cache is None:
cache = Mamba2Cache(len(self.layers))
for i, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, cache[i])
hidden_states = self.norm_f(hidden_states)
return hidden_states
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
x = layer(x, c)
return self.norm_f(x)
class Model(nn.Module):
@ -247,7 +227,10 @@ class Model(nn.Module):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba2(args)
# self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
@ -261,9 +244,6 @@ class Model(nn.Module):
else:
logits = self.lm_head(x)
print(logits)
print(logits.shape)
return logits
def sanitize(self, weights):
@ -272,8 +252,8 @@ class Model(nn.Module):
weights[k] = v.moveaxis(2, 1)
return weights
def make_cache(self, batch_size: int = 1):
return [Mamba2Cache() for _ in range(len(self.layers))]
def make_cache(self):
return [MambaCache() for _ in range(len(self.layers))]
@property
def layers(self):

View File

@ -1,246 +1,411 @@
"""
mamba2-minimal
==============
A minimal, single-file implementation of the Mamba-2 model in PyTorch.
import math
> **Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality**
> Authors: Tri Dao, Albert Gu
> Paper: https://arxiv.org/abs/2405.21060
"""
import json
from dataclasses import dataclass
from typing import Union
from typing import Iterable, NamedTuple, TypeAlias, cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import LongTensor, Tensor, nn
Device: TypeAlias = str | torch.device | None
@dataclass
class Mamba2Config:
d_model: int # D
n_layers: int
d_head: int # todo : plutot n_heads non ?
d_state: int = 64 # N in paper/comments
expand_factor: int = 2 # E in paper/comments
d_conv: int = 4
n_groups: int = 1# todo : ??
A_init_range: tuple = (1, 16)
dt_min: float = 0.001
dt_max: float = 0.1
dt_init_floor: float = 1e-4
dt_limit: tuple = (0.0, float("inf"))
conv_init = None
learnable_init_states: bool = False
activation: str = "swish" # "swish" or "silu"
rms_norm_eps: float = 1e-5
base_std: float = 0.02
bias: bool = False
conv_bias: bool = True
mup: bool = False
mup_base_width: float = 128 # width=d_model
chunk_size: int = 256
use_mem_eff_path: bool = True
dtype=None
device=None
d_model: int # model dimension (D)
n_layer: int = 24 # number of Mamba-2 layers in the language model
d_state: int = 128 # state dimension (N)
d_conv: int = 4 # convolution kernel size
expand: int = 2 # expansion factor (E)
headdim: int = 64 # head dimension (P)
chunk_size: int = 64 # matrix partition size (Q)
vocab_size: int = 50277
pad_vocab_size_multiple: int = 16
def __post_init__(self):
self.d_inner = self.expand_factor * self.d_model # E*D = ED in comments
self.n_heads = self.d_inner // self.d_head
assert self.d_inner % self.d_head == 0
self.d_inner = self.expand * self.d_model
assert self.d_inner % self.headdim == 0
self.nheads = self.d_inner // self.headdim
if self.vocab_size % self.pad_vocab_size_multiple != 0:
self.vocab_size += (
self.pad_vocab_size_multiple
- self.vocab_size % self.pad_vocab_size_multiple
)
assert (self.d_inner / self.d_head) % 8 == 0, "requierement of causal_conv1d"
# muP
if self.mup:
self.mup_width_mult = self.d_model / self.mup_base_width
class InferenceCache(NamedTuple):
conv_state: Tensor # (batch, d_inner + 2 * d_state, d_conv)
ssm_state: Tensor # (batch, nheads, headdim, d_state)
@staticmethod
def alloc(batch_size: int, args: Mamba2Config, device: Device = None):
return InferenceCache(
torch.zeros(
batch_size, args.d_inner + 2 * args.d_state, args.d_conv, device=device
),
torch.zeros(
batch_size, args.nheads, args.headdim, args.d_state, device=device
),
)
class Mamba2LMHeadModel(nn.Module):
def __init__(self, args: Mamba2Config, device: Device = None):
super().__init__()
self.args = args
self.device = device
self.backbone = nn.ModuleDict(
dict(
embedding=nn.Embedding(args.vocab_size, args.d_model, device=device),
layers=nn.ModuleList(
[
nn.ModuleDict(
dict(
mixer=Mamba2(args, device=device),
norm=RMSNorm(args.d_model, device=device),
)
)
for _ in range(args.n_layer)
]
),
norm_f=RMSNorm(args.d_model, device=device),
)
)
self.lm_head = nn.Linear(
args.d_model, args.vocab_size, bias=False, device=device
)
self.lm_head.weight = self.backbone.embedding.weight
def forward(
self, input_ids: LongTensor, h: list[InferenceCache] | list[None] | None = None
) -> tuple[LongTensor, list[InferenceCache]]:
"""
Arguments
input_ids: (batch, seqlen) tokens from `EleutherAI/gpt-neox-20b` tokenizer
h: hidden states for inference step. If present the constant-time
(wrt sequence length) inference path will be taken, input_ids
should have shape (batch, 1) containing the next batch of prompt
token.
Return (logits, h)
logits: (batch, seqlen, vocab_size)
h: updated inference cache after processing `input_ids`
"""
seqlen = input_ids.shape[1]
if h is None:
h = [None for _ in range(self.args.n_layer)]
x = self.backbone.embedding(input_ids)
for i, layer in enumerate(self.backbone.layers):
y, h[i] = layer.mixer(layer.norm(x), h[i])
x = y + x
x = self.backbone.norm_f(x)
logits = self.lm_head(x)
return logits[:, :seqlen], cast(list[InferenceCache], h)
def generate(
self,
input_ids: LongTensor,
max_new_length: int = 20,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 1.0,
eos_token_id: int = 0,
) -> Iterable[tuple[int, list[InferenceCache]]]:
prefix, tokens = input_ids[:-1], input_ids[-1:].unsqueeze(0)
# Process prompt
# The input sequence to forward (non-inference path) must have length multiple that of chunk_size.
# We split out excess tokens so that n_chunked tokens can be processed by one forward call and
# process the rest in multiple inference steps.
n_chunked = (prefix.shape[0] // self.args.chunk_size) * self.args.chunk_size
if n_chunked > 0:
_, h = self(prefix[:n_chunked].unsqueeze(0), None)
else:
h = [
InferenceCache.alloc(1, self.args, device=self.device)
for _ in range(self.args.n_layer)
]
for i in range(n_chunked, prefix.shape[0]):
_, h = self(prefix[i : i + 1].unsqueeze(0), h)
# Generate
for _ in range(max_new_length):
with torch.no_grad():
out, h = self(tokens, h)
logits = out[0, -1]
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, k=top_k)[0][-1]
logits[indices_to_remove] = -torch.inf
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > 0.5
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
sorted_indices_to_remove[0] = False
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = -torch.inf
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
if next_token.item() == eos_token_id:
return
tokens = next_token.unsqueeze(0)
yield cast(int, next_token.item()), h
class Mamba2(nn.Module):
def __init__(self, config: Mamba2Config):
def __init__(self, args: Mamba2Config, device: Device = None):
super().__init__()
self.args = args
self.device = device
self.config = config
# Order: (z, x, B, C, dt)
d_in_proj = 2 * args.d_inner + 2 * args.d_state + args.nheads
self.in_proj = nn.Linear(args.d_model, d_in_proj, bias=False, device=device)
self.layers = nn.ModuleList([ResidualBlock(config) for _ in range(config.n_layers)])
def forward(self, x, caches=None):
if caches is None:
caches = [None] * self.config.n_layers
for i, layer in enumerate(self.layers):
x, caches[i] = layer(x, caches[i])
if caches[0] == None:
return x
else:
return x, caches
class ResidualBlock(nn.Module):
def __init__(self, config: Mamba2Config):
super().__init__()
self.config = config
self.mixer = Mamba2Block(self.config)
self.norm = RMSNorm(self.config.d_model, self.config.rms_norm_eps, self.config.mup)
def forward(self, x, cache=None):
output, cache = self.mixer(self.norm(x), cache)
output = output + x
return output, cache
class Mamba2Block(nn.Module):
def __init__(self, config: Mamba2Config):
super().__init__()
factory_kwargs = {"device": config.device, "dtype": config.dtype}
self.config = config
# [z, x, B, C, dt]
d_in_proj = 2 * self.config.d_inner + 2 * self.config.n_groups * self.config.d_state + self.config.n_heads
self.in_proj = nn.Linear(self.config.d_model, d_in_proj, bias=self.config.bias)
conv_dim = self.config.d_inner + 2 * self.config.n_groups * self.config.d_state
conv_dim = args.d_inner + 2 * args.d_state
self.conv1d = nn.Conv1d(
in_channels=conv_dim,
out_channels=conv_dim,
bias=self.config.conv_bias,
kernel_size=self.config.d_conv,
kernel_size=args.d_conv,
groups=conv_dim,
padding=self.config.d_conv - 1,
**factory_kwargs,
padding=args.d_conv - 1,
device=device,
)
self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device))
self.A_log = nn.Parameter(torch.empty(args.nheads, device=device))
self.D = nn.Parameter(torch.empty(args.nheads, device=device))
self.norm = RMSNorm(args.d_inner, device=device)
self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device)
# Initialize log dt bias
dt = torch.exp(
torch.rand(self.config.n_heads) * (math.log(self.config.dt_max) - math.log(self.config.dt_min))
+ math.log(self.config.dt_min)
)
dt = torch.clamp(dt, min=self.config.dt_init_floor)
inv_dt = dt + torch.log(-torch.expm1(-dt))
self.dt_bias = nn.Parameter(inv_dt)
assert self.config.A_init_range[0] > 0 and self.config.A_init_range[1] >= self.config.A_init_range[0]
A = torch.empty(self.config.n_heads, dtype=torch.float32).uniform_(*self.config.A_init_range)
self.A_log = torch.log(A).to(dtype=self.config.dtype)
self.D = nn.Parameter(torch.ones(self.config.n_heads, device=self.config.device))
self.norm = RMSNormGated(self.config.d_inner, eps=1e-5, norm_before_gate=False)
self.out_proj = nn.Linear(self.config.d_inner, self.config.d_model, bias=self.config.bias)
def forward(self, u, cache=None, seq_idx=None):
def forward(self, u: Tensor, h: InferenceCache | None = None):
"""
u: (B, L, D)
Returns: out : same shape as u
Arguments
u: (batch, seqlen, d_model) input. seqlen should be a multiple of chunk_size.
h: hidden states for inference step. Initialized to 0s if not present.
Return (y, h)
y: (batch, seqlen, d_model) output
h: updated inference cache after processing `u`
"""
if h:
return self.step(u, h)
batch, length, _ = u.shape
return_cache = False
if cache is not None and length > 1:
cache = None
return_cache = True
if cache is not None:
out, cache = self.step(u, cache)
return out, cache
zxbcdt = self.in_proj(u) # (B, L, d_in_proj)
A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state)
initial_states=repeat(self.init_states, "... -> b ...", b=batch) if self.config.learnable_init_states else None
dt_limit_kwargs = {} if self.config.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.config.dt_limit)
A = -torch.exp(self.A_log) # (nheads,)
zxbcdt = self.in_proj(u) # (batch, seqlen, d_in_proj)
z, xBC, dt = torch.split(
zxbcdt,
[self.config.d_inner, self.config.d_inner + 2 * self.config.n_groups * self.config.d_state, self.config.n_heads],
dim=-1
[
self.args.d_inner,
self.args.d_inner + 2 * self.args.d_state,
self.args.nheads,
],
dim=-1,
)
dt = F.softplus(dt + self.dt_bias) # (B, L, nheads)
dt = F.softplus(dt + self.dt_bias) # (batch, seqlen, nheads)
# 1D Convolution
xBC = self.act(self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)) # (B, L, self.d_inner + 2 * n_groups * d_state)
x, B, C = torch.split(xBC, [self.config.d_inner, self.config.n_groups * self.config.d_state, self.config.n_groups * self.config.d_state], dim=-1)
y = mamba_chunk_scan_combined(
rearrange(x, "b l (h p) -> b l h p", p=self.config.d_head),
dt,
A,
rearrange(B, "b l (g n) -> b l g n", g=self.config.n_groups),
rearrange(C, "b l (g n) -> b l g n", g=self.config.n_groups),
chunk_size=self.config.chunk_size,
D=self.D,
z=None,
seq_idx=seq_idx,
initial_states=initial_states,
**dt_limit_kwargs,
# Pad or truncate xBC seqlen to d_conv
conv_state = F.pad(
rearrange(xBC, "b l d -> b d l"), (self.args.d_conv - u.shape[1], 0)
)
xBC = silu(
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, : u.shape[1], :]
) # (batch, seqlen, d_inner + 2 * d_state))
x, B, C = torch.split(
xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1
)
x = rearrange(x, "b l (h p) -> b l h p", p=self.args.headdim)
y, ssm_state = ssd(
x * dt.unsqueeze(-1),
A * dt,
rearrange(B, "b l n -> b l 1 n"),
rearrange(C, "b l n -> b l 1 n"),
self.args.chunk_size,
device=self.device,
)
y = y + x * self.D.unsqueeze(-1)
y = rearrange(y, "b l h p -> b l (h p)")
# Multiply "gate" branch and apply extra normalization layer
y = self.norm(y, z)
out = self.out_proj(y)
return out, cache
y = self.out_proj(y)
def step(self, u, cache):
"""
u: (B, 1, D)
cache: (h_cache, conv_cache)
h = InferenceCache(conv_state, ssm_state)
return y, h
def step(self, u: Tensor, h: InferenceCache) -> tuple[Tensor, InferenceCache]:
"""Take a single inference step for the current input and hidden state
Unlike attention-based models, RNN-based models (eg Mamba) does not need
to look back at all the past tokens to generate a new token. Instead a
hidden state (initialized to 0s initially) is updated for each input and
passed to the next inference step. This means that the total inference
time is linear with respect to the sequence length instead of quadratic
in attention's case.
Arguments
u: (batch, 1, d_model)
h: initial/running hidden state
Return (y, h)
y: (batch, 1, d_model)
h: updated hidden state
"""
assert u.shape[1] == 1, "Only one token can be decoded per inference step"
h_cache, conv_cache = cache
zxbcdt = self.in_proj(u.squeeze(1)) # (batch, d_in_proj)
z, xBC, dt = torch.split(
zxbcdt,
[
self.args.d_inner,
self.args.d_inner + 2 * self.args.d_state,
self.args.nheads,
],
dim=-1,
)
zxbcdt = self.in_proj(u.squeeze(1)) # (B, 2D)
d_mlp = (zxbcdt.shape[-1] - 2 * self.config.d_inner - 2 * self.config.n_groups * self.config.d_state - self.config.n_heads) // 2
z0, x0, z, xBC, dt = torch.split(zxbcdt, [d_mlp, d_mlp, self.config.d_inner, self.config.d_inner + 2 * self.config.n_groups * self.config.d_state, self.config.n_heads], dim=-1)
# Advance convolution input
h.conv_state.copy_(torch.roll(h.conv_state, shifts=-1, dims=-1))
h.conv_state[:, :, -1] = xBC
# Convolution step
xBC = torch.sum(
h.conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
)
xBC += self.conv1d.bias
xBC = silu(xBC)
# conv step
conv_cache.copy_(torch.roll(conv_cache, shifts=-1, dims=-1)) # update state (B, D, W)
conv_cache[:, :, -1] = xBC
xBC = torch.sum(conv_cache * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B, D)
if self.conv1d.bias is not None:
xBC = xBC + self.conv1d.bias
xBC = self.act(xBC).to(dtype=x.dtype)
x, B, C = torch.split(
xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1
)
A = -torch.exp(self.A_log) # (nheads,)
x, B, C = torch.split(xBC, [self.config.d_inner, self.config.n_groups * self.config.d_state, self.config.n_groups * self.config.d_state], dim=-1)
A = -torch.exp(self.A_log.float()) # (n_heads)
A = repeat(A, "h -> h p n", p=self.config.d_head, n=self.config.d_state).to(dtype=torch.float32)
dt = repeat(dt, "b h -> b h p", p=self.config.d_head)
dt_bias = repeat(self.dt_bias, "h -> h p", p=self.config.d_head)
D = repeat(self.D, "h -> h p", p=self.config.d_head)
B = rearrange(B, "b (g n) -> b g n", g=self.config.n_groups)
C = rearrange(C, "b (g n) -> b g n", g=self.config.n_groups)
x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.config.d_head)
y = selective_state_update(h_cache, x_reshaped, dt, A, B, C, D, z=None, dt_bias=dt_bias, dt_softplus=True)
# SSM step
dt = F.softplus(dt + self.dt_bias) # (batch, nheads)
dA = torch.exp(dt * A) # (batch, nheads)
x = rearrange(x, "b (h p) -> b h p", p=self.args.headdim)
dBx = torch.einsum("bh, bn, bhp -> bhpn", dt, B, x)
h.ssm_state.copy_(h.ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
y = torch.einsum("bhpn, bn -> bhp", h.ssm_state, C)
y = y + rearrange(self.D, "h -> h 1") * x
y = rearrange(y, "b h p -> b (h p)")
#if self.rmsnorm:
y = self.norm(y, z)
if d_mlp > 0:
y = torch.cat([F.silu(z0) * x0, y], dim=-1)
out = self.out_proj(y)
return out.unsqueeze(1), (h_cache, conv_cache)
y = self.out_proj(y)
return y.unsqueeze(1), h
def segsum(x: Tensor, device: Device = None) -> Tensor:
"""Stable segment sum calculation.
`exp(segsum(A))` produces a 1-semiseparable matrix, which is equivalent to a scalar SSM.
Source: https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L23-L32
"""
T = x.size(-1)
x = repeat(x, "... d -> ... d e", e=T)
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=-1)
x = x.masked_fill(~mask, 0)
x_segsum = torch.cumsum(x, dim=-2)
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=0)
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
return x_segsum
def ssd(x, A, B, C, chunk_size, initial_states=None, device: Device = None):
"""Structed State Space Duality (SSD) - the core of Mamba-2
This is almost the exact same minimal SSD code from the blog post.
Arguments
x: (batch, seqlen, n_heads, d_head)
A: (batch, seqlen, n_heads)
B: (batch, seqlen, n_heads, d_state)
C: (batch, seqlen, n_heads, d_state)
Return
y: (batch, seqlen, n_heads, d_head)
Source
1. https://tridao.me/blog/2024/mamba2-part3-algorithm/
2. https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L34-L78
"""
assert x.shape[1] % chunk_size == 0
# Rearrange into chunks
# Step 1, 2 and 4 of SSD can be computed in parallel for each chunk across devices (sequence parallel)
# This is not implemented and left as an exercise for the reader 😜
x, A, B, C = [
rearrange(m, "b (c l) ... -> b c l ...", l=chunk_size) for m in (x, A, B, C)
]
A = rearrange(A, "b c l h -> b h c l")
A_cumsum = torch.cumsum(A, dim=-1)
# 1. Compute the output for each intra-chunk (diagonal blocks)
L = torch.exp(segsum(A, device=device))
Y_diag = torch.einsum("bclhn, bcshn, bhcls, bcshp -> bclhp", C, B, L, x)
# 2. Compute the state for each intra-chunk
# (right term of low-rank factorization of off-diagonal blocks; B terms)
decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
states = torch.einsum("bclhn, bhcl, bclhp -> bchpn", B, decay_states, x)
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
if initial_states is None:
initial_states = torch.zeros_like(states[:, :1])
states = torch.cat([initial_states, states], dim=1)
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)), device=device))
new_states = torch.einsum("bhzc, bchpn -> bzhpn", decay_chunk, states)
states, final_state = new_states[:, :-1], new_states[:, -1]
# 4. Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = torch.exp(A_cumsum)
Y_off = torch.einsum("bclhn, bchpn, bhcl -> bclhp", C, states, state_decay_out)
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
return Y, final_state
# taken straight from https://github.com/johnma2006/mamba-minimal/blob/master/model.py
class RMSNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-5, use_mup: bool = False):
def __init__(self, d: int, eps: float = 1e-5, device: Device = None):
"""Gated Root Mean Square Layer Normalization
Paper: https://arxiv.org/abs/1910.07467
"""
super().__init__()
self.use_mup = use_mup
self.eps = eps
self.weight = nn.Parameter(torch.ones(d, device=device))
# https://arxiv.org/abs/2404.05728, RMSNorm gains prevents muTransfer (section 4.2.3)
if not use_mup:
self.weight = nn.Parameter(torch.ones(d_model))
def forward(self, x, z=None):
if z is not None:
x = x * silu(z)
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
def forward(self, x):
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
if not self.use_mup:
return output * self.weight
else:
return output
def silu(x):
"""Applies the Sigmoid Linear Unit (SiLU), element-wise.
Define this manually since torch's version doesn't seem to work on MPS.
"""
return x * F.sigmoid(x)

View File

@ -106,14 +106,16 @@ class Mamba2Block(nn.Module):
self.head_dim = args.hidden_size // args.num_heads
self.n_groups = args.n_groups
projection_size = 2 * args.intermediate_size + 2 * args.n_groups * args.state_size + args.num_heads
# projection_size = 2 * args.intermediate_size + 2 * args.n_groups * args.state_size + args.num_heads
projection_size = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads
self.in_proj = nn.Linear(
args.hidden_size,
projection_size,
bias=args.use_bias
)
self.conv_dim = args.intermediate_size + 2 * args.n_groups * args.state_size
# self.conv_dim = args.intermediate_size + 2 * args.n_groups * args.state_size
self.conv_dim = args.intermediate_size + 2 * args.state_size
self.conv1d = DepthWiseConv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
@ -130,62 +132,125 @@ class Mamba2Block(nn.Module):
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
def ssm_step(self, x, state, dt):
def _ssd(self, x, A, B, C, chunk_size):
batch, seq_len, nheads, head_dim = x.shape
n_state = B.shape[-1]
h = mx.zeros((batch, nheads, head_dim, n_state))
ys = []
for i in range(0, seq_len, chunk_size):
chunk_size_i = min(chunk_size, seq_len - i)
xi = x[:, i:i + chunk_size_i]
Bi = B[:, i:i + chunk_size_i]
Ci = C[:, i:i + chunk_size_i]
for t in range(chunk_size_i):
h = h * mx.exp(A)[:, None, None]
h = h + mx.expand_dims(Bi[:, t], -2) * mx.expand_dims(xi[:, t], -1)
y = mx.sum(h * mx.expand_dims(Ci[:, t], -2), axis=-1)
ys.append(y)
y = mx.stack(ys, axis=1)
return y, h
def __call__(self, x: mx.array, cache) -> mx.array:
if cache is not None:
return self.step(x, cache)
A = -mx.exp(self.A_log)
D = self.D
dt = nn.softplus(dt + self.dt_bias)
zxbcdt = self.in_proj(u)
B, C = mx.split(x, indices_or_sections=[self.state_size * self.n_groups], axis=-1)
z, xBC, dt = mx.split(
zxbcdt,
[
self.args.d_inner,
self.args.d_inner + 2 * self.args.d_state,
self.args.nheads,
],
axis=-1,
)
batch_size = B.shape[0]
B = B.reshape(batch_size, self.n_groups, self.state_size)
C = C.reshape(batch_size, -1, self.state_size)
dt = mx.softplus(dt + self.dt_bias)
dt = dt.reshape(batch_size, self.num_heads, 1)
A = A.reshape(1, self.num_heads, 1)
# Use the custom DepthWiseConv1d with cache
xBC = self.conv1d(xBC, cache, cache_idx=0)
xBC = mx.sigmoid(xBC) * xBC # SiLU activation
if state is None:
new_state = dt * B
else:
new_state = dt * (B + state * mx.exp(dt * A))
x, B, C = mx.split(
xBC,
[self.args.d_inner, self.args.d_state, self.args.d_state],
axis=-1
)
y = mx.sum(new_state[:, :, None, :] * C[:, None, :, :], axis=(-1, -2))
y = y + D * x[:, :self.num_heads]
return y, new_state
x = self._reshape_heads(x, True)
B = mx.expand_dims(B, axis=2)
C = mx.expand_dims(C, axis=2)
def __call__(self, x, cache):
B, T, D = x.shape
if cache is None:
cache = [None, None]
y, ssm_state = self._ssd(
x * mx.expand_dims(dt, -1),
A * dt,
B,
C,
self.args.chunk_size
)
outputs = []
for t in range(T):
xt = x[:, t, :]
zxbcdt = self.in_proj(xt)
y = y + x * mx.expand_dims(self.D, -1)
y = self._reshape_heads(y, False)
y = self.norm(y, z)
y = self.out_proj(y)
z, xBC, dt = mx.split(
zxbcdt,
indices_or_sections=[self.conv_dim, self.conv_dim + self.intermediate_size],
axis=-1
)
if cache is not None:
cache[1] = ssm_state
# Use the new DepthWiseConv1d with caching
conv_out, cache[0] = self.conv1d(mx.expand_dims(z, 1), cache[0])
z = conv_out.squeeze(1)
z = nn.silu(z)
y_t, cache[1] = self.ssm_step(z, cache[1], dt)
xBC = nn.silu(xBC)
return y
# Element-wise multiplication
output_t = y_t[:, :, None] * xBC[:, None, :]
def step(self, x: mx.array, cache) -> mx.array:
"""Single inference step"""
assert x.shape[1] == 1, "Only one token can be decoded per inference step"
output_t = self.norm(output_t)
output_t = output_t.sum(axis=1)
output_t = self.out_proj(output_t)
outputs.append(output_t)
zxbcdt = self.in_proj(mx.squeeze(x, 1))
z, xBC, dt = mx.split(
zxbcdt,
[
self.args.d_inner,
self.args.d_inner + 2 * self.args.d_state,
self.args.nheads,
],
axis=-1,
)
output = mx.stack(outputs, axis=1)
return output
# Use the custom DepthWiseConv1d with cache
xBC = self.conv1d(xBC, cache, cache_idx=0)
xBC = mx.sigmoid(xBC) * xBC # SiLU activation
x, B, C = mx.split(
xBC,
[self.args.d_inner, self.args.d_state, self.args.d_state],
axis=-1
)
A = -mx.exp(self.A_log)
dt = mx.softplus(dt + self.dt_bias)
dA = mx.exp(dt * A)
x = mx.reshape(x, (-1, self.args.nheads, self.args.headdim))
ssm_state = cache[1]
dBx = mx.expand_dims(dt, -1) * mx.expand_dims(B, 1) * mx.expand_dims(x, -1)
ssm_state = ssm_state * mx.expand_dims(mx.expand_dims(dA, -1), -1) + dBx
y = mx.sum(ssm_state * mx.expand_dims(mx.expand_dims(C, 1), 1), axis=-1)
y = y + mx.expand_dims(self.D, -1) * x
y = mx.reshape(y, (-1, self.args.nheads * self.args.headdim))
y = self.norm(y, z)
y = self.out_proj(y)
# Update SSM state in cache
cache[1] = ssm_state
return mx.expand_dims(y, 1)
class ResidualBlock(nn.Module):