generation works! trying training now

This commit is contained in:
Goekdeniz-Guelmez 2024-10-22 18:25:59 +02:00
parent c1634ce81b
commit b9c57cd429
3 changed files with 537 additions and 327 deletions

View File

@ -1,14 +1,11 @@
# Copyright © 2024 Apple Inc.
import math import math
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Tuple, Union from typing import Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs
# python -m mlx_lm.generate --model rokyang/mamba2-130m-hf --prompt "hello how are you." from .base import BaseModelArgs
from .cache import MambaCache
@dataclass @dataclass
class ModelArgs(BaseModelArgs): class ModelArgs(BaseModelArgs):
@ -47,21 +44,6 @@ class ModelArgs(BaseModelArgs):
self.time_step_rank = math.ceil(self.hidden_size / 16) self.time_step_rank = math.ceil(self.hidden_size / 16)
class Mamba2Cache:
def __init__(self):
self.cache = [None, None]
def __setitem__(self, idx, value):
self.cache[idx] = value
def __getitem__(self, idx):
return self.cache[idx]
@property
def state(self):
return self.cache
class MambaRMSNormGated(nn.Module): class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
super().__init__() super().__init__()
@ -111,7 +93,7 @@ class DepthWiseConv1d(nn.Module):
return y, x[:, -K + 1 :, :] return y, x[:, -K + 1 :, :]
class Mamba2Mixer(nn.Module): class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.args = args self.args = args
@ -124,35 +106,36 @@ class Mamba2Mixer(nn.Module):
self.head_dim = args.hidden_size // args.num_heads self.head_dim = args.hidden_size // args.num_heads
self.n_groups = args.n_groups self.n_groups = args.n_groups
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size # projection_size = 2 * args.intermediate_size + 2 * args.n_groups * args.state_size + args.num_heads
self.conv1d = DepthWiseConv1d( projection_size = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads
in_channels=self.conv_dim,
out_channels=self.conv_dim,
bias=args.use_conv_bias,
kernel_size=args.conv_kernel,
groups=self.conv_dim,
padding=args.conv_kernel - 1
)
projection_size = self.intermediate_size + self.conv_dim + self.num_heads
self.in_proj = nn.Linear( self.in_proj = nn.Linear(
self.hidden_size, args.hidden_size,
projection_size, projection_size,
bias=args.use_bias bias=args.use_bias
) )
self.A_log = mx.zeros(self.num_heads) # self.conv_dim = args.intermediate_size + 2 * args.n_groups * args.state_size
self.D = mx.ones(self.num_heads) self.conv_dim = args.intermediate_size + 2 * args.state_size
self.dt_bias = mx.zeros(self.num_heads) self.conv1d = DepthWiseConv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
kernel_size=args.conv_kernel,
bias=args.use_conv_bias,
groups=self.conv_dim,
padding=args.conv_kernel - 1
)
self.norm = MambaRMSNormGated(self.intermediate_size, eps=args.layer_norm_epsilon) self.A_log = mx.zeros(args.num_heads)
self.D = mx.ones((args.num_heads,))
self.dt_bias = mx.zeros(args.num_heads)
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
def ssm_step(self, x, state, dt_proj): def ssm_step(self, x, state, dt):
A = -mx.exp(self.A_log) A = -mx.exp(self.A_log)
D = self.D D = self.D
delta = nn.softplus(dt_proj + self.dt_bias) dt = nn.softplus(dt + self.dt_bias)
B, C = mx.split(x, indices_or_sections=[self.state_size * self.n_groups], axis=-1) B, C = mx.split(x, indices_or_sections=[self.state_size * self.n_groups], axis=-1)
@ -160,13 +143,13 @@ class Mamba2Mixer(nn.Module):
B = B.reshape(batch_size, self.n_groups, self.state_size) B = B.reshape(batch_size, self.n_groups, self.state_size)
C = C.reshape(batch_size, -1, self.state_size) C = C.reshape(batch_size, -1, self.state_size)
delta = delta.reshape(batch_size, self.num_heads, 1) dt = dt.reshape(batch_size, self.num_heads, 1)
A = A.reshape(1, self.num_heads, 1) A = A.reshape(1, self.num_heads, 1)
if state is None: if state is None:
new_state = delta * B new_state = dt * B
else: else:
new_state = delta * (B + state * mx.exp(delta * A)) new_state = dt * (B + state * mx.exp(dt * A))
y = mx.sum(new_state[:, :, None, :] * C[:, None, :, :], axis=(-1, -2)) y = mx.sum(new_state[:, :, None, :] * C[:, None, :, :], axis=(-1, -2))
y = y + D * x[:, :self.num_heads] y = y + D * x[:, :self.num_heads]
@ -180,26 +163,31 @@ class Mamba2Mixer(nn.Module):
outputs = [] outputs = []
for t in range(T): for t in range(T):
xt = x[:, t, :] xt = x[:, t, :]
xz = self.in_proj(xt) zxbcdt = self.in_proj(xt)
x_t, z_t, dt_proj = mx.split( z, xBC, dt = mx.split(
xz, zxbcdt,
indices_or_sections=[self.conv_dim, self.conv_dim + self.intermediate_size], # indices_or_sections=[self.conv_dim, self.conv_dim + self.intermediate_size],
indices_or_sections=[
self.intermediate_size,
self.intermediate_size + 2 * self.state_size,
self.num_heads
],
axis=-1 axis=-1
) )
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0]) # Use the new DepthWiseConv1d with caching
x_t = conv_out.squeeze(1) conv_out, cache[0] = self.conv1d(mx.expand_dims(z, 1), cache[0])
x_t = nn.silu(x_t) z = conv_out.squeeze(1)
y_t, cache[1] = self.ssm_step(x_t, cache[1], dt_proj) z = nn.silu(z)
z_t = nn.silu(z_t) y_t, cache[1] = self.ssm_step(z, cache[1], dt)
xBC = nn.silu(xBC)
# Element-wise multiplication # Element-wise multiplication
output_t = y_t[:, :, None] * z_t[:, None, :] output_t = y_t[:, :, None] * xBC[:, None, :]
# Sum across the second dimension to match the intermediate_size output_t = self.norm(output_t)
output_t = output_t.sum(axis=1) output_t = output_t.sum(axis=1)
output_t = self.out_proj(output_t) output_t = self.out_proj(output_t)
outputs.append(output_t) outputs.append(output_t)
@ -207,10 +195,10 @@ class Mamba2Mixer(nn.Module):
return output return output
class Mamba2Block(nn.Module): class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.mixer = Mamba2Mixer(args) self.mixer = Mamba2Block(args)
self.norm = nn.RMSNorm(args.hidden_size) self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache): def __call__(self, x: mx.array, cache):
@ -222,24 +210,16 @@ class Mamba2(nn.Module):
super().__init__() super().__init__()
self.args = args self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [Mamba2Block(args) for idx in range(args.num_hidden_layers)] self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__( def __call__(self, x: mx.array, cache):
self, x = self.embeddings(x)
inputs: mx.array,
cache=None
):
hidden_states = self.embeddings(inputs)
if cache is None: if cache is None:
cache = Mamba2Cache(len(self.layers)) cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
for i, layer in enumerate(self.layers): x = layer(x, c)
hidden_states = layer(hidden_states, cache[i]) return self.norm_f(x)
hidden_states = self.norm_f(hidden_states)
return hidden_states
class Model(nn.Module): class Model(nn.Module):
@ -247,7 +227,10 @@ class Model(nn.Module):
super().__init__() super().__init__()
self.args = args self.args = args
self.model_type = args.model_type self.model_type = args.model_type
self.backbone = Mamba2(args) self.backbone = Mamba2(args)
# self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
if not args.tie_word_embeddings: if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
@ -261,9 +244,6 @@ class Model(nn.Module):
else: else:
logits = self.lm_head(x) logits = self.lm_head(x)
print(logits)
print(logits.shape)
return logits return logits
def sanitize(self, weights): def sanitize(self, weights):
@ -272,8 +252,8 @@ class Model(nn.Module):
weights[k] = v.moveaxis(2, 1) weights[k] = v.moveaxis(2, 1)
return weights return weights
def make_cache(self, batch_size: int = 1): def make_cache(self):
return [Mamba2Cache() for _ in range(len(self.layers))] return [MambaCache() for _ in range(len(self.layers))]
@property @property
def layers(self): def layers(self):

View File

@ -1,246 +1,411 @@
"""
mamba2-minimal
==============
A minimal, single-file implementation of the Mamba-2 model in PyTorch.
import math > **Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality**
> Authors: Tri Dao, Albert Gu
> Paper: https://arxiv.org/abs/2405.21060
"""
import json
from dataclasses import dataclass from dataclasses import dataclass
from typing import Union from typing import Iterable, NamedTuple, TypeAlias, cast
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from torch import LongTensor, Tensor, nn
Device: TypeAlias = str | torch.device | None
@dataclass @dataclass
class Mamba2Config: class Mamba2Config:
d_model: int # D d_model: int # model dimension (D)
n_layers: int n_layer: int = 24 # number of Mamba-2 layers in the language model
d_head: int # todo : plutot n_heads non ? d_state: int = 128 # state dimension (N)
d_state: int = 64 # N in paper/comments d_conv: int = 4 # convolution kernel size
expand_factor: int = 2 # E in paper/comments expand: int = 2 # expansion factor (E)
d_conv: int = 4 headdim: int = 64 # head dimension (P)
n_groups: int = 1# todo : ?? chunk_size: int = 64 # matrix partition size (Q)
vocab_size: int = 50277
A_init_range: tuple = (1, 16) pad_vocab_size_multiple: int = 16
dt_min: float = 0.001
dt_max: float = 0.1
dt_init_floor: float = 1e-4
dt_limit: tuple = (0.0, float("inf"))
conv_init = None
learnable_init_states: bool = False
activation: str = "swish" # "swish" or "silu"
rms_norm_eps: float = 1e-5
base_std: float = 0.02
bias: bool = False
conv_bias: bool = True
mup: bool = False
mup_base_width: float = 128 # width=d_model
chunk_size: int = 256
use_mem_eff_path: bool = True
dtype=None
device=None
def __post_init__(self): def __post_init__(self):
self.d_inner = self.expand_factor * self.d_model # E*D = ED in comments self.d_inner = self.expand * self.d_model
self.n_heads = self.d_inner // self.d_head assert self.d_inner % self.headdim == 0
assert self.d_inner % self.d_head == 0 self.nheads = self.d_inner // self.headdim
if self.vocab_size % self.pad_vocab_size_multiple != 0:
self.vocab_size += (
self.pad_vocab_size_multiple
- self.vocab_size % self.pad_vocab_size_multiple
)
assert (self.d_inner / self.d_head) % 8 == 0, "requierement of causal_conv1d"
# muP class InferenceCache(NamedTuple):
if self.mup: conv_state: Tensor # (batch, d_inner + 2 * d_state, d_conv)
self.mup_width_mult = self.d_model / self.mup_base_width ssm_state: Tensor # (batch, nheads, headdim, d_state)
@staticmethod
def alloc(batch_size: int, args: Mamba2Config, device: Device = None):
return InferenceCache(
torch.zeros(
batch_size, args.d_inner + 2 * args.d_state, args.d_conv, device=device
),
torch.zeros(
batch_size, args.nheads, args.headdim, args.d_state, device=device
),
)
class Mamba2LMHeadModel(nn.Module):
def __init__(self, args: Mamba2Config, device: Device = None):
super().__init__()
self.args = args
self.device = device
self.backbone = nn.ModuleDict(
dict(
embedding=nn.Embedding(args.vocab_size, args.d_model, device=device),
layers=nn.ModuleList(
[
nn.ModuleDict(
dict(
mixer=Mamba2(args, device=device),
norm=RMSNorm(args.d_model, device=device),
)
)
for _ in range(args.n_layer)
]
),
norm_f=RMSNorm(args.d_model, device=device),
)
)
self.lm_head = nn.Linear(
args.d_model, args.vocab_size, bias=False, device=device
)
self.lm_head.weight = self.backbone.embedding.weight
def forward(
self, input_ids: LongTensor, h: list[InferenceCache] | list[None] | None = None
) -> tuple[LongTensor, list[InferenceCache]]:
"""
Arguments
input_ids: (batch, seqlen) tokens from `EleutherAI/gpt-neox-20b` tokenizer
h: hidden states for inference step. If present the constant-time
(wrt sequence length) inference path will be taken, input_ids
should have shape (batch, 1) containing the next batch of prompt
token.
Return (logits, h)
logits: (batch, seqlen, vocab_size)
h: updated inference cache after processing `input_ids`
"""
seqlen = input_ids.shape[1]
if h is None:
h = [None for _ in range(self.args.n_layer)]
x = self.backbone.embedding(input_ids)
for i, layer in enumerate(self.backbone.layers):
y, h[i] = layer.mixer(layer.norm(x), h[i])
x = y + x
x = self.backbone.norm_f(x)
logits = self.lm_head(x)
return logits[:, :seqlen], cast(list[InferenceCache], h)
def generate(
self,
input_ids: LongTensor,
max_new_length: int = 20,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 1.0,
eos_token_id: int = 0,
) -> Iterable[tuple[int, list[InferenceCache]]]:
prefix, tokens = input_ids[:-1], input_ids[-1:].unsqueeze(0)
# Process prompt
# The input sequence to forward (non-inference path) must have length multiple that of chunk_size.
# We split out excess tokens so that n_chunked tokens can be processed by one forward call and
# process the rest in multiple inference steps.
n_chunked = (prefix.shape[0] // self.args.chunk_size) * self.args.chunk_size
if n_chunked > 0:
_, h = self(prefix[:n_chunked].unsqueeze(0), None)
else:
h = [
InferenceCache.alloc(1, self.args, device=self.device)
for _ in range(self.args.n_layer)
]
for i in range(n_chunked, prefix.shape[0]):
_, h = self(prefix[i : i + 1].unsqueeze(0), h)
# Generate
for _ in range(max_new_length):
with torch.no_grad():
out, h = self(tokens, h)
logits = out[0, -1]
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, k=top_k)[0][-1]
logits[indices_to_remove] = -torch.inf
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > 0.5
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
sorted_indices_to_remove[0] = False
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = -torch.inf
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
if next_token.item() == eos_token_id:
return
tokens = next_token.unsqueeze(0)
yield cast(int, next_token.item()), h
class Mamba2(nn.Module): class Mamba2(nn.Module):
def __init__(self, config: Mamba2Config): def __init__(self, args: Mamba2Config, device: Device = None):
super().__init__() super().__init__()
self.args = args
self.device = device
self.config = config # Order: (z, x, B, C, dt)
d_in_proj = 2 * args.d_inner + 2 * args.d_state + args.nheads
self.in_proj = nn.Linear(args.d_model, d_in_proj, bias=False, device=device)
self.layers = nn.ModuleList([ResidualBlock(config) for _ in range(config.n_layers)]) conv_dim = args.d_inner + 2 * args.d_state
def forward(self, x, caches=None):
if caches is None:
caches = [None] * self.config.n_layers
for i, layer in enumerate(self.layers):
x, caches[i] = layer(x, caches[i])
if caches[0] == None:
return x
else:
return x, caches
class ResidualBlock(nn.Module):
def __init__(self, config: Mamba2Config):
super().__init__()
self.config = config
self.mixer = Mamba2Block(self.config)
self.norm = RMSNorm(self.config.d_model, self.config.rms_norm_eps, self.config.mup)
def forward(self, x, cache=None):
output, cache = self.mixer(self.norm(x), cache)
output = output + x
return output, cache
class Mamba2Block(nn.Module):
def __init__(self, config: Mamba2Config):
super().__init__()
factory_kwargs = {"device": config.device, "dtype": config.dtype}
self.config = config
# [z, x, B, C, dt]
d_in_proj = 2 * self.config.d_inner + 2 * self.config.n_groups * self.config.d_state + self.config.n_heads
self.in_proj = nn.Linear(self.config.d_model, d_in_proj, bias=self.config.bias)
conv_dim = self.config.d_inner + 2 * self.config.n_groups * self.config.d_state
self.conv1d = nn.Conv1d( self.conv1d = nn.Conv1d(
in_channels=conv_dim, in_channels=conv_dim,
out_channels=conv_dim, out_channels=conv_dim,
bias=self.config.conv_bias, kernel_size=args.d_conv,
kernel_size=self.config.d_conv,
groups=conv_dim, groups=conv_dim,
padding=self.config.d_conv - 1, padding=args.d_conv - 1,
**factory_kwargs, device=device,
) )
self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device))
self.A_log = nn.Parameter(torch.empty(args.nheads, device=device))
self.D = nn.Parameter(torch.empty(args.nheads, device=device))
self.norm = RMSNorm(args.d_inner, device=device)
self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device)
# Initialize log dt bias def forward(self, u: Tensor, h: InferenceCache | None = None):
dt = torch.exp(
torch.rand(self.config.n_heads) * (math.log(self.config.dt_max) - math.log(self.config.dt_min))
+ math.log(self.config.dt_min)
)
dt = torch.clamp(dt, min=self.config.dt_init_floor)
inv_dt = dt + torch.log(-torch.expm1(-dt))
self.dt_bias = nn.Parameter(inv_dt)
assert self.config.A_init_range[0] > 0 and self.config.A_init_range[1] >= self.config.A_init_range[0]
A = torch.empty(self.config.n_heads, dtype=torch.float32).uniform_(*self.config.A_init_range)
self.A_log = torch.log(A).to(dtype=self.config.dtype)
self.D = nn.Parameter(torch.ones(self.config.n_heads, device=self.config.device))
self.norm = RMSNormGated(self.config.d_inner, eps=1e-5, norm_before_gate=False)
self.out_proj = nn.Linear(self.config.d_inner, self.config.d_model, bias=self.config.bias)
def forward(self, u, cache=None, seq_idx=None):
""" """
u: (B, L, D) Arguments
Returns: out : same shape as u u: (batch, seqlen, d_model) input. seqlen should be a multiple of chunk_size.
h: hidden states for inference step. Initialized to 0s if not present.
Return (y, h)
y: (batch, seqlen, d_model) output
h: updated inference cache after processing `u`
""" """
if h:
return self.step(u, h)
batch, length, _ = u.shape A = -torch.exp(self.A_log) # (nheads,)
zxbcdt = self.in_proj(u) # (batch, seqlen, d_in_proj)
return_cache = False
if cache is not None and length > 1:
cache = None
return_cache = True
if cache is not None:
out, cache = self.step(u, cache)
return out, cache
zxbcdt = self.in_proj(u) # (B, L, d_in_proj)
A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state)
initial_states=repeat(self.init_states, "... -> b ...", b=batch) if self.config.learnable_init_states else None
dt_limit_kwargs = {} if self.config.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.config.dt_limit)
z, xBC, dt = torch.split( z, xBC, dt = torch.split(
zxbcdt, zxbcdt,
[self.config.d_inner, self.config.d_inner + 2 * self.config.n_groups * self.config.d_state, self.config.n_heads], [
dim=-1 self.args.d_inner,
self.args.d_inner + 2 * self.args.d_state,
self.args.nheads,
],
dim=-1,
) )
dt = F.softplus(dt + self.dt_bias) # (B, L, nheads) dt = F.softplus(dt + self.dt_bias) # (batch, seqlen, nheads)
# 1D Convolution # Pad or truncate xBC seqlen to d_conv
xBC = self.act(self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)) # (B, L, self.d_inner + 2 * n_groups * d_state) conv_state = F.pad(
rearrange(xBC, "b l d -> b d l"), (self.args.d_conv - u.shape[1], 0)
x, B, C = torch.split(xBC, [self.config.d_inner, self.config.n_groups * self.config.d_state, self.config.n_groups * self.config.d_state], dim=-1)
y = mamba_chunk_scan_combined(
rearrange(x, "b l (h p) -> b l h p", p=self.config.d_head),
dt,
A,
rearrange(B, "b l (g n) -> b l g n", g=self.config.n_groups),
rearrange(C, "b l (g n) -> b l g n", g=self.config.n_groups),
chunk_size=self.config.chunk_size,
D=self.D,
z=None,
seq_idx=seq_idx,
initial_states=initial_states,
**dt_limit_kwargs,
) )
xBC = silu(
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, : u.shape[1], :]
) # (batch, seqlen, d_inner + 2 * d_state))
x, B, C = torch.split(
xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1
)
x = rearrange(x, "b l (h p) -> b l h p", p=self.args.headdim)
y, ssm_state = ssd(
x * dt.unsqueeze(-1),
A * dt,
rearrange(B, "b l n -> b l 1 n"),
rearrange(C, "b l n -> b l 1 n"),
self.args.chunk_size,
device=self.device,
)
y = y + x * self.D.unsqueeze(-1)
y = rearrange(y, "b l h p -> b l (h p)") y = rearrange(y, "b l h p -> b l (h p)")
# Multiply "gate" branch and apply extra normalization layer
y = self.norm(y, z) y = self.norm(y, z)
out = self.out_proj(y) y = self.out_proj(y)
return out, cache
def step(self, u, cache): h = InferenceCache(conv_state, ssm_state)
""" return y, h
u: (B, 1, D)
cache: (h_cache, conv_cache) def step(self, u: Tensor, h: InferenceCache) -> tuple[Tensor, InferenceCache]:
"""Take a single inference step for the current input and hidden state
Unlike attention-based models, RNN-based models (eg Mamba) does not need
to look back at all the past tokens to generate a new token. Instead a
hidden state (initialized to 0s initially) is updated for each input and
passed to the next inference step. This means that the total inference
time is linear with respect to the sequence length instead of quadratic
in attention's case.
Arguments
u: (batch, 1, d_model)
h: initial/running hidden state
Return (y, h)
y: (batch, 1, d_model)
h: updated hidden state
""" """
assert u.shape[1] == 1, "Only one token can be decoded per inference step"
h_cache, conv_cache = cache zxbcdt = self.in_proj(u.squeeze(1)) # (batch, d_in_proj)
z, xBC, dt = torch.split(
zxbcdt,
[
self.args.d_inner,
self.args.d_inner + 2 * self.args.d_state,
self.args.nheads,
],
dim=-1,
)
zxbcdt = self.in_proj(u.squeeze(1)) # (B, 2D) # Advance convolution input
d_mlp = (zxbcdt.shape[-1] - 2 * self.config.d_inner - 2 * self.config.n_groups * self.config.d_state - self.config.n_heads) // 2 h.conv_state.copy_(torch.roll(h.conv_state, shifts=-1, dims=-1))
z0, x0, z, xBC, dt = torch.split(zxbcdt, [d_mlp, d_mlp, self.config.d_inner, self.config.d_inner + 2 * self.config.n_groups * self.config.d_state, self.config.n_heads], dim=-1) h.conv_state[:, :, -1] = xBC
# Convolution step
xBC = torch.sum(
h.conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
)
xBC += self.conv1d.bias
xBC = silu(xBC)
# conv step x, B, C = torch.split(
conv_cache.copy_(torch.roll(conv_cache, shifts=-1, dims=-1)) # update state (B, D, W) xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1
conv_cache[:, :, -1] = xBC )
xBC = torch.sum(conv_cache * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B, D) A = -torch.exp(self.A_log) # (nheads,)
if self.conv1d.bias is not None:
xBC = xBC + self.conv1d.bias
xBC = self.act(xBC).to(dtype=x.dtype)
x, B, C = torch.split(xBC, [self.config.d_inner, self.config.n_groups * self.config.d_state, self.config.n_groups * self.config.d_state], dim=-1) # SSM step
A = -torch.exp(self.A_log.float()) # (n_heads) dt = F.softplus(dt + self.dt_bias) # (batch, nheads)
dA = torch.exp(dt * A) # (batch, nheads)
x = rearrange(x, "b (h p) -> b h p", p=self.args.headdim)
A = repeat(A, "h -> h p n", p=self.config.d_head, n=self.config.d_state).to(dtype=torch.float32) dBx = torch.einsum("bh, bn, bhp -> bhpn", dt, B, x)
dt = repeat(dt, "b h -> b h p", p=self.config.d_head) h.ssm_state.copy_(h.ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
dt_bias = repeat(self.dt_bias, "h -> h p", p=self.config.d_head) y = torch.einsum("bhpn, bn -> bhp", h.ssm_state, C)
D = repeat(self.D, "h -> h p", p=self.config.d_head) y = y + rearrange(self.D, "h -> h 1") * x
B = rearrange(B, "b (g n) -> b g n", g=self.config.n_groups)
C = rearrange(C, "b (g n) -> b g n", g=self.config.n_groups)
x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.config.d_head)
y = selective_state_update(h_cache, x_reshaped, dt, A, B, C, D, z=None, dt_bias=dt_bias, dt_softplus=True)
y = rearrange(y, "b h p -> b (h p)") y = rearrange(y, "b h p -> b (h p)")
#if self.rmsnorm:
y = self.norm(y, z) y = self.norm(y, z)
if d_mlp > 0: y = self.out_proj(y)
y = torch.cat([F.silu(z0) * x0, y], dim=-1)
out = self.out_proj(y) return y.unsqueeze(1), h
return out.unsqueeze(1), (h_cache, conv_cache)
def segsum(x: Tensor, device: Device = None) -> Tensor:
"""Stable segment sum calculation.
`exp(segsum(A))` produces a 1-semiseparable matrix, which is equivalent to a scalar SSM.
Source: https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L23-L32
"""
T = x.size(-1)
x = repeat(x, "... d -> ... d e", e=T)
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=-1)
x = x.masked_fill(~mask, 0)
x_segsum = torch.cumsum(x, dim=-2)
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=0)
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
return x_segsum
def ssd(x, A, B, C, chunk_size, initial_states=None, device: Device = None):
"""Structed State Space Duality (SSD) - the core of Mamba-2
This is almost the exact same minimal SSD code from the blog post.
Arguments
x: (batch, seqlen, n_heads, d_head)
A: (batch, seqlen, n_heads)
B: (batch, seqlen, n_heads, d_state)
C: (batch, seqlen, n_heads, d_state)
Return
y: (batch, seqlen, n_heads, d_head)
Source
1. https://tridao.me/blog/2024/mamba2-part3-algorithm/
2. https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L34-L78
"""
assert x.shape[1] % chunk_size == 0
# Rearrange into chunks
# Step 1, 2 and 4 of SSD can be computed in parallel for each chunk across devices (sequence parallel)
# This is not implemented and left as an exercise for the reader 😜
x, A, B, C = [
rearrange(m, "b (c l) ... -> b c l ...", l=chunk_size) for m in (x, A, B, C)
]
A = rearrange(A, "b c l h -> b h c l")
A_cumsum = torch.cumsum(A, dim=-1)
# 1. Compute the output for each intra-chunk (diagonal blocks)
L = torch.exp(segsum(A, device=device))
Y_diag = torch.einsum("bclhn, bcshn, bhcls, bcshp -> bclhp", C, B, L, x)
# 2. Compute the state for each intra-chunk
# (right term of low-rank factorization of off-diagonal blocks; B terms)
decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
states = torch.einsum("bclhn, bhcl, bclhp -> bchpn", B, decay_states, x)
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
if initial_states is None:
initial_states = torch.zeros_like(states[:, :1])
states = torch.cat([initial_states, states], dim=1)
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)), device=device))
new_states = torch.einsum("bhzc, bchpn -> bzhpn", decay_chunk, states)
states, final_state = new_states[:, :-1], new_states[:, -1]
# 4. Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = torch.exp(A_cumsum)
Y_off = torch.einsum("bclhn, bchpn, bhcl -> bclhp", C, states, state_decay_out)
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
return Y, final_state
# taken straight from https://github.com/johnma2006/mamba-minimal/blob/master/model.py
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-5, use_mup: bool = False): def __init__(self, d: int, eps: float = 1e-5, device: Device = None):
"""Gated Root Mean Square Layer Normalization
Paper: https://arxiv.org/abs/1910.07467
"""
super().__init__() super().__init__()
self.use_mup = use_mup
self.eps = eps self.eps = eps
self.weight = nn.Parameter(torch.ones(d, device=device))
# https://arxiv.org/abs/2404.05728, RMSNorm gains prevents muTransfer (section 4.2.3) def forward(self, x, z=None):
if not use_mup: if z is not None:
self.weight = nn.Parameter(torch.ones(d_model)) x = x * silu(z)
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
def forward(self, x):
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
if not self.use_mup: def silu(x):
return output * self.weight """Applies the Sigmoid Linear Unit (SiLU), element-wise.
else:
return output Define this manually since torch's version doesn't seem to work on MPS.
"""
return x * F.sigmoid(x)

View File

@ -106,14 +106,16 @@ class Mamba2Block(nn.Module):
self.head_dim = args.hidden_size // args.num_heads self.head_dim = args.hidden_size // args.num_heads
self.n_groups = args.n_groups self.n_groups = args.n_groups
projection_size = 2 * args.intermediate_size + 2 * args.n_groups * args.state_size + args.num_heads # projection_size = 2 * args.intermediate_size + 2 * args.n_groups * args.state_size + args.num_heads
projection_size = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads
self.in_proj = nn.Linear( self.in_proj = nn.Linear(
args.hidden_size, args.hidden_size,
projection_size, projection_size,
bias=args.use_bias bias=args.use_bias
) )
self.conv_dim = args.intermediate_size + 2 * args.n_groups * args.state_size # self.conv_dim = args.intermediate_size + 2 * args.n_groups * args.state_size
self.conv_dim = args.intermediate_size + 2 * args.state_size
self.conv1d = DepthWiseConv1d( self.conv1d = DepthWiseConv1d(
in_channels=self.conv_dim, in_channels=self.conv_dim,
out_channels=self.conv_dim, out_channels=self.conv_dim,
@ -130,62 +132,125 @@ class Mamba2Block(nn.Module):
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon) self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
def ssm_step(self, x, state, dt): def _ssd(self, x, A, B, C, chunk_size):
batch, seq_len, nheads, head_dim = x.shape
n_state = B.shape[-1]
h = mx.zeros((batch, nheads, head_dim, n_state))
ys = []
for i in range(0, seq_len, chunk_size):
chunk_size_i = min(chunk_size, seq_len - i)
xi = x[:, i:i + chunk_size_i]
Bi = B[:, i:i + chunk_size_i]
Ci = C[:, i:i + chunk_size_i]
for t in range(chunk_size_i):
h = h * mx.exp(A)[:, None, None]
h = h + mx.expand_dims(Bi[:, t], -2) * mx.expand_dims(xi[:, t], -1)
y = mx.sum(h * mx.expand_dims(Ci[:, t], -2), axis=-1)
ys.append(y)
y = mx.stack(ys, axis=1)
return y, h
def __call__(self, x: mx.array, cache) -> mx.array:
if cache is not None:
return self.step(x, cache)
A = -mx.exp(self.A_log) A = -mx.exp(self.A_log)
D = self.D zxbcdt = self.in_proj(u)
dt = nn.softplus(dt + self.dt_bias)
B, C = mx.split(x, indices_or_sections=[self.state_size * self.n_groups], axis=-1)
batch_size = B.shape[0]
B = B.reshape(batch_size, self.n_groups, self.state_size)
C = C.reshape(batch_size, -1, self.state_size)
dt = dt.reshape(batch_size, self.num_heads, 1)
A = A.reshape(1, self.num_heads, 1)
if state is None:
new_state = dt * B
else:
new_state = dt * (B + state * mx.exp(dt * A))
y = mx.sum(new_state[:, :, None, :] * C[:, None, :, :], axis=(-1, -2))
y = y + D * x[:, :self.num_heads]
return y, new_state
def __call__(self, x, cache):
B, T, D = x.shape
if cache is None:
cache = [None, None]
outputs = []
for t in range(T):
xt = x[:, t, :]
zxbcdt = self.in_proj(xt)
z, xBC, dt = mx.split( z, xBC, dt = mx.split(
zxbcdt, zxbcdt,
indices_or_sections=[self.conv_dim, self.conv_dim + self.intermediate_size], [
self.args.d_inner,
self.args.d_inner + 2 * self.args.d_state,
self.args.nheads,
],
axis=-1,
)
dt = mx.softplus(dt + self.dt_bias)
# Use the custom DepthWiseConv1d with cache
xBC = self.conv1d(xBC, cache, cache_idx=0)
xBC = mx.sigmoid(xBC) * xBC # SiLU activation
x, B, C = mx.split(
xBC,
[self.args.d_inner, self.args.d_state, self.args.d_state],
axis=-1 axis=-1
) )
# Use the new DepthWiseConv1d with caching x = self._reshape_heads(x, True)
conv_out, cache[0] = self.conv1d(mx.expand_dims(z, 1), cache[0]) B = mx.expand_dims(B, axis=2)
z = conv_out.squeeze(1) C = mx.expand_dims(C, axis=2)
z = nn.silu(z)
y_t, cache[1] = self.ssm_step(z, cache[1], dt)
xBC = nn.silu(xBC)
# Element-wise multiplication y, ssm_state = self._ssd(
output_t = y_t[:, :, None] * xBC[:, None, :] x * mx.expand_dims(dt, -1),
A * dt,
B,
C,
self.args.chunk_size
)
output_t = self.norm(output_t) y = y + x * mx.expand_dims(self.D, -1)
output_t = output_t.sum(axis=1) y = self._reshape_heads(y, False)
output_t = self.out_proj(output_t) y = self.norm(y, z)
outputs.append(output_t) y = self.out_proj(y)
output = mx.stack(outputs, axis=1) if cache is not None:
return output cache[1] = ssm_state
return y
def step(self, x: mx.array, cache) -> mx.array:
"""Single inference step"""
assert x.shape[1] == 1, "Only one token can be decoded per inference step"
zxbcdt = self.in_proj(mx.squeeze(x, 1))
z, xBC, dt = mx.split(
zxbcdt,
[
self.args.d_inner,
self.args.d_inner + 2 * self.args.d_state,
self.args.nheads,
],
axis=-1,
)
# Use the custom DepthWiseConv1d with cache
xBC = self.conv1d(xBC, cache, cache_idx=0)
xBC = mx.sigmoid(xBC) * xBC # SiLU activation
x, B, C = mx.split(
xBC,
[self.args.d_inner, self.args.d_state, self.args.d_state],
axis=-1
)
A = -mx.exp(self.A_log)
dt = mx.softplus(dt + self.dt_bias)
dA = mx.exp(dt * A)
x = mx.reshape(x, (-1, self.args.nheads, self.args.headdim))
ssm_state = cache[1]
dBx = mx.expand_dims(dt, -1) * mx.expand_dims(B, 1) * mx.expand_dims(x, -1)
ssm_state = ssm_state * mx.expand_dims(mx.expand_dims(dA, -1), -1) + dBx
y = mx.sum(ssm_state * mx.expand_dims(mx.expand_dims(C, 1), 1), axis=-1)
y = y + mx.expand_dims(self.D, -1) * x
y = mx.reshape(y, (-1, self.args.nheads * self.args.headdim))
y = self.norm(y, z)
y = self.out_proj(y)
# Update SSM state in cache
cache[1] = ssm_state
return mx.expand_dims(y, 1)
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):