This commit is contained in:
Goekdeniz-Guelmez 2024-10-30 21:23:13 +01:00
parent ffc7ab06a0
commit 58b448dc0b
4 changed files with 1007 additions and 1285 deletions

View File

@ -324,6 +324,7 @@ class RotatingKVCache(_BaseCache):
class MambaCache(_BaseCache): class MambaCache(_BaseCache):
def __init__(self): def __init__(self):
self.cache = [None, None] self.cache = [None, None]
self.offset = 0
def __setitem__(self, idx, value): def __setitem__(self, idx, value):
self.cache[idx] = value self.cache[idx] = value
@ -341,129 +342,12 @@ class MambaCache(_BaseCache):
class Mamba2Cache: class Mamba2Cache:
batch_size: int def __init__(self, batch_size, conv_dim, kernel_size, num_heads, head_dim, state_size):
intermediate_size: int self.conv_states = mx.zeros((batch_size, conv_dim, kernel_size - 1))
state_size: int self.ssm_states = mx.zeros((batch_size, num_heads, head_dim, state_size))
conv_kernel: int self.seqlen_offset = 0
num_heads: int
head_dim: int
def __init__( def update(self, new_conv_state, new_ssm_state):
self, self.conv_states = new_conv_state
batch_size: int, self.ssm_states = new_ssm_state
intermediate_size: int, self.seqlen_offset += 1
state_size: int,
conv_kernel: int,
num_heads: int,
head_dim: int
):
self.batch_size = batch_size
self.intermediate_size = intermediate_size
self.state_size = state_size
self.conv_kernel = conv_kernel
self.num_heads = num_heads
self.head_dim = head_dim
# Initialize conv state with proper dimensions
self.conv_dim = self.intermediate_size + 2 * self.state_size
self.conv_state = mx.zeros((batch_size, self.conv_dim, conv_kernel - 1))
# Initialize SSM state
self.ssm_state = mx.zeros((
batch_size,
num_heads,
head_dim,
state_size
))
def update_conv_state(self, x: mx.array) -> mx.array:
"""
Update convolution state for incremental inference.
Args:
x: Input tensor containing projected values (B, conv_in_dim)
Returns:
Combined state tensor of shape (batch_size, conv_dim, kernel_size)
"""
# Handle input shape
if x.ndim == 1:
x = mx.expand_dims(x, 0) # Add batch dimension if needed
# Ensure batch size matches
assert x.shape[0] == self.batch_size, f"Batch size mismatch: {x.shape[0]} vs {self.batch_size}"
# Reshape x to match conv_dim
# The input x contains intermediate_size + 2 * state_size dimensions
x_reshaped = mx.reshape(x, (self.batch_size, -1))
x_padded = mx.pad(
x_reshaped,
[(0, 0), (0, self.conv_dim - x_reshaped.shape[1])],
mode='constant',
constant_values=0
)
# Expand dims for concatenation
x_expanded = mx.expand_dims(x_padded, -1) # Shape: (batch_size, conv_dim, 1)
# Roll the existing state left by 1
rolled_state = mx.roll(self.conv_state, shift=-1, axis=-1)
# Create update mask for the last position
update_pos = self.conv_kernel - 2
state_idx = mx.arange(self.conv_kernel - 1)
update_mask = state_idx == update_pos
# Broadcast mask to match state dimensions
update_mask = mx.broadcast_to(
mx.reshape(update_mask, (1, 1, -1)),
rolled_state.shape
)
# Update state with padded input
x_broadcast = mx.broadcast_to(x_expanded, (self.batch_size, self.conv_dim, 1))
self.conv_state = mx.where(
update_mask,
x_broadcast,
rolled_state
)
# Return concatenated state for convolution
return mx.concatenate([self.conv_state, x_expanded], axis=-1)
def update_ssm_state(self, dA: mx.array, dBx: mx.array) -> mx.array:
"""
Update SSM state for incremental inference.
Args:
dA: State transition tensor of shape (batch_size, num_heads)
dBx: Input projection tensor of shape (batch_size, num_heads, head_dim, state_size)
Returns:
Updated SSM state of shape (batch_size, num_heads, head_dim, state_size)
"""
# Add necessary dimensions to dA for broadcasting
# dA shape: (batch_size, num_heads) -> (batch_size, num_heads, 1, 1)
dA = mx.expand_dims(mx.expand_dims(dA, -1), -1)
# Ensure dBx has the correct shape
assert dBx.shape[-1] == self.state_size, f"dBx state dimension mismatch: {dBx.shape[-1]} vs {self.state_size}"
assert dBx.shape[-2] == self.head_dim, f"dBx head dimension mismatch: {dBx.shape[-2]} vs {self.head_dim}"
# Update state: state = dA * state + dBx
self.ssm_state = dA * self.ssm_state + dBx
return self.ssm_state
@classmethod
def get_cache(
cls,
args,
batch_size: int,
max_seq_length: Optional[int]
) -> "Mamba2Cache":
"""Create a new cache instance with the given parameters."""
return cls(
batch_size=batch_size,
intermediate_size=args.intermediate_size,
state_size=args.state_size,
conv_kernel=args.conv_kernel,
num_heads=args.num_heads,
head_dim=args.head_dim
)

View File

@ -1,275 +1,7 @@
import math import math
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Tuple, Union from typing import Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .cache import MambaCache
@dataclass
class ModelArgs(BaseModelArgs):
num_heads: int
head_dim: int
vocab_size: int
hidden_size: int
state_size: int
num_hidden_layers: int
layer_norm_epsilon: float
expand: int
conv_kernel: int
n_groups: int
use_bias: bool
use_conv_bias: bool
initializer_range: float
residual_in_fp32: bool
time_step_min: float
time_step_max: float
time_step_floor: float
rescale_prenorm_residual: bool
use_cache: bool
rms_norm: bool
chunk_size: int
tie_word_embeddings: bool
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
time_step_rank: Union[int, str] = "auto"
model_type: str = "mamba2"
def __post_init__(self):
if not hasattr(self, "intermediate_size"):
self.intermediate_size = int(self.expand * self.hidden_size)
if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_heads
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = mx.ones((hidden_size,))
self.variance_epsilon = eps
def __call__(self, hidden_states, gate=None):
if gate is not None:
hidden_states = hidden_states * nn.silu(gate)
variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True)
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
class DepthWiseConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = padding
self.groups = groups if groups is not None else in_channels
# Ensure in_channels and out_channels are the same for depthwise conv
assert in_channels == out_channels, "In and out channels must be the same for depthwise convolution"
# Ensure groups is equal to in_channels for depthwise conv
assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution"
# Initialize weight with shape (out_channels, kernel_size, 1)
self.weight = mx.random.normal((out_channels, kernel_size, 1))
self.bias = mx.zeros((out_channels,)) if bias else None
def __call__(self, x, cache=None):
B, L, C = x.shape
_, K, _ = self.weight.shape
if cache is not None:
x = mx.concatenate([cache, x], axis=1)
else:
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
y = mx.conv_general(x, self.weight, groups=self.groups)
if self.bias is not None:
y = y + self.bias
return y, x[:, -K + 1 :, :]
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.intermediate_size = args.intermediate_size
self.time_step_rank = args.time_step_rank
self.conv_kernel_size = args.conv_kernel
self.hidden_size = args.hidden_size
self.state_size = args.state_size
self.num_heads = args.num_heads
self.head_dim = args.hidden_size // args.num_heads
self.n_groups = args.n_groups
# projection_size = 2 * args.intermediate_size + 2 * args.n_groups * args.state_size + args.num_heads
projection_size = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads
self.in_proj = nn.Linear(
args.hidden_size,
projection_size,
bias=args.use_bias
)
# self.conv_dim = args.intermediate_size + 2 * args.n_groups * args.state_size
self.conv_dim = args.intermediate_size + 2 * args.state_size
self.conv1d = DepthWiseConv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
kernel_size=args.conv_kernel,
bias=args.use_conv_bias,
groups=self.conv_dim,
padding=args.conv_kernel - 1
)
self.A_log = mx.zeros(args.num_heads)
self.D = mx.ones((args.num_heads,))
self.dt_bias = mx.zeros(args.num_heads)
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
def ssm_step(self, x, state, dt):
A = -mx.exp(self.A_log)
D = self.D
dt = nn.softplus(dt + self.dt_bias)
B, C = mx.split(x, indices_or_sections=[self.state_size * self.n_groups], axis=-1)
batch_size = B.shape[0]
B = B.reshape(batch_size, self.n_groups, self.state_size)
C = C.reshape(batch_size, -1, self.state_size)
dt = dt.reshape(batch_size, self.num_heads, 1)
A = A.reshape(1, self.num_heads, 1)
if state is None:
new_state = dt * B
else:
new_state = dt * (B + state * mx.exp(dt * A))
y = mx.sum(new_state[:, :, None, :] * C[:, None, :, :], axis=(-1, -2))
y = y + D * x[:, :self.num_heads]
return y, new_state
def __call__(self, x, cache):
B, T, D = x.shape
if cache is None:
cache = [None, None]
outputs = []
for t in range(T):
xt = x[:, t, :]
zxbcdt = self.in_proj(xt)
z, xBC, dt = mx.split(
zxbcdt,
# indices_or_sections=[self.conv_dim, self.conv_dim + self.intermediate_size],
indices_or_sections=[
self.intermediate_size,
self.intermediate_size + 2 * self.state_size,
self.num_heads
],
axis=-1
)
# Use the new DepthWiseConv1d with caching
conv_out, cache[0] = self.conv1d(mx.expand_dims(z, 1), cache[0])
z = conv_out.squeeze(1)
z = nn.silu(z)
y_t, cache[1] = self.ssm_step(z, cache[1], dt)
xBC = nn.silu(xBC)
# Element-wise multiplication
output_t = y_t[:, :, None] * xBC[:, None, :]
output_t = self.norm(output_t)
output_t = output_t.sum(axis=1)
output_t = self.out_proj(output_t)
outputs.append(output_t)
output = mx.stack(outputs, axis=1)
return output
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.mixer = Mamba2Block(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
return self.mixer(self.norm(x), cache) + x
class Mamba2(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(self, x: mx.array, cache):
x = self.embeddings(x)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
x = layer(x, c)
return self.norm_f(x)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba2(args)
# self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None):
B, T = inputs.shape
x = self.backbone(inputs, cache)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(x)
else:
logits = self.lm_head(x)
return logits
def sanitize(self, weights):
for k, v in weights.items():
if "conv1d.weight" in k and v.ndim == 3:
weights[k] = v.moveaxis(2, 1)
return weights
def make_cache(self):
return [MambaCache() for _ in range(len(self.layers))]
@property
def layers(self):
return self.backbone.layers
# ------
import math
from dataclasses import dataclass, field
from typing import Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -296,130 +28,79 @@ class ModelArgs(BaseModelArgs):
time_step_max: float time_step_max: float
time_step_floor: float time_step_floor: float
rescale_prenorm_residual: bool rescale_prenorm_residual: bool
use_cache: bool
rms_norm: bool rms_norm: bool
chunk_size: int chunk_size: int
tie_word_embeddings: bool tie_word_embeddings: bool
use_cache: bool = True intermediate_size: int = None
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf"))) time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
time_step_rank: Union[int, str] = "auto" time_step_rank: Union[int, str] = "auto"
model_type: str = "mamba2" model_type: str = "mamba2"
def __post_init__(self): def __post_init__(self):
if not hasattr(self, "intermediate_size"): self.intermediate_size = int(self.expand * self.hidden_size) # E*D = ED
self.intermediate_size = int(self.expand * self.hidden_size)
if not hasattr(self, "head_dim"): if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_heads self.head_dim = self.hidden_size // self.num_heads
if self.time_step_rank == "auto": if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16) self.time_step_rank = math.ceil(self.hidden_size / 16)
class MambaRMSNormGated(nn.Module): def selective_scan(x, A, B, C, chunk_size):
def __init__(self, hidden_size, eps=1e-6): """
super().__init__() Selective scan implementation for training.
self.weight = mx.ones((hidden_size,))
self.variance_epsilon = eps
def __call__(self, hidden_states, gate=None): Arguments
if gate is not None: x: (batch, seqlen, n_heads, d_head)
hidden_states = hidden_states * nn.silu(gate) A: (batch, seqlen, n_heads)
variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True) B: (batch, seqlen, n_heads, d_state)
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon) C: (batch, seqlen, n_heads, d_state)
return self.weight * hidden_states
Return
y: (batch, seqlen, n_heads, d_head)
"""
assert x.shape[1] % chunk_size == 0
def silu(x): # Reshape into chunks
return x * mx.sigmoid(x) def chunk_reshape(m):
shape = list(m.shape)
shape[1:2] = [shape[1] // chunk_size, chunk_size]
return m.reshape(shape)
def ssd(x, A, B, C, chunk_size): x, A, B, C = map(chunk_reshape, (x, A, B, C))
# Replace einsum operations with explicit reshape and matrix multiply A = mx.transpose(A, [0, 3, 1, 2])
batch, seqlen, nheads, dim = x.shape
B = mx.expand_dims(B, axis=2)
C = mx.expand_dims(C, axis=2)
state = mx.zeros((batch, nheads, dim, B.shape[-1])) # Compute cumulative sums
outputs = [] A_cumsum = mx.cumsum(A, axis=-1)
for i in range(0, seqlen, chunk_size): # Process chunks
chunk = slice(i, min(i + chunk_size, seqlen)) L = mx.exp(selective_cumsum(A))
dA = mx.exp(mx.expand_dims(A[chunk], axis=0)) Y_diag = mx.einsum('bclhn,bcshn,bhcls,bcshp->bclhp', C, B, L, x)
# Replace einsum with explicit operations decay_states = mx.exp(A_cumsum[..., -1:] - A_cumsum)
x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim] states = mx.einsum('bclhn,bhcl,bclhp->bchpn', B, decay_states, x)
x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size]
B_chunk = B[:, chunk] # [batch, chunk_size, state_size]
dBx = mx.matmul(x_chunk, B_chunk) # [batch, nheads, dim, state_size]
state = state * mx.expand_dims(dA, axis=-1) + dBx initial_states = mx.zeros_like(states[:, :1])
states = mx.concatenate([initial_states, states], axis=1)
decay_chunk = mx.exp(selective_cumsum(mx.pad(A_cumsum[..., -1], ((0,0), (0,0), (1,0)))))
new_states = mx.einsum('bhzc,bchpn->bzhpn', decay_chunk, states)
states = new_states[:, :-1]
# Replace einsum with explicit operations state_decay_out = mx.exp(A_cumsum)
C_chunk = C[:, chunk] # [batch, chunk_size, state_size] Y_off = mx.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)
y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size]
y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim]
outputs.append(y)
return mx.concatenate(outputs, axis=1), state Y = (Y_diag + Y_off).reshape((-1, x.shape[1] * chunk_size, *Y_diag.shape[-2:]))
return Y
def selective_cumsum(x: mx.array) -> mx.array:
class DepthWiseConv1d(nn.Module): """Stable selective cumulative sum calculation."""
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0): T = x.shape[-1]
super().__init__() x = mx.repeat(x[..., None], T, axis=-1)
self.in_channels = in_channels mask = mx.tril(mx.ones((T, T)), k=-1)
self.out_channels = out_channels x = x * mask
self.kernel_size = kernel_size x_cumsum = mx.cumsum(x, axis=-2)
self.padding = padding mask = mx.tril(mx.ones((T, T)), k=0)
self.groups = groups if groups is not None else in_channels return mx.where(mask, x_cumsum, float('-inf'))
assert in_channels == out_channels, "In and out channels must be same for depthwise convolution"
assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution"
self.weight = mx.random.normal((in_channels, 1, kernel_size))
self.bias = mx.zeros((out_channels,)) if bias else None
def __call__(self, x: mx.array, cache=None) -> mx.array:
B, L, C = x.shape
K = self.kernel_size
assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}"
if cache is not None and 'conv_states' in cache:
conv_states = cache['conv_states']
if conv_states is not None:
assert conv_states.shape[0] == B, "Cache batch size mismatch"
assert conv_states.shape[2] == C, "Cache channel count mismatch"
x = mx.concatenate([conv_states, x], axis=1)
# Process each channel independently
outputs = []
for c in range(C):
x_c = x[:, :, c]
x_c = mx.expand_dims(x_c, axis=1)
w_c = self.weight[c]
if w_c.ndim == 2:
w_c = mx.expand_dims(w_c, axis=0)
elif w_c.ndim == 1:
w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0)
# Apply convolution
y_c = mx.conv_general(
x_c,
w_c,
stride=1,
padding=0
)
if self.bias is not None:
y_c = y_c + self.bias[c]
outputs.append(mx.squeeze(y_c, axis=1))
y = mx.stack(outputs, axis=-1)
# Update cache
if cache is not None:
cache['conv_states'] = x[:, -K+1:, :] if x.shape[1] >= K else x
return y
class Mamba2Block(nn.Module): class Mamba2Block(nn.Module):
@ -427,165 +108,250 @@ class Mamba2Block(nn.Module):
super().__init__() super().__init__()
self.args = args self.args = args
d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads # Project input to get various components [z, x, B, C, dt]
self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias) projection_size = (2 * args.intermediate_size + 2 * args.n_groups * args.state_size + args.num_heads)
self.in_proj = nn.Linear(
args.hidden_size,
projection_size,
bias=args.use_bias
)
conv_dim = args.intermediate_size + 2 * args.state_size # Convolution layer
self.conv1d = DepthWiseConv1d( conv_dim = args.intermediate_size + 2 * args.n_groups * args.state_size
self.conv1d = nn.Conv1d(
in_channels=conv_dim, in_channels=conv_dim,
out_channels=conv_dim, out_channels=conv_dim,
kernel_size=args.conv_kernel, kernel_size=args.conv_kernel,
groups=conv_dim, groups=conv_dim,
bias=args.use_conv_bias, padding=args.conv_kernel - 1,
padding=args.conv_kernel - 1 bias=args.use_conv_bias
) )
self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range # SSM parameters
self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range self.dt_bias = mx.zeros(args.num_heads)
self.D = mx.random.normal((args.num_heads,)) * args.initializer_range self.A_log = mx.zeros(args.num_heads)
self.D = mx.ones(args.num_heads)
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon) # Output projections
self.norm = nn.RMSNorm(args.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
if args.rescale_prenorm_residual: def __call__(self, u: mx.array, cache=None) -> mx.array:
layer_scale = math.sqrt(1.0 / args.num_hidden_layers) # return self.forward_training(x) if x.shape[1] > 1 else self.forward_inference(x, cache)
self.out_proj.weight = self.out_proj.weight * layer_scale
def __call__(self, x: mx.array, cache=None): # def forward_training(self, u: mx.array) -> mx.array:
if cache is not None: # # Reset cache during training
return self.step(x, cache) # self.cache = None
# Regular forward pass code remains the same... # # Input projection and splitting
d_model = self.args.intermediate_size # zxbcdt = self.in_proj(u)
d_state = self.args.state_size # z, xBC, dt = mx.split(
n_heads = self.args.num_heads # zxbcdt,
# [
# self.args.hidden_size,
# self.args.hidden_size + 2 * self.args.state_size
# ],
# axis=-1
# )
A = -mx.exp(self.A_log) # # Time step processing
zxbcdt = self.in_proj(x) # dt = mx.clip(
# nn.softplus(dt + self.dt_bias),
# self.args.time_step_min,
# self.args.time_step_max
# )
splits = [d_model, d_model + 2 * d_state, n_heads] # # Convolution processing
z = zxbcdt[:, :, :splits[0]] # xBC_t = mx.transpose(xBC, [0, 2, 1])
xBC = zxbcdt[:, :, splits[0]:splits[0] + splits[1]] # conv_out = self.conv1d(xBC_t)
dt = zxbcdt[:, :, -splits[2]:] # xBC = mx.transpose(conv_out, [0, 2, 1])[:, :u.shape[1]]
# xBC = mx.sigmoid(xBC) * xBC # SiLU
# # Split states
# x, B, C = mx.split(
# xBC,
# [self.args.hidden_size, self.args.state_size],
# axis=-1
# )
# # Reshape for selective scan
# x = x.reshape((-1, x.shape[1], self.args.num_heads, self.args.head_dim))
# A = -mx.exp(self.A_log)
# # Apply selective scan
# y = selective_scan(
# x * dt[..., None],
# A * dt,
# B[..., None, :],
# C[..., None, :],
# self.args.chunk_size
# )
# # Output processing
# y = y + x * self.D[None, None, :, None]
# y = y.reshape((-1, y.shape[1], self.args.hidden_size))
# y = self.norm(y, z)
# y = self.out_proj(y)
# return y
# def forward_inference(self, u: mx.array, cache=None) -> mx.array:
# """
# u: (B, 1, D)
# cache: (h_cache, conv_cache)
# """
# """Single token processing during inference."""
# assert u.shape[1] == 1, "Inference mode expects single token"
# batch_size = u.shape[0]
# # Use provided cache or create new one
# self.cache = cache if cache is not None else Mamba2Cache.get_cache(self.args, batch_size, None)
# # Project input
# zxbcdt = self.in_proj(u.squeeze(1)) # (B, 2D)
# d_mlp = (zxbcdt.shape[-1] - 2 * self.args.hidden_size - 2 * self.args.n_groups * self.args.state_size - self.args.num_heads) // 2
# # (1, 768) (1, 0) (1, 0) (1, 256) (1, 0) (1, 3328)
# y0, z0, x0, z, xBC, dt = mx.split(
# zxbcdt,
# [
# d_mlp,
# d_mlp,
# self.args.hidden_size,
# self.args.hidden_size + 2 * self.args.n_groups * self.args.state_size,
# self.args.num_heads
# ],
# axis=-1
# )
# # Update convolution state and apply
# conv_state = self.cache.update_conv_state(xBC)
# xBC = mx.sum(conv_state[:, :, -1] * mx.transpose(self.conv1d.weight, [1, 0, 2]), axis=-1) # (B, D) (4, 1792)
# if self.args.use_conv_bias:
# xBC = xBC + self.conv1d.bias
# xBC = mx.sigmoid(xBC) * xBC # SiLU (4, 1792)
# # Split states and ensure proper shapes
# a0, x, B, C = mx.split(
# xBC, # (4, 1792)
# [
# self.args.hidden_size,
# self.args.n_groups * self.args.state_size,
# self.args.n_groups * self.args.state_size
# ],
# axis=-1
# )
# # SSM step with explicit shapes
# A = -mx.exp(self.A_log) # (num_heads) (24,)
# print(A.shape) # (24,)
# print(dt.shape) # (1, 3328)
# dA = mx.exp(dt * A[None, :]) # Shape: (batch_size, num_heads) <------- her eis the error
# # Reshape x considering intermediate size
# # x shape should be (batch_size * num_heads, head_dim)
# x = mx.reshape(x, (batch_size, self.args.num_heads, -1))
# assert x.shape[-1] == self.args.head_dim, f"Head dimension mismatch: {x.shape[-1]} vs {self.args.head_dim}"
# B = mx.reshape(B, (batch_size, -1)) # Should be (batch_size, state_size)
# C = mx.reshape(C, (batch_size, -1)) # Should be (batch_size, state_size)
# # Compute dBx with explicit shapes
# dBx = mx.einsum('bh,bs,bhd->bhds', dt, B, x)
# ssm_state = self.cache.update_ssm_state(dA, dBx)
# y = mx.einsum('bhds,bs->bhd', ssm_state, C)
# y = y + x * self.D[None, :, None]
# y = mx.reshape(y, (batch_size, self.args.hidden_size))
# # Output processing
# y = self.norm(y, z)
# if d_mlp > 0:
# y = mx.cat([nn.silu(z0) * x0, y], axis=-1)
# y = self.out_proj(y)
# return mx.expand_dims(y, 1)
assert u.shape[1] == 1, "Inference mode expects single token"
batch_size = u.shape[0]
# Use provided cache or create new one
self.cache = cache if cache is not None else Mamba2Cache.get_cache(self.args, batch_size, None)
# Project input
zxbcdt = self.in_proj(u.squeeze(1)) # (B, projection_size)
# Calculate splits based on model dimensions
d_mlp = self.args.intermediate_size
d_state = self.args.state_size * self.args.n_groups
# Split the projection into its components
splits = [
d_mlp, # y0
d_mlp, # z0
self.args.hidden_size, # x0
self.args.hidden_size, # z
d_state * 2, # xBC (includes both B and C)
self.args.num_heads # dt
]
y0, z0, x0, z, xBC, dt = mx.split(zxbcdt, splits[:-1], axis=-1)
# Update convolution state and apply
conv_state = self.cache.update_conv_state(xBC)
xBC = mx.sum(conv_state[:, :, -1] * mx.transpose(self.conv1d.weight, [1, 0, 2]), axis=-1)
if self.args.use_conv_bias:
xBC = xBC + self.conv1d.bias
xBC = mx.sigmoid(xBC) * xBC # SiLU
# Split states and reshape
x, BC = mx.split(xBC, [self.args.intermediate_size], axis=-1)
B, C = mx.split(BC, [d_state], axis=-1)
# Reshape for SSM computation
x = mx.reshape(x, (batch_size, self.args.num_heads, -1)) # (B, H, head_dim)
B = mx.reshape(B, (batch_size, self.args.num_heads, -1)) # (B, H, state_per_head)
C = mx.reshape(C, (batch_size, self.args.num_heads, -1)) # (B, H, state_per_head)
# Process dt to match expected shape
dt = mx.reshape(dt, (batch_size, self.args.num_heads)) # (B, H)
dt = mx.clip( dt = mx.clip(
nn.softplus(dt + self.dt_bias), nn.softplus(dt + self.dt_bias),
self.args.time_step_min, self.args.time_step_min,
self.args.time_step_max self.args.time_step_max
) )
dt = mx.maximum(dt, self.args.time_step_floor)
xBC = silu(self.conv1d(xBC)) # SSM step
A = -mx.exp(self.A_log) # (H,)
dA = mx.exp(dt * A[None, :]) # (B, H)
x = xBC[:, :, :d_model] # Compute dBx
B = xBC[:, :, d_model:d_model + d_state] dBx = mx.einsum('bh,bhs,bhd->bhds', dt, B, x)
C = xBC[:, :, -d_state:]
b, l, hp = x.shape # Update SSM state and compute output
h = self.args.num_heads ssm_state = self.cache.update_ssm_state(dA, dBx)
p = hp // h y = mx.einsum('bhds,bhs->bhd', ssm_state, C)
x = mx.reshape(x, (b, l, h, p)) y = y + x * self.D[None, :, None]
y, ssm_state = ssd(x * mx.expand_dims(dt, -1), A * dt, B, C, self.args.chunk_size) # Reshape output
y = y + x * mx.expand_dims(self.D, -1) y = mx.reshape(y, (batch_size, self.args.hidden_size))
y = mx.reshape(y, (b, l, h * p))
# Final output processing
y = self.norm(y, z)
if d_mlp > 0:
y = mx.concat([nn.silu(z0) * x0, y], axis=-1)
y = self.norm(y + z)
y = self.out_proj(y) y = self.out_proj(y)
if self.args.residual_in_fp32: return mx.expand_dims(y, 1) # (B, 1, D)
y = y.astype(mx.float32)
return y
def step(self, u: mx.array, cache):
batch_size = u.shape[0]
seq_len = u.shape[1]
outputs = []
# Initialize cache if needed
if cache.conv_states is None:
conv_dim = self.args.intermediate_size + 2 * self.args.state_size
cache.conv_states = mx.zeros((
batch_size,
self.args.conv_kernel - 1,
conv_dim
))
if cache.ssm_state is None:
cache.ssm_state = mx.zeros((
batch_size,
self.args.num_heads,
self.args.head_dim,
self.args.state_size
))
for pos in range(seq_len):
u_t = u[:, pos:pos+1, :]
zxbcdt = self.in_proj(u_t)
d_model = self.args.intermediate_size
d_state = self.args.state_size
n_heads = self.args.num_heads
z = zxbcdt[:, :, :d_model]
xBC = zxbcdt[:, :, d_model:d_model + 2*d_state + d_model]
dt = zxbcdt[:, :, -(n_heads):]
dt = mx.reshape(dt, (batch_size, n_heads))
dt = mx.clip(
nn.softplus(dt + self.dt_bias),
self.args.time_step_min,
self.args.time_step_max
)
dt = mx.maximum(dt, self.args.time_step_floor)
# Create a temporary cache dictionary for the convolution
conv_cache = {'conv_states': cache.conv_states}
xBC = self.conv1d(xBC, cache=conv_cache)
cache.conv_states = conv_cache['conv_states']
xBC = silu(xBC)
x = xBC[:, :, :d_model]
B = xBC[:, :, d_model:d_model + d_state]
C = xBC[:, :, -d_state:]
x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim))
x = mx.squeeze(x, axis=1)
B = mx.reshape(B, (batch_size, 1, d_state))
B = mx.broadcast_to(B, (batch_size, n_heads, d_state))
B = mx.expand_dims(B, axis=2)
C = mx.reshape(C, (batch_size, 1, d_state))
C = mx.broadcast_to(C, (batch_size, n_heads, d_state))
C = mx.expand_dims(C, axis=3)
A = -mx.exp(self.A_log)
dA = mx.exp(dt * mx.expand_dims(A, 0))
dA = mx.expand_dims(mx.expand_dims(dA, -1), -1)
x = mx.expand_dims(x, axis=3)
dBx = mx.matmul(x, B)
cache.ssm_state = cache.ssm_state * dA + dBx
y = mx.matmul(cache.ssm_state, C)
y = mx.squeeze(y, axis=-1)
y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1)
y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim))
y = self.norm(y + z)
y = self.out_proj(y)
if self.args.residual_in_fp32:
y = y.astype(mx.float32)
outputs.append(y)
return mx.concatenate(outputs, axis=1)
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
@ -594,11 +360,12 @@ class ResidualBlock(nn.Module):
self.mixer = Mamba2Block(args) self.mixer = Mamba2Block(args)
self.norm = nn.RMSNorm(args.hidden_size) self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache): def __call__(self, x: mx.array, cache=None) -> mx.array:
return self.mixer(self.norm(x), cache) + x # x : (B, L, D)
return self.mixer(self.norm(x), cache) + x # (B, L, D)
class Mamba2(nn.Module): class Mamba2Model(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.args = args self.args = args
@ -606,12 +373,15 @@ class Mamba2(nn.Module):
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(self, x: mx.array, cache): def __call__(self, x: mx.array, cache=None) -> mx.array:
# x : (B, L)
x = self.embeddings(x) x = self.embeddings(x)
# x : (B, L, D)
if cache is None: if cache is None:
cache = [None] * len(self.layers) cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
x = layer(x, c) for layer, layer_cache in zip(self.layers, cache):
x = layer(x, layer_cache)
return self.norm_f(x) return self.norm_f(x)
@ -619,14 +389,13 @@ class Model(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.args = args self.args = args
self.model_type = args.model_type self.backbone = Mamba2Model(args)
self.backbone = Mamba2(args)
if not args.tie_word_embeddings: if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None): def __call__(self, inputs: mx.array, cache=None) -> mx.array:
# inputs : (B, L)
B, T = inputs.shape B, T = inputs.shape
x = self.backbone(inputs, cache) x = self.backbone(inputs, cache)
@ -638,23 +407,18 @@ class Model(nn.Module):
return logits return logits
def make_cache(self): def make_cache(self, batch_size=1):
return [Mamba2Cache() for _ in range(len(self.layers))] return [Mamba2Cache(
batch_size=batch_size,
hidden_size=self.args.hidden_size,
state_size=self.args.state_size,
conv_kernel=self.args.conv_kernel,
num_heads=self.args.num_heads,
head_dim=self.args.head_dim
) for _ in range(len(self.backbone.layers))]
def sanitize(self, weights): def sanitize(self, weights):
sanitized = {}
for k, v in weights.items(): for k, v in weights.items():
if "conv1d.weight" in k: if "conv1d.weight" in k and v.ndim == 3:
# Ensure weights are in correct shape (channels, 1, kernel_size) weights[k] = v.moveaxis(2, 1)
if v.ndim == 2: return weights
v = mx.expand_dims(v, axis=1)
elif v.ndim == 1:
v = mx.expand_dims(mx.expand_dims(v, axis=0), axis=0)
sanitized[k] = v
else:
sanitized[k] = v
return sanitized
@property
def layers(self):
return self.backbone.layers

View File

@ -1,437 +1,490 @@
""" # coding=utf-8
mamba2-minimal # Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team.
============== #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch MAMBA2 model."""
A minimal, single-file implementation of the Mamba-2 model in PyTorch. import math
> **Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality**
> Authors: Tri Dao, Albert Gu
> Paper: https://arxiv.org/abs/2405.21060
"""
import json
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable, NamedTuple, TypeAlias, cast from typing import Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.utils.checkpoint
from einops import rearrange, repeat from torch import nn
from torch import LongTensor, Tensor, nn from torch.nn import CrossEntropyLoss
Device: TypeAlias = str | torch.device | None logger = logging.get_logger(__name__)
@dataclass def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int):
class Mamba2Config: """
d_model: int # model dimension (D) Padding x tensor with `pad_size` on the seq_len dim (dim=1)
n_layer: int = 24 # number of Mamba-2 layers in the language model
d_state: int = 128 # state dimension (N)
d_conv: int = 4 # convolution kernel size
expand: int = 2 # expansion factor (E)
headdim: int = 64 # head dimension (P)
chunk_size: int = 64 # matrix partition size (Q)
vocab_size: int = 50277
pad_vocab_size_multiple: int = 16
def __post_init__(self): Assumes that we only have tensors of either size 4 or 3
self.d_inner = self.expand * self.d_model """
assert self.d_inner % self.headdim == 0 pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0)
self.nheads = self.d_inner // self.headdim
if self.vocab_size % self.pad_vocab_size_multiple != 0: return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0)
self.vocab_size += (
self.pad_vocab_size_multiple
- self.vocab_size % self.pad_vocab_size_multiple
)
class InferenceCache(NamedTuple): def reshape_into_chunks(input_tensor, pad_size, chunk_size):
conv_state: Tensor # (batch, d_inner + 2 * d_state, d_conv) """
ssm_state: Tensor # (batch, nheads, headdim, d_state) Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and
simultaneously splitting it into chunk sequences.
@staticmethod Assumes that we only have tensors of either size 4 or 3
def alloc(batch_size: int, args: Mamba2Config, device: Device = None): """
return InferenceCache( # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...]
torch.zeros( input_tensor = pad_tensor_by_size(input_tensor, pad_size)
batch_size, args.d_inner + 2 * args.d_state, args.d_conv, device=device
), if len(input_tensor.shape) == 3:
torch.zeros( # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads]
batch_size, args.nheads, args.headdim, args.d_state, device=device return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2])
), else:
# [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size]
return input_tensor.reshape(
input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3]
) )
class Mamba2LMHeadModel(nn.Module): def segment_sum(input_tensor):
def __init__(self, args: Mamba2Config, device: Device = None): """
More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.
"""
chunk_size = input_tensor.size(-1)
# 1. expand input tensor to have an additional dimension and repeat along that dimension
# [..., chunk_size] -> [..., chunk_size, chunk_size]
input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size)
# 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag
mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1)
input_tensor = input_tensor.masked_fill(~mask, 0)
# 3. compute actual cumsum
tensor_segsum = torch.cumsum(input_tensor, dim=-2)
# 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time)
mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0)
tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf)
return tensor_segsum
class Mamba2Cache:
"""
Arguments:
config: ModelArgs
batch_size: int
dtype: torch.dtype
device: torch.device
Attributes:
seqlen_offset: int
dtype: torch.dtype
conv_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, conv_kernel_size]
ssm_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, ssm_state_size]
"""
def __init__(
self, config: ModelArgs, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None
):
self.seqlen_offset = 0
self.dtype = dtype
self.conv_kernel_size = config.conv_kernel
self.intermediate_size = int(config.expand * config.hidden_size)
self.conv_states = {
i: torch.zeros(
batch_size,
self.intermediate_size + 2 * config.n_groups * config.state_size,
self.conv_kernel_size,
device=device,
dtype=dtype,
)
for i in range(config.num_hidden_layers)
}
self.ssm_states = {
i: torch.zeros(
batch_size, config.num_heads, config.head_dim, config.state_size, device=device, dtype=dtype
)
for i in range(config.num_hidden_layers)
}
def update_conv_state(
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
) -> torch.Tensor:
conv_state = self.conv_states[layer_idx]
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
conv_state = conv_state.roll(shifts=-1, dims=-1)
conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
self.conv_states[layer_idx].zero_()
self.conv_states[layer_idx] += conv_state
return self.conv_states[layer_idx]
def reset(self):
self.conv_states.zero_()
self.ssm_states.zero_()
class MambaRMSNormGated(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__() super().__init__()
self.args = args self.weight = nn.Parameter(torch.ones(hidden_size))
self.device = device self.variance_epsilon = eps
self.backbone = nn.ModuleDict( def forward(self, hidden_states, gate=None):
dict( input_dtype = hidden_states.dtype
embedding=nn.Embedding(args.vocab_size, args.d_model, device=device), hidden_states = hidden_states
layers=nn.ModuleList(
[ if gate is not None:
nn.ModuleDict( hidden_states = hidden_states * nn.functional.silu(gate)
dict( variance = hidden_states.pow(2).mean(-1, keepdim=True)
mixer=Mamba2(args, device=device), hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
norm=RMSNorm(args.d_model, device=device),
) return self.weight * hidden_states
)
for _ in range(args.n_layer)
] class Mamba2Mixer(nn.Module):
), def __init__(self, config: ModelArgs):
norm_f=RMSNorm(args.d_model, device=device), super().__init__()
self.num_heads = config.num_heads
self.hidden_size = config.hidden_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel
self.intermediate_size = int(config.expand * self.hidden_size)
self.time_step_rank = int(config.time_step_rank)
self.use_conv_bias = config.use_conv_bias
self.act = nn.silu
self.layer_norm_epsilon = config.layer_norm_epsilon
self.rms_norm = config.rms_norm
self.n_groups = config.n_groups
self.head_dim = config.head_dim
self.chunk_size = config.chunk_size
self.time_step_limit = config.time_step_limit
self.time_step_min = config.time_step_min
self.time_step_max = config.time_step_max
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
self.conv1d = nn.Conv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
bias=config.use_conv_bias,
kernel_size=config.conv_kernel,
groups=self.conv_dim,
padding=config.conv_kernel - 1,
)
# projection of the input hidden states
projection_size = self.intermediate_size + self.conv_dim + self.num_heads
self.in_proj = nn.Linear(
self.hidden_size,
projection_size,
bias=config.use_bias,
)
self.dt_bias = torch.ones(self.num_heads)
A = torch.arange(1, self.num_heads + 1)
self.A_log = torch.log(A)
self.D = torch.ones(self.num_heads)
self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon)
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
def forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, cache_position:Optional[torch.LongTensor]=None):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
# Gated MLP's linear projection
projected_states = self.in_proj(input_states.squeeze(1))
d_mlp = (
projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size- self.num_heads) // 2
_, _, gate, hidden_states, dt = projected_states.split(
[d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
)
# Convolution sequence transformation
if cache_params is not None:
ssm_state = cache_params.ssm_states[self.layer_idx].clone()
ssm_state = ssm_state.to(hidden_states.device)
if cache_params.seqlen_offset > 0:
conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size]
conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
# handle batched generation - states are copied through
conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states
cache_params.conv_states[self.layer_idx].copy_(conv_state)
hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
if self.use_conv_bias:
hidden_states += self.conv1d.bias
hidden_states = self.act(hidden_states)[:, None, ...] # [batch, 1, intermediate_size] : decoding
else:
hidden_states = hidden_states.transpose(1,2)
conv_state = nn.functional.pad(
hidden_states,
(self.conv_kernel_size - hidden_states.shape[-1], 0)
)
cache_params.conv_states[self.layer_idx].copy_(conv_state)
hidden_states = self.act(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len]
else:
ssm_state = torch.zeros(
(batch_size, self.num_heads, self.head_dim, self.ssm_state_size),
device=hidden_states.device
) )
) hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2))
self.lm_head = nn.Linear(
args.d_model, args.vocab_size, bias=False, device=device
)
self.lm_head.weight = self.backbone.embedding.weight
@staticmethod hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1)
def from_pretrained(huggingface_model_id: str, device: Device = None): A = -torch.exp(self.A_log.float()) # [num_heads]
from transformers.utils import CONFIG_NAME, WEIGHTS_NAME
from transformers.utils.hub import cached_file
config_path = cached_file(huggingface_model_id, CONFIG_NAME) if cache_params is not None and cache_params.seqlen_offset > 0:
assert config_path, "Failed to get huggingface config file" # Note: there is no need to pad parameter matrices here, as there is just one new token
state_dict_path = cached_file(huggingface_model_id, WEIGHTS_NAME) # for batched generation
assert state_dict_path, "Failed to get huggingface state dict file" dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...]
dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim)
# [num_heads] -> [num_heads, head_dim]
dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)
config = json.load(open(config_path)) dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype))
args = Mamba2Config( dt = torch.clamp(dt, self.time_step_min) #, self.time_step_max)
d_model=config["d_model"], A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
n_layer=config["n_layer"], # [bsz, num_heads, head_dim, state_size]
vocab_size=config["vocab_size"], dA = torch.exp(dt[..., None] * A)
pad_vocab_size_multiple=config["pad_vocab_size_multiple"],
)
map_location = "cpu" if device is None else device # Discretize B
state_dict = torch.load( # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
state_dict_path, weights_only=True, map_location=map_location, mmap=True # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size]
) B = B.reshape(batch_size, self.n_groups, -1)[..., None, :]
model = Mamba2LMHeadModel(args, device=device) B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous()
model.load_state_dict(state_dict) B = B.reshape(batch_size, -1, B.shape[-1])
model.eval() # [bsz, num_heads, head_dim, state_size]
return model dB = dt[..., None] * B[..., None, :]
# Discretize x into dB
# [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
dBx = dB * hidden_states[..., None]
# State calculation
cache_params.ssm_states[self.layer_idx].copy_(
cache_params.ssm_states[self.layer_idx] * dA + dBx
)
# Subsequent output
# [bsz, n_groups * state_size] -> [bsz, num_heads, state_size]
C = C.reshape(batch_size, self.n_groups, -1)[..., None, :]
C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous()
C = C.reshape(batch_size, -1, C.shape[-1])
# [bsz, num_heads, head_dim]
ssm_states = cache_params.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n]
# Reshape ssm_states to merge the first two dimensions
ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n]
C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1]
y = torch.bmm(ssm_states_reshaped, C_reshaped)
y = y.view(batch_size, self.num_heads, self.head_dim)
# D skip connection
# [num_heads] -> [num_heads, head_dim]
D = self.D[..., None].expand(self.D.shape[0], self.head_dim)
y = (y + hidden_states * D).to(y.dtype)
# [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size]
y = y.reshape(batch_size, -1)[:, None, ...]
else:
# begin ssd naive implementation without einsums
dt = nn.functional.softplus(dt + self.dt_bias)
dt = torch.clamp(dt, self.time_step_min)
hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
pad_size = self.chunk_size - (seq_len % self.chunk_size)
D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
# Discretize x and A
hidden_states = hidden_states * dt[..., None]
A = A.to(hidden_states.dtype) * dt
# Rearrange into blocks/chunks
hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]
# [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size]
A = A.permute(0, 3, 1, 2)
A_cumsum = torch.cumsum(A, dim=-1)
# 1. Compute the output for each intra-chunk (diagonal blocks)
# This is the analog of a causal mask
L = torch.exp(segment_sum(A))
# First, contraction of C and B to get G (attention-weights like)
G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n)
G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)
# Step 2: Compute M, equivalent to applying attention mask to weights
M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
M = M_intermediate.sum(dim=-1)
# Step 3: Compute Y_diag (apply to values)
Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3)
# (right term of low-rank factorization of off-diagonal blocks; B terms)
decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None]
# permute back B * decay states
states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3)
if cache_params is not None and cache_params.seqlen_offset > 0:
previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...]
else:
previous_states = torch.zeros_like(states[:, :1])
states = torch.cat([previous_states, states], dim=1)
decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
states_permuted = states.permute(0, 2, 1, 3, 4)
result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2)
new_states = result.permute(0, 2, 1, 3, 4)
states, ssm_state = new_states[:, :-1], new_states[:, -1]
# Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = torch.exp(A_cumsum)
# compute Yoff
C_times_states = (C[..., None, :] * states[:, :, None, ...])
state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
y = Y_diag + Y_off
# [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
y = y + D_residual
# Cutting off padded chunks
if pad_size > 0:
y = y[:, :seq_len, :, :]
y = y.reshape(batch_size, seq_len, -1)
if ssm_state is not None and cache_params is not None:
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
scan_output = self.norm(y, gate)
# end ssd naive
# 4. Final linear projection
contextualized_states = self.out_proj(scan_output) # [batch, seq_len, hidden_size]
return contextualized_states
class Mamba2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Mamba2RMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
class Mamba2Block(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.norm = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mixer = Mamba2Mixer(config)
def forward( def forward(
self, input_ids: LongTensor, h: list[InferenceCache] | list[None] | None = None
) -> tuple[LongTensor, list[InferenceCache]]:
"""
Arguments
input_ids: (batch, seqlen) tokens from `EleutherAI/gpt-neox-20b` tokenizer
h: hidden states for inference step. If present the constant-time
(wrt sequence length) inference path will be taken, input_ids
should have shape (batch, 1) containing the next batch of prompt
token.
Return (logits, h)
logits: (batch, seqlen, vocab_size)
h: updated inference cache after processing `input_ids`
"""
seqlen = input_ids.shape[1]
if h is None:
h = [None for _ in range(self.args.n_layer)]
x = self.backbone.embedding(input_ids)
for i, layer in enumerate(self.backbone.layers):
y, h[i] = layer.mixer(layer.norm(x), h[i])
x = y + x
x = self.backbone.norm_f(x)
logits = self.lm_head(x)
return logits[:, :seqlen], cast(list[InferenceCache], h)
def generate(
self, self,
input_ids: LongTensor, hidden_states,
max_new_length: int = 20, cache_params: Optional[Mamba2Cache] = None,
temperature: float = 1.0, cache_position: Optional[torch.LongTensor] = None,
top_k: int = 50, ):
top_p: float = 1.0, x = self.mixer(
eos_token_id: int = 0, self.norm(hidden_states), cache_params=cache_params, cache_position=cache_position
) -> Iterable[tuple[int, list[InferenceCache]]]: )
prefix, tokens = input_ids[:-1], input_ids[-1:].unsqueeze(0) return x + hidden_states
# Process prompt
# The input sequence to forward (non-inference path) must have length multiple that of chunk_size. class Mamba2Model(nn.Module):
# We split out excess tokens so that n_chunked tokens can be processed by one forward call and def __init__(self, config):
# process the rest in multiple inference steps. super().__init__(config)
n_chunked = (prefix.shape[0] // self.args.chunk_size) * self.args.chunk_size self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
if n_chunked > 0: self.layers = nn.ModuleList([Mamba2Block(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
_, h = self(prefix[:n_chunked].unsqueeze(0), None)
self.gradient_checkpointing = False
self.norm_f = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
cache_params: Optional[Mamba2Cache] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.embeddings(input_ids)
if use_cache:
if cache_params is None:
cache_params = Mamba2Cache(
self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
)
cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device)
else: else:
h = [ cache_params = None
InferenceCache.alloc(1, self.args, device=self.device)
for _ in range(self.args.n_layer)
]
for i in range(n_chunked, prefix.shape[0]):
_, h = self(prefix[i : i + 1].unsqueeze(0), h)
# Generate hidden_states = inputs_embeds
for _ in range(max_new_length): for mixer_block in self.layers:
with torch.no_grad(): hidden_states = mixer_block(
out, h = self(tokens, h) hidden_states,
logits = out[0, -1] cache_params=cache_params,
if temperature != 1.0: cache_position=cache_position,
logits = logits / temperature )
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, k=top_k)[0][-1] if use_cache:
logits[indices_to_remove] = -torch.inf cache_params.seqlen_offset += inputs_embeds.shape[1]
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True) return self.norm_f(hidden_states), cache_params if use_cache else None
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > 0.5
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
sorted_indices_to_remove[0] = False
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = -torch.inf
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
if next_token.item() == eos_token_id:
return
tokens = next_token.unsqueeze(0)
yield cast(int, next_token.item()), h
class Mamba2(nn.Module):
def __init__(self, args: Mamba2Config, device: Device = None):
super().__init__()
self.args = args
self.device = device
# Order: (z, x, B, C, dt) class Mamba2ForCausalLM(nn.Module):
d_in_proj = 2 * args.d_inner + 2 * args.d_state + args.nheads def __init__(self, config):
self.in_proj = nn.Linear(args.d_model, d_in_proj, bias=False, device=device) super().__init__(config)
self.backbone = Mamba2Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
conv_dim = args.d_inner + 2 * args.d_state def forward(
self.conv1d = nn.Conv1d( self,
in_channels=conv_dim, input_ids: Optional[torch.LongTensor] = None,
out_channels=conv_dim, cache_params: Optional[Mamba2Cache] = None,
kernel_size=args.d_conv, use_cache: Optional[bool] = None,
groups=conv_dim, cache_position: Optional[torch.Tensor] = None,
padding=args.d_conv - 1, ):
device=device, mamba2_outputs = self.backbone(
input_ids,
cache_params=cache_params,
use_cache=use_cache,
cache_position=cache_position,
) )
hidden_states = mamba2_outputs[0]
self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device)) logits = self.lm_head(hidden_states)
self.A_log = nn.Parameter(torch.empty(args.nheads, device=device)) return logits, mamba2_outputs.cache_params, mamba2_outputs.hidden_states
self.D = nn.Parameter(torch.empty(args.nheads, device=device))
self.norm = RMSNorm(args.d_inner, device=device)
self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device)
def forward(self, u: Tensor, h: InferenceCache | None = None):
"""
Arguments
u: (batch, seqlen, d_model) input. seqlen should be a multiple of chunk_size.
h: hidden states for inference step. Initialized to 0s if not present.
Return (y, h)
y: (batch, seqlen, d_model) output
h: updated inference cache after processing `u`
"""
if h:
return self.step(u, h)
A = -torch.exp(self.A_log) # (nheads,)
zxbcdt = self.in_proj(u) # (batch, seqlen, d_in_proj)
z, xBC, dt = torch.split(
zxbcdt,
[
self.args.d_inner,
self.args.d_inner + 2 * self.args.d_state,
self.args.nheads,
],
dim=-1,
)
dt = F.softplus(dt + self.dt_bias) # (batch, seqlen, nheads)
# Pad or truncate xBC seqlen to d_conv
conv_state = F.pad(
rearrange(xBC, "b l d -> b d l"), (self.args.d_conv - u.shape[1], 0)
)
xBC = silu(
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, : u.shape[1], :]
) # (batch, seqlen, d_inner + 2 * d_state))
x, B, C = torch.split(
xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1
)
x = rearrange(x, "b l (h p) -> b l h p", p=self.args.headdim)
y, ssm_state = ssd(
x * dt.unsqueeze(-1),
A * dt,
rearrange(B, "b l n -> b l 1 n"),
rearrange(C, "b l n -> b l 1 n"),
self.args.chunk_size,
device=self.device,
)
y = y + x * self.D.unsqueeze(-1)
y = rearrange(y, "b l h p -> b l (h p)")
y = self.norm(y, z)
y = self.out_proj(y)
h = InferenceCache(conv_state, ssm_state)
return y, h
def step(self, u: Tensor, h: InferenceCache) -> tuple[Tensor, InferenceCache]:
"""Take a single inference step for the current input and hidden state
Unlike attention-based models, RNN-based models (eg Mamba) does not need
to look back at all the past tokens to generate a new token. Instead a
hidden state (initialized to 0s initially) is updated for each input and
passed to the next inference step. This means that the total inference
time is linear with respect to the sequence length instead of quadratic
in attention's case.
Arguments
u: (batch, 1, d_model)
h: initial/running hidden state
Return (y, h)
y: (batch, 1, d_model)
h: updated hidden state
"""
assert u.shape[1] == 1, "Only one token can be decoded per inference step"
zxbcdt = self.in_proj(u.squeeze(1)) # (batch, d_in_proj)
z, xBC, dt = torch.split(
zxbcdt,
[
self.args.d_inner,
self.args.d_inner + 2 * self.args.d_state,
self.args.nheads,
],
dim=-1,
)
# Advance convolution input
h.conv_state.copy_(torch.roll(h.conv_state, shifts=-1, dims=-1))
h.conv_state[:, :, -1] = xBC
# Convolution step
xBC = torch.sum(
h.conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
)
xBC += self.conv1d.bias
xBC = silu(xBC)
x, B, C = torch.split(
xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1
)
A = -torch.exp(self.A_log) # (nheads,)
# SSM step
dt = F.softplus(dt + self.dt_bias) # (batch, nheads)
dA = torch.exp(dt * A) # (batch, nheads)
x = rearrange(x, "b (h p) -> b h p", p=self.args.headdim)
dBx = torch.einsum("bh, bn, bhp -> bhpn", dt, B, x)
h.ssm_state.copy_(h.ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
y = torch.einsum("bhpn, bn -> bhp", h.ssm_state, C)
y = y + rearrange(self.D, "h -> h 1") * x
y = rearrange(y, "b h p -> b (h p)")
y = self.norm(y, z)
y = self.out_proj(y)
return y.unsqueeze(1), h
def segsum(x: Tensor, device: Device = None) -> Tensor:
"""Stable segment sum calculation.
`exp(segsum(A))` produces a 1-semiseparable matrix, which is equivalent to a scalar SSM.
Source: https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L23-L32
"""
T = x.size(-1)
x = repeat(x, "... d -> ... d e", e=T)
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=-1)
x = x.masked_fill(~mask, 0)
x_segsum = torch.cumsum(x, dim=-2)
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=0)
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
return x_segsum
def ssd(x, A, B, C, chunk_size, initial_states=None, device: Device = None):
"""Structed State Space Duality (SSD) - the core of Mamba-2
This is almost the exact same minimal SSD code from the blog post.
Arguments
x: (batch, seqlen, n_heads, d_head)
A: (batch, seqlen, n_heads)
B: (batch, seqlen, n_heads, d_state)
C: (batch, seqlen, n_heads, d_state)
Return
y: (batch, seqlen, n_heads, d_head)
Source
1. https://tridao.me/blog/2024/mamba2-part3-algorithm/
2. https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L34-L78
"""
assert x.shape[1] % chunk_size == 0
# Rearrange into chunks
# Step 1, 2 and 4 of SSD can be computed in parallel for each chunk across devices (sequence parallel)
# This is not implemented and left as an exercise for the reader 😜
x, A, B, C = [
rearrange(m, "b (c l) ... -> b c l ...", l=chunk_size) for m in (x, A, B, C)
]
A = rearrange(A, "b c l h -> b h c l")
A_cumsum = torch.cumsum(A, dim=-1)
# 1. Compute the output for each intra-chunk (diagonal blocks)
L = torch.exp(segsum(A, device=device))
Y_diag = torch.einsum("bclhn, bcshn, bhcls, bcshp -> bclhp", C, B, L, x)
# 2. Compute the state for each intra-chunk
# (right term of low-rank factorization of off-diagonal blocks; B terms)
decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
states = torch.einsum("bclhn, bhcl, bclhp -> bchpn", B, decay_states, x)
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
if initial_states is None:
initial_states = torch.zeros_like(states[:, :1])
states = torch.cat([initial_states, states], dim=1)
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)), device=device))
new_states = torch.einsum("bhzc, bchpn -> bzhpn", decay_chunk, states)
states, final_state = new_states[:, :-1], new_states[:, -1]
# 4. Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = torch.exp(A_cumsum)
Y_off = torch.einsum("bclhn, bchpn, bhcl -> bclhp", C, states, state_decay_out)
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
return Y, final_state
class RMSNorm(nn.Module):
def __init__(self, d: int, eps: float = 1e-5, device: Device = None):
"""Gated Root Mean Square Layer Normalization
Paper: https://arxiv.org/abs/1910.07467
"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d, device=device))
def forward(self, x, z=None):
if z is not None:
x = x * silu(z)
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
def silu(x):
"""Applies the Sigmoid Linear Unit (SiLU), element-wise.
Define this manually since torch's version doesn't seem to work on MPS.
"""
return x * F.sigmoid(x)

View File

@ -32,259 +32,272 @@ class ModelArgs(BaseModelArgs):
rms_norm: bool rms_norm: bool
chunk_size: int chunk_size: int
tie_word_embeddings: bool tie_word_embeddings: bool
intermediate_size: int = None
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf"))) time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
time_step_rank: Union[int, str] = "auto" time_step_rank: Union[int, str] = "auto"
model_type: str = "mamba2" model_type: str = "mamba2"
def __post_init__(self): def __post_init__(self):
if not hasattr(self, "intermediate_size"): self.intermediate_size = int(self.expand * self.hidden_size) # E*D = ED
self.intermediate_size = int(self.expand * self.hidden_size)
if not hasattr(self, "head_dim"): if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_heads self.head_dim = self.hidden_size // self.num_heads
if self.time_step_rank == "auto": if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16) self.time_step_rank = math.ceil(self.hidden_size / 16)
def selective_scan(x, A, B, C, chunk_size): class MambaRMSNormGated(nn.Module):
""" def __init__(self, hidden_size, eps=1e-6):
Selective scan implementation for training. super().__init__()
self.weight = mx.ones(hidden_size)
self.variance_epsilon = eps
Arguments def forward(self, hidden_states, gate=None):
x: (batch, seqlen, n_heads, d_head) input_dtype = hidden_states.dtype
A: (batch, seqlen, n_heads) hidden_states = hidden_states.to(mx.float32)
B: (batch, seqlen, n_heads, d_state)
C: (batch, seqlen, n_heads, d_state)
Return if gate is not None:
y: (batch, seqlen, n_heads, d_head) hidden_states = hidden_states * nn.functional.silu(gate.to(mx.float32))
""" variance = hidden_states.pow(2).mean(-1, keepdim=True)
assert x.shape[1] % chunk_size == 0 hidden_states = hidden_states * math.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
# Reshape into chunks
def chunk_reshape(m):
shape = list(m.shape)
shape[1:2] = [shape[1] // chunk_size, chunk_size]
return m.reshape(shape)
x, A, B, C = map(chunk_reshape, (x, A, B, C))
A = mx.transpose(A, [0, 3, 1, 2])
# Compute cumulative sums
A_cumsum = mx.cumsum(A, axis=-1)
# Process chunks
L = mx.exp(selective_cumsum(A))
Y_diag = mx.einsum('bclhn,bcshn,bhcls,bcshp->bclhp', C, B, L, x)
decay_states = mx.exp(A_cumsum[..., -1:] - A_cumsum)
states = mx.einsum('bclhn,bhcl,bclhp->bchpn', B, decay_states, x)
initial_states = mx.zeros_like(states[:, :1])
states = mx.concatenate([initial_states, states], axis=1)
decay_chunk = mx.exp(selective_cumsum(mx.pad(A_cumsum[..., -1], ((0,0), (0,0), (1,0)))))
new_states = mx.einsum('bhzc,bchpn->bzhpn', decay_chunk, states)
states = new_states[:, :-1]
state_decay_out = mx.exp(A_cumsum)
Y_off = mx.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)
Y = (Y_diag + Y_off).reshape((-1, x.shape[1] * chunk_size, *Y_diag.shape[-2:]))
return Y
def selective_cumsum(x: mx.array) -> mx.array:
"""Stable selective cumulative sum calculation."""
T = x.shape[-1]
x = mx.repeat(x[..., None], T, axis=-1)
mask = mx.tril(mx.ones((T, T)), k=-1)
x = x * mask
x_cumsum = mx.cumsum(x, axis=-2)
mask = mx.tril(mx.ones((T, T)), k=0)
return mx.where(mask, x_cumsum, float('-inf'))
class Mamba2Block(nn.Module): class Mamba2Mixer(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.args = args # Model dimensions
self.hidden_size = args.hidden_size
self.num_heads = args.num_heads
self.head_dim = args.head_dim
self.ssm_state_size = args.state_size
self.n_groups = args.n_groups
self.intermediate_size = int(args.expand * args.hidden_size)
# Internal cache state # Convolution parameters
self.conv_state = None self.conv_kernel = args.conv_kernel
self.ssm_state = None self.use_conv_bias = args.use_conv_bias
# Project input to get various components # Time step parameters
d_in_proj = (2 * args.intermediate_size + 2 * self.args.n_groups * args.state_size + args.num_heads) self.time_step_rank = int(args.time_step_rank)
self.time_step_min = args.time_step_min
self.time_step_max = args.time_step_max
# Processing parameters
self.chunk_size = args.chunk_size
self.layer_norm_epsilon = args.layer_norm_epsilon
# Calculate dimensions
self.conv_dim = (self.intermediate_size +
2 * self.n_groups * self.ssm_state_size)
projection_size = (self.intermediate_size +
self.conv_dim +
self.num_heads)
# Initialize layers
self.in_proj = nn.Linear( self.in_proj = nn.Linear(
args.hidden_size, self.hidden_size,
d_in_proj, projection_size,
bias=args.use_bias bias=args.use_bias
) )
# Convolution layer
conv_dim = args.intermediate_size + 2 * self.args.n_groups * args.state_size
self.conv1d = nn.Conv1d( self.conv1d = nn.Conv1d(
in_channels=conv_dim, in_channels=self.conv_dim,
out_channels=conv_dim, out_channels=self.conv_dim,
kernel_size=args.conv_kernel, kernel_size=self.conv_kernel,
groups=conv_dim, groups=self.conv_dim,
padding=args.conv_kernel - 1, padding=self.conv_kernel - 1,
bias=args.use_conv_bias bias=self.use_conv_bias
) )
# SSM parameters # Initialize parameters
dt_init_floor = math.log(args.time_step_floor) self.dt_bias = mx.ones(self.num_heads)
self.dt_bias = mx.zeros((args.num_heads,)) * args.initializer_range A = mx.arange(1, self.num_heads + 1)
self.A_log = mx.zeros((args.num_heads,)) * args.initializer_range self.A_log = mx.log(A)
self.D = mx.zeros((args.num_heads,)) * args.initializer_range self.D = mx.ones(self.num_heads)
# Output projections # Output layers
self.norm = nn.RMSNorm(args.intermediate_size, eps=args.layer_norm_epsilon) self.norm = MambaRMSNormGated(
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) self.intermediate_size,
eps=self.layer_norm_epsilon
def __call__(self, x: mx.array, cache=None) -> mx.array: )
return self.forward_training(x) if x.shape[1] > 1 else self.forward_inference(x, cache) self.out_proj = nn.Linear(
self.intermediate_size,
def forward_training(self, u: mx.array) -> mx.array: self.hidden_size,
# Reset cache during training bias=args.use_bias
self.cache = None
# Input projection and splitting
zxbcdt = self.in_proj(u)
z, xBC, dt = mx.split(
zxbcdt,
[
self.args.intermediate_size,
self.args.intermediate_size + 2 * self.args.state_size
],
axis=-1
) )
# Time step processing def reshape_into_chunks(self, tensor, pad_size, chunk_size):
if pad_size > 0:
pad_shape = list(tensor.shape)
pad_shape[1] = pad_size
padding = mx.zeros(pad_shape, dtype=tensor.dtype)
tensor = mx.concatenate([tensor, padding], axis=1)
chunk_shape = list(tensor.shape)
chunk_shape[1] = -1
chunk_shape.insert(2, chunk_size)
return tensor.reshape(chunk_shape)
def segment_sum(self, x):
return mx.cumsum(x, axis=-1)
def process_single_token(self, hidden_states, B, C, dt, cache):
batch_size = hidden_states.shape[0]
# Process convolution state
if cache is not None:
conv_state = cache.conv_states
# Roll the conv state and update the last position
conv_state = mx.roll(conv_state, shift=-1, axis=-1)
# Create new conv state with updated last position
new_conv_state = mx.array(conv_state)
new_conv_state = new_conv_state.at[:, :, -1].add(hidden_states)
conv_state = new_conv_state
# Compute convolution
conv_out = mx.sum(conv_state * self.conv1d.weight[:, 0, :], axis=-1)
if self.use_conv_bias:
conv_out = conv_out + self.conv1d.bias
# Apply SiLU activation
conv_out = mx.sigmoid(conv_out) * conv_out
else:
# Initialize new cache
conv_state = mx.zeros((batch_size, self.conv_dim, self.conv_kernel - 1))
conv_out = self.conv1d(hidden_states)
conv_out = mx.sigmoid(conv_out) * conv_out
# Process SSM
dt = mx.clip( dt = mx.clip(
nn.softplus(dt + self.dt_bias), nn.softplus(dt + self.dt_bias),
self.args.time_step_min, self.time_step_min,
self.args.time_step_max self.time_step_max
) )
# Convolution processing
xBC_t = mx.transpose(xBC, [0, 2, 1])
conv_out = self.conv1d(xBC_t)
xBC = mx.transpose(conv_out, [0, 2, 1])[:, :u.shape[1]]
xBC = mx.sigmoid(xBC) * xBC # SiLU
# Split states
x, B, C = mx.split(
xBC,
[self.args.intermediate_size, self.args.state_size],
axis=-1
)
# Reshape for selective scan
x = x.reshape((-1, x.shape[1], self.args.num_heads, self.args.head_dim))
A = -mx.exp(self.A_log) A = -mx.exp(self.A_log)
dA = mx.exp(dt * A[None, :])
# Apply selective scan if cache is not None:
y = selective_scan( ssm_state = cache.ssm_states
x * dt[..., None], else:
A * dt, ssm_state = mx.zeros(
B[..., None, :], (batch_size, self.num_heads, self.head_dim, self.ssm_state_size)
C[..., None, :], )
self.args.chunk_size
# Compute SSM updates
dBx = mx.einsum('bh,bhs,bhd->bhds', dt, B, hidden_states)
next_state = ssm_state * dA[:, :, None, None] + dBx
y = mx.einsum('bhds,bhs->bhd', next_state, C)
# Add skip connection
y = y + hidden_states * self.D[None, :, None]
return y, conv_state, next_state
def process_long_sequence(self, hidden_states, B, C, dt, ssm_state):
batch_size, seq_len = hidden_states.shape[:2]
pad_size = self.chunk_size - (seq_len % self.chunk_size)
# Reshape into chunks
x_chunks = self.reshape_into_chunks(hidden_states, pad_size, self.chunk_size)
B_chunks = self.reshape_into_chunks(B, pad_size, self.chunk_size)
C_chunks = self.reshape_into_chunks(C, pad_size, self.chunk_size)
# Process time steps
dt = nn.softplus(dt + self.dt_bias)
dt = mx.clip(dt, self.time_step_min)
# Prepare matrices
A = -mx.exp(self.A_log)
A = A * dt[:, None]
# Process chunks
A_chunks = self.reshape_into_chunks(
mx.broadcast_to(A, (batch_size, seq_len + pad_size, self.num_heads)),
pad_size,
self.chunk_size
) )
# Output processing # Compute cumulative sums
y = y + x * self.D[None, None, :, None] A_cumsum = mx.cumsum(A_chunks, axis=-1)
y = y.reshape((-1, y.shape[1], self.args.intermediate_size)) L = mx.exp(self.segment_sum(A_chunks))
y = self.norm(y, z)
y = self.out_proj(y)
return y # Process diagonal blocks
G = mx.einsum('...lhn,...shn->...lsh', C_chunks, B_chunks)
M = G * L[..., None, :]
Y_diag = mx.einsum('...lsh,...sh->...lh', M, x_chunks)
def forward_inference(self, u: mx.array, cache=None) -> mx.array: # Process off-diagonal blocks
"""Single token processing during inference.""" decay_states = mx.exp(A_cumsum[..., -1:] - A_cumsum)
assert u.shape[1] == 1, "Inference mode expects single token" B_decay = B_chunks * decay_states[..., None]
states = mx.einsum('...shn,...sh->...hn', B_decay, x_chunks)
batch_size = u.shape[0] # Combine results
# Use provided cache or create new one y = Y_diag + states
self.cache = cache if cache is not None else Mamba2Cache.get_cache(self.args, batch_size, None)
# Remove padding if necessary
if pad_size > 0:
y = y[:, :seq_len]
return y, ssm_state
def __call__(self, x: mx.array, cache: Optional[Mamba2Cache] = None) -> mx.array:
batch_size, seq_len, _ = x.shape
# Project input # Project input
zxbcdt = self.in_proj(mx.squeeze(u, 1)) projected_states = self.in_proj(x.squeeze(1))
parts = mx.split(
zxbcdt,
[
self.args.intermediate_size,
self.args.intermediate_size + 2 * self.args.state_size
],
axis=-1
)
z, xBC = parts[0], parts[1]
dt = zxbcdt[:, -self.args.num_heads:] # Extract dt separately
# Update convolution state and apply # Calculate d_mlp based on projection size
conv_state = self.cache.update_conv_state(xBC) d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 *
xBC = mx.sum( self.n_groups * self.ssm_state_size - self.num_heads) // 2
conv_state * mx.transpose(self.conv1d.weight, [1, 0, 2]),
axis=-1
)
if self.args.use_conv_bias:
xBC = xBC + self.conv1d.bias
xBC = mx.sigmoid(xBC) * xBC # SiLU
# Split states and ensure proper shapes # Split projections with corrected dimensions
x_splits = mx.split( splits = [
xBC, d_mlp, # z0
[self.args.intermediate_size, self.args.state_size], d_mlp, # x0
axis=-1 self.intermediate_size, # gate
) self.conv_dim, # hidden_states
x, B, C = x_splits[0], x_splits[1], x_splits[2] self.num_heads # dt
]
# Process time steps - ensure proper broadcasting z0, x0, x1, gate, hidden_states, dt = projected_states.split(splits, axis=-1)
dt = mx.reshape(dt, (batch_size, self.args.num_heads))
dt = mx.clip(
nn.softplus(dt + self.dt_bias[None, :]),
self.args.time_step_min,
self.args.time_step_max
)
# SSM step with explicit shapes # Split hidden states into components
A = -mx.exp(self.A_log) x_conv, BC = mx.split(hidden_states, [self.intermediate_size], axis=-1)
dA = mx.exp(dt * A[None, :]) # Shape: (batch_size, num_heads) B, C = mx.split(BC, [self.n_groups * self.ssm_state_size], axis=-1)
# Reshape x considering intermediate size # Process based on sequence length
# x shape should be (batch_size * num_heads, head_dim) if seq_len > 1 and cache is None:
x = mx.reshape(x, (batch_size, self.args.num_heads, -1)) y, next_state = self.process_long_sequence(
assert x.shape[-1] == self.args.head_dim, f"Head dimension mismatch: {x.shape[-1]} vs {self.args.head_dim}" x_conv, B, C, dt,
mx.zeros((batch_size, self.num_heads, self.head_dim, self.ssm_state_size))
)
else:
# Reshape for single token processing
x_conv = x_conv.reshape(batch_size, -1, self.head_dim)
B = B.reshape(batch_size, self.num_heads, -1)
C = C.reshape(batch_size, self.num_heads, -1)
y, conv_state, next_state = self.process_single_token(x_conv, B, C, dt, cache)
# Reshape B and C for ssm computation if cache is not None:
B = mx.reshape(B, (batch_size, -1)) # Should be (batch_size, state_size) cache.update(conv_state, next_state)
C = mx.reshape(C, (batch_size, -1)) # Should be (batch_size, state_size)
# Compute dBx with explicit shapes # Apply normalization and final projection
dBx = mx.einsum('bh,bs,bhd->bhds', dt, B, x) y = self.norm(y) * gate
return self.out_proj(y)
ssm_state = self.cache.update_ssm_state(dA, dBx)
y = mx.einsum('bhds,bs->bhd', ssm_state, C)
y = y + x * self.D[None, :, None]
y = mx.reshape(y, (batch_size, self.args.intermediate_size))
# Output processing
y = self.norm(y, z)
y = self.out_proj(y)
return mx.expand_dims(y, 1)
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.mixer = Mamba2Block(args) self.mixer = Mamba2Mixer(args)
self.norm = nn.RMSNorm(args.hidden_size) self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache=None) -> mx.array: def __call__(self, x: mx.array, cache: Optional[Mamba2Cache] = None) -> mx.array:
return self.mixer(self.norm(x), cache) + x return self.mixer(self.norm(x), cache) + x
class Mamba2Model(nn.Module): class Mamba2Model(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
@ -295,19 +308,20 @@ class Mamba2Model(nn.Module):
def __call__(self, x: mx.array, cache=None) -> mx.array: def __call__(self, x: mx.array, cache=None) -> mx.array:
x = self.embeddings(x) x = self.embeddings(x)
if cache is None: if cache is None:
cache = [None] * len(self.layers) cache = [None] * len(self.layers)
for layer, layer_cache in zip(self.layers, cache): for layer, layer_cache in zip(self.layers, cache):
x = layer(x, layer_cache) x = layer(x, layer_cache)
return self.norm_f(x)
return self.norm_f(x)
class Model(nn.Module): class Model(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.args = args self.args = args
self.backbone = Mamba2Model(args) self.backbone = Mamba2Model(args)
if not args.tie_word_embeddings: if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
@ -324,17 +338,24 @@ class Model(nn.Module):
return logits return logits
def make_cache(self, batch_size=1): def make_cache(self, batch_size=1):
return [Mamba2Cache( return [
batch_size=batch_size, Mamba2Cache(
intermediate_size=self.args.intermediate_size, batch_size=batch_size,
state_size=self.args.state_size, conv_dim=self.args.intermediate_size + 2 * self.args.n_groups * self.args.state_size,
conv_kernel=self.args.conv_kernel, kernel_size=self.args.conv_kernel,
num_heads=self.args.num_heads, num_heads=self.args.num_heads,
head_dim=self.args.head_dim head_dim=self.args.head_dim,
) for _ in range(len(self.backbone.layers))] state_size=self.args.state_size
)
for _ in range(len(self.backbone.layers))
]
def sanitize(self, weights): def sanitize(self, weights):
for k, v in weights.items(): for k, v in weights.items():
if "conv1d.weight" in k and v.ndim == 3: if "conv1d.weight" in k and v.ndim == 3:
weights[k] = v.moveaxis(2, 1) weights[k] = v.moveaxis(2, 1)
return weights return weights
@property
def layers(self):
return self.backbone.layers