save push

This commit is contained in:
Goekdeniz-Guelmez 2024-11-06 16:35:46 +01:00
parent 58b448dc0b
commit 906f972d36
10 changed files with 2777 additions and 1012 deletions

View File

@ -341,13 +341,28 @@ class MambaCache(_BaseCache):
self.cache = v
class Mamba2Cache:
def __init__(self, batch_size, conv_dim, kernel_size, num_heads, head_dim, state_size):
self.conv_states = mx.zeros((batch_size, conv_dim, kernel_size - 1))
self.ssm_states = mx.zeros((batch_size, num_heads, head_dim, state_size))
self.seqlen_offset = 0
def update(self, new_conv_state, new_ssm_state):
self.conv_states = new_conv_state
self.ssm_states = new_ssm_state
self.seqlen_offset += 1
class Mamba2Cache(_BaseCache):
def __init__(
self,
batch_size,
conv_kernel
):
self.conv_kernel: mx.array = conv_kernel
self.conv_states: mx.array = [None]
self.ssm_states = [None]
self.seqlen_offset = 0
def reset(self):
self.conv_states = None
self.ssm_state = None
def update(self, layer_idx: int, new_conv_state: mx.array, cache_position: mx.array) -> mx.array:
conv_state = self.conv_states[layer_idx]
cache_position = cache_position.clamp(0, self.conv_kernel - 1)
conv_state = conv_state.roll(shifts=-1, dims=-1)
conv_state[:, :, cache_position] = new_conv_state
self.conv_states[layer_idx].zero_()
self.conv_states[layer_idx] += conv_state
return self.conv_states[layer_idx]

View File

@ -1,424 +0,0 @@
import math
from dataclasses import dataclass, field
from typing import Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .cache import Mamba2Cache
@dataclass
class ModelArgs(BaseModelArgs):
num_heads: int
head_dim: int
vocab_size: int
hidden_size: int
state_size: int
num_hidden_layers: int
layer_norm_epsilon: float
expand: int
conv_kernel: int
n_groups: int
use_bias: bool
use_conv_bias: bool
initializer_range: float
residual_in_fp32: bool
time_step_min: float
time_step_max: float
time_step_floor: float
rescale_prenorm_residual: bool
use_cache: bool
rms_norm: bool
chunk_size: int
tie_word_embeddings: bool
intermediate_size: int = None
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
time_step_rank: Union[int, str] = "auto"
model_type: str = "mamba2"
def __post_init__(self):
self.intermediate_size = int(self.expand * self.hidden_size) # E*D = ED
if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_heads
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
def selective_scan(x, A, B, C, chunk_size):
"""
Selective scan implementation for training.
Arguments
x: (batch, seqlen, n_heads, d_head)
A: (batch, seqlen, n_heads)
B: (batch, seqlen, n_heads, d_state)
C: (batch, seqlen, n_heads, d_state)
Return
y: (batch, seqlen, n_heads, d_head)
"""
assert x.shape[1] % chunk_size == 0
# Reshape into chunks
def chunk_reshape(m):
shape = list(m.shape)
shape[1:2] = [shape[1] // chunk_size, chunk_size]
return m.reshape(shape)
x, A, B, C = map(chunk_reshape, (x, A, B, C))
A = mx.transpose(A, [0, 3, 1, 2])
# Compute cumulative sums
A_cumsum = mx.cumsum(A, axis=-1)
# Process chunks
L = mx.exp(selective_cumsum(A))
Y_diag = mx.einsum('bclhn,bcshn,bhcls,bcshp->bclhp', C, B, L, x)
decay_states = mx.exp(A_cumsum[..., -1:] - A_cumsum)
states = mx.einsum('bclhn,bhcl,bclhp->bchpn', B, decay_states, x)
initial_states = mx.zeros_like(states[:, :1])
states = mx.concatenate([initial_states, states], axis=1)
decay_chunk = mx.exp(selective_cumsum(mx.pad(A_cumsum[..., -1], ((0,0), (0,0), (1,0)))))
new_states = mx.einsum('bhzc,bchpn->bzhpn', decay_chunk, states)
states = new_states[:, :-1]
state_decay_out = mx.exp(A_cumsum)
Y_off = mx.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)
Y = (Y_diag + Y_off).reshape((-1, x.shape[1] * chunk_size, *Y_diag.shape[-2:]))
return Y
def selective_cumsum(x: mx.array) -> mx.array:
"""Stable selective cumulative sum calculation."""
T = x.shape[-1]
x = mx.repeat(x[..., None], T, axis=-1)
mask = mx.tril(mx.ones((T, T)), k=-1)
x = x * mask
x_cumsum = mx.cumsum(x, axis=-2)
mask = mx.tril(mx.ones((T, T)), k=0)
return mx.where(mask, x_cumsum, float('-inf'))
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
# Project input to get various components [z, x, B, C, dt]
projection_size = (2 * args.intermediate_size + 2 * args.n_groups * args.state_size + args.num_heads)
self.in_proj = nn.Linear(
args.hidden_size,
projection_size,
bias=args.use_bias
)
# Convolution layer
conv_dim = args.intermediate_size + 2 * args.n_groups * args.state_size
self.conv1d = nn.Conv1d(
in_channels=conv_dim,
out_channels=conv_dim,
kernel_size=args.conv_kernel,
groups=conv_dim,
padding=args.conv_kernel - 1,
bias=args.use_conv_bias
)
# SSM parameters
self.dt_bias = mx.zeros(args.num_heads)
self.A_log = mx.zeros(args.num_heads)
self.D = mx.ones(args.num_heads)
# Output projections
self.norm = nn.RMSNorm(args.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
def __call__(self, u: mx.array, cache=None) -> mx.array:
# return self.forward_training(x) if x.shape[1] > 1 else self.forward_inference(x, cache)
# def forward_training(self, u: mx.array) -> mx.array:
# # Reset cache during training
# self.cache = None
# # Input projection and splitting
# zxbcdt = self.in_proj(u)
# z, xBC, dt = mx.split(
# zxbcdt,
# [
# self.args.hidden_size,
# self.args.hidden_size + 2 * self.args.state_size
# ],
# axis=-1
# )
# # Time step processing
# dt = mx.clip(
# nn.softplus(dt + self.dt_bias),
# self.args.time_step_min,
# self.args.time_step_max
# )
# # Convolution processing
# xBC_t = mx.transpose(xBC, [0, 2, 1])
# conv_out = self.conv1d(xBC_t)
# xBC = mx.transpose(conv_out, [0, 2, 1])[:, :u.shape[1]]
# xBC = mx.sigmoid(xBC) * xBC # SiLU
# # Split states
# x, B, C = mx.split(
# xBC,
# [self.args.hidden_size, self.args.state_size],
# axis=-1
# )
# # Reshape for selective scan
# x = x.reshape((-1, x.shape[1], self.args.num_heads, self.args.head_dim))
# A = -mx.exp(self.A_log)
# # Apply selective scan
# y = selective_scan(
# x * dt[..., None],
# A * dt,
# B[..., None, :],
# C[..., None, :],
# self.args.chunk_size
# )
# # Output processing
# y = y + x * self.D[None, None, :, None]
# y = y.reshape((-1, y.shape[1], self.args.hidden_size))
# y = self.norm(y, z)
# y = self.out_proj(y)
# return y
# def forward_inference(self, u: mx.array, cache=None) -> mx.array:
# """
# u: (B, 1, D)
# cache: (h_cache, conv_cache)
# """
# """Single token processing during inference."""
# assert u.shape[1] == 1, "Inference mode expects single token"
# batch_size = u.shape[0]
# # Use provided cache or create new one
# self.cache = cache if cache is not None else Mamba2Cache.get_cache(self.args, batch_size, None)
# # Project input
# zxbcdt = self.in_proj(u.squeeze(1)) # (B, 2D)
# d_mlp = (zxbcdt.shape[-1] - 2 * self.args.hidden_size - 2 * self.args.n_groups * self.args.state_size - self.args.num_heads) // 2
# # (1, 768) (1, 0) (1, 0) (1, 256) (1, 0) (1, 3328)
# y0, z0, x0, z, xBC, dt = mx.split(
# zxbcdt,
# [
# d_mlp,
# d_mlp,
# self.args.hidden_size,
# self.args.hidden_size + 2 * self.args.n_groups * self.args.state_size,
# self.args.num_heads
# ],
# axis=-1
# )
# # Update convolution state and apply
# conv_state = self.cache.update_conv_state(xBC)
# xBC = mx.sum(conv_state[:, :, -1] * mx.transpose(self.conv1d.weight, [1, 0, 2]), axis=-1) # (B, D) (4, 1792)
# if self.args.use_conv_bias:
# xBC = xBC + self.conv1d.bias
# xBC = mx.sigmoid(xBC) * xBC # SiLU (4, 1792)
# # Split states and ensure proper shapes
# a0, x, B, C = mx.split(
# xBC, # (4, 1792)
# [
# self.args.hidden_size,
# self.args.n_groups * self.args.state_size,
# self.args.n_groups * self.args.state_size
# ],
# axis=-1
# )
# # SSM step with explicit shapes
# A = -mx.exp(self.A_log) # (num_heads) (24,)
# print(A.shape) # (24,)
# print(dt.shape) # (1, 3328)
# dA = mx.exp(dt * A[None, :]) # Shape: (batch_size, num_heads) <------- her eis the error
# # Reshape x considering intermediate size
# # x shape should be (batch_size * num_heads, head_dim)
# x = mx.reshape(x, (batch_size, self.args.num_heads, -1))
# assert x.shape[-1] == self.args.head_dim, f"Head dimension mismatch: {x.shape[-1]} vs {self.args.head_dim}"
# B = mx.reshape(B, (batch_size, -1)) # Should be (batch_size, state_size)
# C = mx.reshape(C, (batch_size, -1)) # Should be (batch_size, state_size)
# # Compute dBx with explicit shapes
# dBx = mx.einsum('bh,bs,bhd->bhds', dt, B, x)
# ssm_state = self.cache.update_ssm_state(dA, dBx)
# y = mx.einsum('bhds,bs->bhd', ssm_state, C)
# y = y + x * self.D[None, :, None]
# y = mx.reshape(y, (batch_size, self.args.hidden_size))
# # Output processing
# y = self.norm(y, z)
# if d_mlp > 0:
# y = mx.cat([nn.silu(z0) * x0, y], axis=-1)
# y = self.out_proj(y)
# return mx.expand_dims(y, 1)
assert u.shape[1] == 1, "Inference mode expects single token"
batch_size = u.shape[0]
# Use provided cache or create new one
self.cache = cache if cache is not None else Mamba2Cache.get_cache(self.args, batch_size, None)
# Project input
zxbcdt = self.in_proj(u.squeeze(1)) # (B, projection_size)
# Calculate splits based on model dimensions
d_mlp = self.args.intermediate_size
d_state = self.args.state_size * self.args.n_groups
# Split the projection into its components
splits = [
d_mlp, # y0
d_mlp, # z0
self.args.hidden_size, # x0
self.args.hidden_size, # z
d_state * 2, # xBC (includes both B and C)
self.args.num_heads # dt
]
y0, z0, x0, z, xBC, dt = mx.split(zxbcdt, splits[:-1], axis=-1)
# Update convolution state and apply
conv_state = self.cache.update_conv_state(xBC)
xBC = mx.sum(conv_state[:, :, -1] * mx.transpose(self.conv1d.weight, [1, 0, 2]), axis=-1)
if self.args.use_conv_bias:
xBC = xBC + self.conv1d.bias
xBC = mx.sigmoid(xBC) * xBC # SiLU
# Split states and reshape
x, BC = mx.split(xBC, [self.args.intermediate_size], axis=-1)
B, C = mx.split(BC, [d_state], axis=-1)
# Reshape for SSM computation
x = mx.reshape(x, (batch_size, self.args.num_heads, -1)) # (B, H, head_dim)
B = mx.reshape(B, (batch_size, self.args.num_heads, -1)) # (B, H, state_per_head)
C = mx.reshape(C, (batch_size, self.args.num_heads, -1)) # (B, H, state_per_head)
# Process dt to match expected shape
dt = mx.reshape(dt, (batch_size, self.args.num_heads)) # (B, H)
dt = mx.clip(
nn.softplus(dt + self.dt_bias),
self.args.time_step_min,
self.args.time_step_max
)
# SSM step
A = -mx.exp(self.A_log) # (H,)
dA = mx.exp(dt * A[None, :]) # (B, H)
# Compute dBx
dBx = mx.einsum('bh,bhs,bhd->bhds', dt, B, x)
# Update SSM state and compute output
ssm_state = self.cache.update_ssm_state(dA, dBx)
y = mx.einsum('bhds,bhs->bhd', ssm_state, C)
y = y + x * self.D[None, :, None]
# Reshape output
y = mx.reshape(y, (batch_size, self.args.hidden_size))
# Final output processing
y = self.norm(y, z)
if d_mlp > 0:
y = mx.concat([nn.silu(z0) * x0, y], axis=-1)
y = self.out_proj(y)
return mx.expand_dims(y, 1) # (B, 1, D)
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.mixer = Mamba2Block(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache=None) -> mx.array:
# x : (B, L, D)
return self.mixer(self.norm(x), cache) + x # (B, L, D)
class Mamba2Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(self, x: mx.array, cache=None) -> mx.array:
# x : (B, L)
x = self.embeddings(x)
# x : (B, L, D)
if cache is None:
cache = [None] * len(self.layers)
for layer, layer_cache in zip(self.layers, cache):
x = layer(x, layer_cache)
return self.norm_f(x)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.backbone = Mamba2Model(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None) -> mx.array:
# inputs : (B, L)
B, T = inputs.shape
x = self.backbone(inputs, cache)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(x)
else:
logits = self.lm_head(x)
return logits
def make_cache(self, batch_size=1):
return [Mamba2Cache(
batch_size=batch_size,
hidden_size=self.args.hidden_size,
state_size=self.args.state_size,
conv_kernel=self.args.conv_kernel,
num_heads=self.args.num_heads,
head_dim=self.args.head_dim
) for _ in range(len(self.backbone.layers))]
def sanitize(self, weights):
for k, v in weights.items():
if "conv1d.weight" in k and v.ndim == 3:
weights[k] = v.moveaxis(2, 1)
return weights

View File

@ -1,288 +0,0 @@
# Copyright © 2024 Apple Inc.
import math
from dataclasses import dataclass, field
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str = "mamba2"
num_heads: int = 128
head_dim: int = 64
vocab_size: int = 32768
hidden_size: int = 4096
state_size: int = 128
num_hidden_layers: int = 64
layer_norm_epsilon: float = 1e-5
expand: int = 2
conv_kernel: int = 4
n_groups: int = 8
use_bias: bool = False
use_conv_bias: bool = True
initializer_range: float = 0.1
residual_in_fp32: bool = True
time_step_rank: Union[int, str] = "auto"
time_step_min: float = 0.001
time_step_max: float = 0.1
time_step_floor: float = 1e-4
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
rescale_prenorm_residual: bool = False
use_cache: bool = True
rms_norm: bool = True
chunk_size: int = 256
tie_word_embeddings: bool = False
def __post_init__(self):
if not hasattr(self, "intermediate_size"):
self.intermediate_size = int(self.expand * self.hidden_size)
if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_heads
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
class Mamba2Cache:
def __init__(self):
self.cache = [None, None]
def __setitem__(self, idx, value):
self.cache[idx] = value
def __getitem__(self, idx):
return self.cache[idx]
@property
def state(self):
return self.cache
class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = mx.ones((hidden_size,))
self.variance_epsilon = eps
def __call__(self, hidden_states, gate=None):
if gate is not None:
hidden_states = hidden_states * nn.silu(gate)
variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True)
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
class DepthWiseConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = padding
self.groups = groups if groups is not None else in_channels
# Ensure in_channels and out_channels are the same for depthwise conv
assert in_channels == out_channels, "In and out channels must be the same for depthwise convolution"
# Ensure groups is equal to in_channels for depthwise conv
assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution"
# Initialize weight with shape (out_channels, kernel_size, 1)
self.weight = mx.random.normal((out_channels, kernel_size, 1))
self.bias = mx.zeros((out_channels,)) if bias else None
def __call__(self, x, cache=None):
B, L, C = x.shape
_, K, _ = self.weight.shape
if cache is not None:
x = mx.concatenate([cache, x], axis=1)
else:
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
y = mx.conv_general(x, self.weight, groups=self.groups)
if self.bias is not None:
y = y + self.bias
return y, x[:, -K + 1 :, :]
class Mamba2Mixer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.intermediate_size = args.intermediate_size
self.time_step_rank = args.time_step_rank
self.conv_kernel_size = args.conv_kernel
self.hidden_size = args.hidden_size
self.state_size = args.state_size
self.num_heads = args.num_heads
self.head_dim = args.hidden_size // args.num_heads
self.n_groups = args.n_groups
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size
self.conv1d = DepthWiseConv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
bias=args.use_conv_bias,
kernel_size=args.conv_kernel,
groups=self.conv_dim,
padding=args.conv_kernel - 1
)
projection_size = self.intermediate_size + self.conv_dim + self.num_heads
self.in_proj = nn.Linear(
self.hidden_size,
projection_size,
bias=args.use_bias
)
self.dt_bias = mx.ones((self.num_heads,))
self.A_log = mx.log(mx.arange(1, self.num_heads + 1))
self.D = mx.ones((self.num_heads,))
self.norm = MambaRMSNormGated(self.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias)
def ssm_step(self, x, state, dt_proj):
A = -mx.exp(self.A_log)
D = self.D
delta = nn.softplus(dt_proj + self.dt_bias)
B, C = mx.split(x, indices_or_sections=[self.state_size * self.n_groups], axis=-1)
B = B.reshape(-1, self.n_groups, self.state_size)
C = C.reshape(-1, self.n_groups, self.state_size)
if state is None:
new_state = mx.expand_dims(delta, -1) * B
else:
new_state = mx.expand_dims(delta, -1) * (B + state * mx.exp(mx.expand_dims(delta, -1) * A))
y = mx.sum(new_state * C, axis=-1)
y = y + D * x[:, :self.num_heads]
return y, new_state
def __call__(self, x, cache):
B, T, D = x.shape
if cache is None:
cache = [None, None]
outputs = []
for t in range(T):
xt = x[:, t, :]
xz = self.in_proj(xt)
x_t, z_t, dt_proj = mx.split(
xz,
indices_or_sections=[self.conv_dim, self.conv_dim + self.intermediate_size],
axis=-1
)
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
x_t = conv_out.squeeze(1)
x_t = nn.silu(x_t)
y_t, cache[1] = self.ssm_step(x_t, cache[1], dt_proj)
z_t = nn.silu(z_t)
# Print shapes for debugging
print(f"y_t shape: {y_t.shape}")
print(f"z_t shape: {z_t.shape}")
# Reshape y_t to (B, num_heads, head_dim)
y_t_reshaped = y_t.reshape(B, self.num_heads, -1)
# Reshape z_t to (B, num_heads, intermediate_size // num_heads)
z_t_reshaped = z_t.reshape(B, self.num_heads, -1)
print(f"y_t_reshaped shape: {y_t_reshaped.shape}")
print(f"z_t_reshaped shape: {z_t_reshaped.shape}")
# Element-wise multiplication (broadcasting across the last dimension)
output_t = y_t_reshaped * z_t_reshaped
# Reshape to match the expected input of out_proj
output_t = output_t.reshape(B, -1)
print(f"output_t shape before out_proj: {output_t.shape}")
print(f"out_proj weight shape: {self.out_proj.weight.shape}")
output_t = self.out_proj(output_t)
outputs.append(output_t)
output = mx.stack(outputs, axis=1)
return output
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.mixer = Mamba2Mixer(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
return self.mixer(self.norm(x), cache) + x
class Mamba2(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [Mamba2Block(args) for idx in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(
self,
inputs: mx.array,
cache=None
):
hidden_states = self.embeddings(inputs)
if cache is None:
cache = Mamba2Cache(len(self.layers))
for i, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, cache[i])
hidden_states = self.norm_f(hidden_states)
return hidden_states
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba2(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None):
B, T = inputs.shape
x = self.backbone(inputs, cache)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(x)
else:
logits = self.lm_head(x)
return logits
def sanitize(self, weights):
for k, v in weights.items():
if "conv1d.weight" in k and v.ndim == 3:
weights[k] = v.moveaxis(2, 1)
return weights
def make_cache(self, batch_size: int = 1):
return [Mamba2Cache() for _ in range(len(self.layers))]
@property
def layers(self):
return self.backbone.layers

View File

@ -0,0 +1,449 @@
# coding=utf-8
# Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch MAMBA2 model."""
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
logger = logging.get_logger(__name__)
def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int):
"""
Padding x tensor with `pad_size` on the seq_len dim (dim=1)
Assumes that we only have tensors of either size 4 or 3
"""
pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0)
return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0)
def reshape_into_chunks(input_tensor, pad_size, chunk_size):
"""
Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and
simultaneously splitting it into chunk sequences.
Assumes that we only have tensors of either size 4 or 3
"""
# [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...]
input_tensor = pad_tensor_by_size(input_tensor, pad_size)
if len(input_tensor.shape) == 3:
# [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads]
return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2])
else:
# [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size]
return input_tensor.reshape(
input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3]
)
def segment_sum(input_tensor):
"""
More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.
"""
chunk_size = input_tensor.size(-1)
# 1. expand input tensor to have an additional dimension and repeat along that dimension
# [..., chunk_size] -> [..., chunk_size, chunk_size]
input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size)
# 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag
mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1)
input_tensor = input_tensor.masked_fill(~mask, 0)
# 3. compute actual cumsum
tensor_segsum = torch.cumsum(input_tensor, dim=-2)
# 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time)
mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0)
tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf)
return tensor_segsum
class Mamba2Cache:
"""
Arguments:
config: ModelArgs
batch_size: int
dtype: torch.dtype
device: torch.device
Attributes:
seqlen_offset: int
dtype: torch.dtype
conv_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, conv_kernel]
ssm_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, ssm_state_size]
"""
def __init__(
self, config: ModelArgs, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None
):
self.seqlen_offset = 0
self.dtype = dtype
self.conv_kernel = config.conv_kernel
self.intermediate_size = int(config.expand * config.hidden_size)
self.conv_states = {
i: torch.zeros(
batch_size,
self.intermediate_size + 2 * config.n_groups * config.state_size,
self.conv_kernel,
device=device,
dtype=dtype,
)
for i in range(config.num_hidden_layers)
}
self.ssm_states = {
i: torch.zeros(
batch_size, config.num_heads, config.head_dim, config.state_size, device=device, dtype=dtype
)
for i in range(config.num_hidden_layers)
}
def update_conv_state(
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
) -> torch.Tensor:
conv_state = self.conv_states[layer_idx]
cache_position = cache_position.clamp(0, self.conv_kernel - 1)
conv_state = conv_state.roll(shifts=-1, dims=-1)
conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
self.conv_states[layer_idx].zero_()
self.conv_states[layer_idx] += conv_state
return self.conv_states[layer_idx]
def reset(self):
self.conv_states.zero_()
self.ssm_states.zero_()
class MambaRMSNormGated(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states, gate=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states
if gate is not None:
hidden_states = hidden_states * nn.functional.silu(gate)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
class Mamba2Mixer(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.num_heads = config.num_heads
self.hidden_size = config.hidden_size
self.state_size = config.state_size
self.conv_kernel = config.conv_kernel
self.intermediate_size = int(config.expand * self.hidden_size)
self.time_step_rank = int(config.time_step_rank)
self.use_conv_bias = config.use_conv_bias
self.layer_norm_epsilon = config.layer_norm_epsilon
self.n_groups = config.n_groups
self.head_dim = config.head_dim
self.chunk_size = config.chunk_size
self.time_step_limit = config.time_step_limit
self.time_step_min = config.time_step_min
self.time_step_max = config.time_step_max
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size
self.conv1d = nn.Conv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
bias=config.use_conv_bias,
kernel_size=config.conv_kernel,
groups=self.conv_dim,
padding=config.conv_kernel - 1,
)
# projection of the input hidden states
projection_size = self.intermediate_size + self.conv_dim + self.num_heads
self.in_proj = nn.Linear(
self.hidden_size,
projection_size,
bias=config.use_bias,
)
self.dt_bias = torch.ones(self.num_heads)
A = torch.arange(1, self.num_heads + 1)
self.A_log = torch.log(A)
self.D = torch.ones(self.num_heads)
self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon)
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
def forward(self, input_states, cache: Optional[Mamba2Cache]=None):
batch_size, seq_len, _ = input_states.shape
# Gated MLP's linear projection
projected_states = self.in_proj(input_states.squeeze(1))
d_mlp = (
projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.state_size- self.num_heads) // 2
_, _, gate, hidden_states, dt = projected_states.split(
[d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
)
# Convolution sequence transformation
ssm_state = cache.ssm_states[self.layer_idx].clone()
ssm_state = ssm_state.to(hidden_states.device)
if cache.seqlen_offset > 0:
conv_state = cache.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel]
conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
# handle batched generation - states are copied through
conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states
cache.conv_states[self.layer_idx].copy_(conv_state)
hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
if self.use_conv_bias:
hidden_states += self.conv1d.bias
hidden_states = nn.silu(hidden_states)[:, None, ...] # [batch, 1, intermediate_size] : decoding
else:
hidden_states = hidden_states.transpose(1,2)
conv_state = nn.functional.pad(
hidden_states,
(self.conv_kernel - hidden_states.shape[-1], 0)
)
cache.conv_states[self.layer_idx].copy_(conv_state)
hidden_states = nn.silu(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len]
hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.state_size, self.n_groups * self.state_size], dim=-1)
A = -torch.exp(self.A_log.float()) # [num_heads]
if cache is not None and cache.seqlen_offset > 0:
# Note: there is no need to pad parameter matrices here, as there is just one new token
# for batched generation
dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...]
dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim)
# [num_heads] -> [num_heads, head_dim]
dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)
dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype))
dt = torch.clamp(dt, self.time_step_min) #, self.time_step_max)
A = A[..., None, None].expand(self.num_heads, self.head_dim, self.state_size).to(dtype=torch.float32)
# [bsz, num_heads, head_dim, state_size]
dA = torch.exp(dt[..., None] * A)
# Discretize B
# [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
# -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size]
B = B.reshape(batch_size, self.n_groups, -1)[..., None, :]
B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous()
B = B.reshape(batch_size, -1, B.shape[-1])
# [bsz, num_heads, head_dim, state_size]
dB = dt[..., None] * B[..., None, :]
# Discretize x into dB
# [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
dBx = dB * hidden_states[..., None]
# State calculation
cache.ssm_states[self.layer_idx].copy_(
cache.ssm_states[self.layer_idx] * dA + dBx
)
# Subsequent output
# [bsz, n_groups * state_size] -> [bsz, num_heads, state_size]
C = C.reshape(batch_size, self.n_groups, -1)[..., None, :]
C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous()
C = C.reshape(batch_size, -1, C.shape[-1])
# [bsz, num_heads, head_dim]
ssm_states = cache.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n]
# Reshape ssm_states to merge the first two dimensions
ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.state_size) # Shape: [b*h, d, n]
C_reshaped = C.view(batch_size * self.num_heads, self.state_size, 1) # Shape: [b*h, n, 1]
y = torch.bmm(ssm_states_reshaped, C_reshaped)
y = y.view(batch_size, self.num_heads, self.head_dim)
# D skip connection
# [num_heads] -> [num_heads, head_dim]
D = self.D[..., None].expand(self.D.shape[0], self.head_dim)
y = (y + hidden_states * D).to(y.dtype)
# [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size]
y = y.reshape(batch_size, -1)[:, None, ...]
else:
# begin ssd naive implementation without einsums
dt = nn.functional.softplus(dt + self.dt_bias)
dt = torch.clamp(dt, self.time_step_min)
hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
B = B.reshape(batch_size, seq_len, -1, self.state_size).float()
C = C.reshape(batch_size, seq_len, -1, self.state_size).float()
B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
pad_size = self.chunk_size - (seq_len % self.chunk_size)
D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
# Discretize x and A
hidden_states = hidden_states * dt[..., None]
A = A.to(hidden_states.dtype) * dt
# Rearrange into blocks/chunks
hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]
# [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size]
A = A.permute(0, 3, 1, 2)
A_cumsum = torch.cumsum(A, dim=-1)
# 1. Compute the output for each intra-chunk (diagonal blocks)
# This is the analog of a causal mask
L = torch.exp(segment_sum(A))
# First, contraction of C and B to get G (attention-weights like)
G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n)
G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)
# Step 2: Compute M, equivalent to applying attention mask to weights
M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
M = M_intermediate.sum(dim=-1)
# Step 3: Compute Y_diag (apply to values)
Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3)
# (right term of low-rank factorization of off-diagonal blocks; B terms)
decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None]
# permute back B * decay states
states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3)
if cache is not None and cache.seqlen_offset > 0:
previous_states = cache.ssm_states[self.layer_idx][:, None, ...]
else:
previous_states = torch.zeros_like(states[:, :1])
states = torch.cat([previous_states, states], dim=1)
decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
states_permuted = states.permute(0, 2, 1, 3, 4)
result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2)
new_states = result.permute(0, 2, 1, 3, 4)
states, ssm_state = new_states[:, :-1], new_states[:, -1]
# Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = torch.exp(A_cumsum)
# compute Yoff
C_times_states = (C[..., None, :] * states[:, :, None, ...])
state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
y = Y_diag + Y_off
# [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
y = y + D_residual
# Cutting off padded chunks
if pad_size > 0:
y = y[:, :seq_len, :, :]
y = y.reshape(batch_size, seq_len, -1)
if ssm_state is not None and cache is not None:
cache.ssm_states[self.layer_idx].copy_(ssm_state)
scan_output = self.norm(y, gate)
# end ssd naive
# 4. Final linear projection
contextualized_states = self.out_proj(scan_output) # [batch, seq_len, hidden_size]
return contextualized_states
class Mamba2Block(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.norm = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mixer = Mamba2Mixer(config)
def forward(
self,
hidden_states,
cache: Optional[Mamba2Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
):
x = self.mixer(
self.norm(hidden_states), cache=cache, cache_position=cache_position
)
return x + hidden_states
class Mamba2Model(nn.Module):
def __init__(self, config):
super().__init__(config)
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([Mamba2Block(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
cache: Optional[Mamba2Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.embeddings(input_ids)
hidden_states = inputs_embeds
for mixer_block in self.layers:
hidden_states = mixer_block(
hidden_states,
cache=cache,
cache_position=cache_position,
)
cache.seqlen_offset += inputs_embeds.shape[1]
return self.norm_f(hidden_states), cache
class Mamba2ForCausalLM(nn.Module):
def __init__(self, config):
super().__init__(config)
self.backbone = Mamba2Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
cache: Optional[Mamba2Cache] = None,
cache_position: Optional[torch.Tensor] = None,
):
out, cache = self.backbone(
input_ids,
cache=cache,
cache_position=cache_position,
)
logits = self.lm_head(out)
return logits, cache

View File

@ -23,9 +23,42 @@ import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...modeling_utils import PreTrainedModel
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
)
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
from .configuration_mamba2 import Mamba2Config
logger = logging.get_logger(__name__)
if is_mamba_2_ssm_available():
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
else:
selective_state_update = None
if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
else:
causal_conv1d_update, causal_conv1d_fn = None, None
is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
_CHECKPOINT_FOR_DOC = "mistralai/mamba-codestral-7B-v0.1"
_CONFIG_FOR_DOC = "Mamba2Config"
# Helper methods for segment sum computation
def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int):
"""
Padding x tensor with `pad_size` on the seq_len dim (dim=1)
@ -80,7 +113,7 @@ def segment_sum(input_tensor):
class Mamba2Cache:
"""
Arguments:
config: ModelArgs
config: Mamba2Config
batch_size: int
dtype: torch.dtype
device: torch.device
@ -93,7 +126,7 @@ class Mamba2Cache:
"""
def __init__(
self, config: ModelArgs, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None
self, config: Mamba2Config, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None
):
self.seqlen_offset = 0
self.dtype = dtype
@ -116,6 +149,8 @@ class Mamba2Cache:
)
for i in range(config.num_hidden_layers)
}
self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]
def update_conv_state(
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
@ -142,18 +177,25 @@ class MambaRMSNormGated(torch.nn.Module):
def forward(self, hidden_states, gate=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states
hidden_states = hidden_states.to(torch.float32)
if gate is not None:
hidden_states = hidden_states * nn.functional.silu(gate)
hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32))
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
return self.weight * hidden_states.to(input_dtype)
class Mamba2Mixer(nn.Module):
def __init__(self, config: ModelArgs):
"""
Compute , A, B, C, and D the state space parameters and compute the `contextualized_states`.
A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
and is why Mamba is called **selective** state spaces)
"""
def __init__(self, config: Mamba2Config, layer_idx: int):
super().__init__()
self.num_heads = config.num_heads
self.hidden_size = config.hidden_size
@ -161,8 +203,10 @@ class Mamba2Mixer(nn.Module):
self.conv_kernel_size = config.conv_kernel
self.intermediate_size = int(config.expand * self.hidden_size)
self.time_step_rank = int(config.time_step_rank)
self.layer_idx = layer_idx
self.use_conv_bias = config.use_conv_bias
self.act = nn.silu
self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]
self.layer_norm_epsilon = config.layer_norm_epsilon
self.rms_norm = config.rms_norm
@ -192,23 +236,178 @@ class Mamba2Mixer(nn.Module):
projection_size,
bias=config.use_bias,
)
# selective projection used to make dt, B and C input dependant
self.dt_bias = torch.ones(self.num_heads)
# time step projection (discretization)
# instantiate once and copy inv_dt in init_weights of PretrainedModel
self.dt_bias = nn.Parameter(torch.ones(self.num_heads))
# S4D real initialization. These are not discretized!
# The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
A = torch.arange(1, self.num_heads + 1)
self.A_log = torch.log(A)
self.D = torch.ones(self.num_heads)
self.A_log = nn.Parameter(torch.log(A))
self.A_log._no_weight_decay = True
self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon)
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
self.D = nn.Parameter(torch.ones(self.num_heads))
self.D._no_weight_decay = True
def forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, cache_position:Optional[torch.LongTensor]=None):
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
self.use_bias = config.use_bias
if not is_fast_path_available:
logger.warning_once(
"The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
" is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and"
" https://github.com/Dao-AILab/causal-conv1d"
)
def cuda_kernels_forward(
self,
hidden_states: torch.Tensor,
cache_params: Optional[Mamba2Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
):
# set up dimensions for reshapes later
batch_size, seq_len, _ = hidden_states.shape
groups_time_state_size = self.n_groups * self.ssm_state_size
d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads
# getting projected states from cache if it exists
if cache_params is not None and cache_params.seqlen_offset > 0:
in_projected_states = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2
split_projection_dim = [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads]
_, _, gate, hidden_states_B_C, dt = torch.split(in_projected_states, split_projection_dim, dim=-1)
hidden_states_B_C = causal_conv1d_update(
hidden_states_B_C,
cache_params.conv_states[self.layer_idx],
self.conv1d.weight.squeeze(1),
self.conv1d.bias,
self.activation,
)
hidden_states, B, C = torch.split(
hidden_states_B_C,
[self.intermediate_size, groups_time_state_size, groups_time_state_size],
dim=-1,
)
A = -torch.exp(self.A_log.float()) # (nheads,)
A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
dt = dt[:, :, None].expand(-1, -1, self.head_dim)
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
D = self.D[:, None, ...].expand(-1, self.head_dim)
B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim)
hidden_states = selective_state_update(
cache_params.ssm_states[self.layer_idx],
hidden_states_reshaped,
dt,
A,
B,
C,
D,
z=None,
dt_bias=dt_bias,
dt_softplus=True,
)
hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
hidden_states = self.norm(hidden_states, gate)
out = self.out_proj(hidden_states)[:, None, ...]
# if no cache is found, calling the kernel
else:
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
dtype = hidden_states.dtype
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)
A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size)
dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit}
if self.training and cache_params is None:
out, ssm_state = mamba_split_conv1d_scan_combined(
projected_states,
self.conv1d.weight.squeeze(1),
self.conv1d.bias,
self.dt_bias,
A,
D=self.D,
chunk_size=self.chunk_size,
seq_idx=None, # was seq_idx
activation=self.activation,
rmsnorm_weight=self.norm.weight,
rmsnorm_eps=self.norm.variance_epsilon,
outproj_weight=self.out_proj.weight,
outproj_bias=self.out_proj.bias,
headdim=self.head_dim,
ngroups=self.n_groups,
norm_before_gate=False,
return_final_states=True,
**dt_limit_kwargs,
)
else:
gate, hidden_states_B_C, time_step = torch.split(
projected_states,
[self.intermediate_size, self.conv_dim, self.num_heads],
dim=-1,
)
time_step = nn.functional.softplus(time_step + self.dt_bias)
# 1D Convolution
if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
hidden_states_B_C = self.act(
self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len]
) # (B, L, self.d_inner + 2 * ngroups * d_state)
else:
hidden_states_B_C = causal_conv1d_fn(
x=hidden_states_B_C.transpose(1, 2),
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation=self.activation,
).transpose(1, 2)[:, :seq_len]
hidden_states, B, C = torch.split(
hidden_states_B_C,
[self.intermediate_size, groups_time_state_size, groups_time_state_size],
dim=-1,
)
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
dtype = hidden_states.dtype
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
scan_output, ssm_state = mamba_chunk_scan_combined(
hidden_states.view(batch_size, seq_len, -1, self.head_dim),
time_step,
A,
B.view(batch_size, seq_len, self.n_groups, -1),
C.view(batch_size, seq_len, self.n_groups, -1),
chunk_size=self.chunk_size,
D=self.D,
z=None,
seq_idx=None,
return_final_states=True,
**dt_limit_kwargs,
)
if ssm_state is not None and cache_params is not None:
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
scan_output = scan_output.view(batch_size, seq_len, -1)
# Multiply "gate" branch and apply extra normalization layer
scan_output = self.norm(scan_output, gate)
out = self.out_proj(scan_output)
return out
# fmt: off
def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
# Gated MLP's linear projection
projected_states = self.in_proj(input_states.squeeze(1))
d_mlp = (
projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size- self.num_heads) // 2
d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size- self.num_heads) // 2
_, _, gate, hidden_states, dt = projected_states.split(
[d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
)
@ -223,10 +422,10 @@ class Mamba2Mixer(nn.Module):
# handle batched generation - states are copied through
conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states
cache_params.conv_states[self.layer_idx].copy_(conv_state)
hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1)
if self.use_conv_bias:
hidden_states += self.conv1d.bias
hidden_states = self.act(hidden_states)[:, None, ...] # [batch, 1, intermediate_size] : decoding
hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding
else:
hidden_states = hidden_states.transpose(1,2)
conv_state = nn.functional.pad(
@ -235,16 +434,18 @@ class Mamba2Mixer(nn.Module):
)
cache_params.conv_states[self.layer_idx].copy_(conv_state)
hidden_states = self.act(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len]
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
dtype = hidden_states.dtype
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
else:
ssm_state = torch.zeros(
(batch_size, self.num_heads, self.head_dim, self.ssm_state_size),
device=hidden_states.device
device=hidden_states.device, dtype=dtype
)
hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2))
hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1)
A = -torch.exp(self.A_log.float()) # [num_heads]
if cache_params is not None and cache_params.seqlen_offset > 0:
# Note: there is no need to pad parameter matrices here, as there is just one new token
# for batched generation
@ -384,8 +585,25 @@ class Mamba2Mixer(nn.Module):
# end ssd naive
# 4. Final linear projection
contextualized_states = self.out_proj(scan_output) # [batch, seq_len, hidden_size]
contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size]
return contextualized_states
# fmt: on
def forward(
self,
hidden_states,
cache_params: Optional[Mamba2Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
):
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
dtype = hidden_states.dtype
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)
class Mamba2RMSNorm(nn.Module):
@ -399,47 +617,258 @@ class Mamba2RMSNorm(nn.Module):
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
return self.weight * hidden_states.to(input_dtype)
class Mamba2Block(nn.Module):
def __init__(self, config):
def __init__(self, config, layer_idx):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.residual_in_fp32 = config.residual_in_fp32
self.norm = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mixer = Mamba2Mixer(config)
self.mixer = Mamba2Mixer(config, layer_idx=layer_idx)
def forward(
self,
hidden_states,
cache_params: Optional[Mamba2Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
):
x = self.mixer(
self.norm(hidden_states), cache_params=cache_params, cache_position=cache_position
residual = hidden_states
hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
hidden_states = self.mixer(
hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask
)
return x + hidden_states
hidden_states = residual + hidden_states
return hidden_states
class Mamba2Model(nn.Module):
class Mamba2PreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = Mamba2Config
base_model_prefix = "backbone"
_no_split_modules = ["Mamba2Block"]
supports_gradient_checkpointing = True
_is_stateful = True
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, Mamba2Mixer):
module.A_log._no_weight_decay = True
module.D._no_weight_decay = True
dt = torch.exp(
torch.rand(self.config.num_heads)
* (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
+ math.log(self.config.time_step_min)
).clamp(min=self.config.time_step_floor)
# # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
module.dt_bias.copy_(inv_dt)
module.dt_bias._no_reinit = True
if isinstance(module, nn.Linear):
if module.bias is not None:
if not getattr(module.bias, "_no_reinit", False):
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=self.config.initializer_range)
if self.config.rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["out_proj.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
# We need to reinit p since this code could be called multiple times
# Having just p *= scale would repeatedly scale it down
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
with torch.no_grad():
p /= math.sqrt(self.config.num_hidden_layers)
@dataclass
# Copied from transformers.models.mamba.modeling_mamba.MambaOutput with MAMBA->MAMBA2,Mamba->Mamba2
class Mamba2Output(ModelOutput):
"""
Class for the MAMBA2 model outputs.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
cache_params (`Mamba2Cache`):
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
avoid providing the old `input_ids`.
Includes both the State space model state matrices after the selective scan, and the Convolutional states
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""
last_hidden_state: Optional[torch.FloatTensor] = None
cache_params: Optional[Mamba2Cache] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
# Copied from transformers.models.mamba.modeling_mamba.MambaCausalLMOutput with Mamba->Mamba2
class Mamba2CausalLMOutput(ModelOutput):
"""
Base class for causal language model (or autoregressive) outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
cache_params (`Mamba2Cache`):
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
avoid providing the old `input_ids`.
Includes both the State space model state matrices after the selective scan, and the Convolutional states
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
cache_params: Optional[Mamba2Cache] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
MAMBA2_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`Mamba2Config`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
MAMBA2_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
Indices of input sequence tokens in the vocabulary.
If `cache_params.seqlen_offset>0`, only `input_ids` that do not have their past calculated should be passed as
`input_ids`.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
cache_params (`Mamba2Cache`, *optional*):
If passed along, the model uses the previous state in all the blocks (which will give the output for the
`input_ids` provided as if the model add `state_input_ids + input_ids` as context).
use_cache (`bool`, *optional*):
If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The bare MAMBA2 Model transformer outputting raw hidden-states without any specific head on top.",
MAMBA2_START_DOCSTRING,
)
class Mamba2Model(Mamba2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([Mamba2Block(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
self.norm_f = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
# Initialize weights and apply final processing
self._register_load_state_dict_pre_hook(self.load_hook)
self.post_init()
def load_hook(self, state_dict, prefix, *args):
for k in state_dict:
if "embedding." in k:
state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k)
break
def get_input_embeddings(self):
return self.embeddings
def set_input_embeddings(self, new_embeddings):
self.embeddings = new_embeddings
@add_start_docstrings_to_model_forward(MAMBA2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Mamba2Output,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
cache_params: Optional[Mamba2Cache] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.embeddings(input_ids)
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[Tuple, Mamba2Output]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)
if self.gradient_checkpointing and self.training and use_cache:
use_cache = False
if use_cache:
if cache_params is None:
@ -447,44 +876,206 @@ class Mamba2Model(nn.Module):
self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
)
cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device)
elif cache_position is None:
# cases when we do manual forward instead of using `model.generate` which will initiate
# `cache_position` and makes sure it is not None, throw error here instead of doing some
# hack to conjecture the current cache position
raise ValueError(
"You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, "
"you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will "
"be initialized for you automatically"
)
else:
cache_params = None
hidden_states = inputs_embeds
all_hidden_states = () if output_hidden_states else None
for mixer_block in self.layers:
hidden_states = mixer_block(
hidden_states,
cache_params=cache_params,
cache_position=cache_position,
)
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask
)
else:
hidden_states = mixer_block(
hidden_states,
cache_params=cache_params,
cache_position=cache_position,
attention_mask=attention_mask,
)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if use_cache:
cache_params.seqlen_offset += inputs_embeds.shape[1]
return self.norm_f(hidden_states), cache_params if use_cache else None
hidden_states = self.norm_f(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
return Mamba2Output(
last_hidden_state=hidden_states,
cache_params=cache_params if use_cache else None,
hidden_states=all_hidden_states,
)
@add_start_docstrings(
"""
The MAMBA2 Model transformer with a language modeling head on top (linear layer with weights not tied to the input
embeddings).
""",
MAMBA2_START_DOCSTRING,
)
class Mamba2ForCausalLM(Mamba2PreTrainedModel):
_tied_weights_keys = []
class Mamba2ForCausalLM(nn.Module):
def __init__(self, config):
super().__init__(config)
self.backbone = Mamba2Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def get_input_embeddings(self):
return self.backbone.get_input_embeddings()
def set_input_embeddings(self, new_embeddings):
return self.backbone.set_input_embeddings(new_embeddings)
def prepare_inputs_for_generation(
self,
input_ids,
inputs_embeds=None,
use_cache=None,
cache_params: Optional[Mamba2Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
):
if inputs_embeds is not None:
past_len = inputs_embeds.shape[1] + input_ids.shape[1]
else:
past_len = input_ids.shape[1]
if use_cache:
# `cache_position` should have been initialized in `generate`
if cache_position is None:
raise ValueError(
"`cache_position` should not be None as it should have been initialized in "
"`model.generate`, you are responsible for passing in a valid `cache_position` if "
"you are calling `prepare_inputs_for_generation` directly with `use_cache=True`"
)
# how do we detect that we are in decoding without cache?
if cache_position[0] > 0:
input_ids = input_ids[:, -1][..., None]
attention_mask = attention_mask[:, -1][..., None]
else:
# we initialize the `cache_position` to full size of `conv_states` at prefill stage
# considering padding will be applied when input length is shorter, and truncation
# will be applied when it is longer, so it will be equivalent to always have it match
# the length of `cache_params.conv_states`, which is `config.conv_kernel`
cache_position = torch.arange(0, past_len, device=input_ids.device)
# if the cache is not used, we also do have to extend the attention mask here
# TODO there is likely a cleverer way to do this
extended_mask = torch.ones(
attention_mask.size(0), past_len - attention_mask.shape[1], device=attention_mask.device
)
attention_mask = torch.cat([attention_mask, extended_mask], dim=1)
cache_params = None
if attention_mask.shape[1] < past_len:
# we have to update manually the attention mask if
# we are in decoding without cache
# and we don't have position_ids here
# TODO but we should be able to use cache_position though at a later time
extended_mask = torch.ones(
attention_mask.size(0), past_len - attention_mask.shape[1], device=attention_mask.device
)
attention_mask = torch.cat([attention_mask, extended_mask], dim=1)
if inputs_embeds is not None and cache_params is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"attention_mask": attention_mask,
"cache_params": cache_params,
"use_cache": use_cache,
"cache_position": cache_position,
}
)
return model_inputs
@add_start_docstrings_to_model_forward(MAMBA2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Mamba2CausalLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_params: Optional[Mamba2Cache] = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.Tensor] = None,
):
attention_mask: Optional[torch.Tensor] = None,
**kwargs, # for now we need this for generation
) -> Union[Tuple, Mamba2CausalLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
mamba2_outputs = self.backbone(
input_ids,
cache_params=cache_params,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
cache_position=cache_position,
attention_mask=attention_mask,
)
hidden_states = mamba2_outputs[0]
logits = self.lm_head(hidden_states)
return logits, mamba2_outputs.cache_params, mamba2_outputs.hidden_states
logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (logits,) + mamba2_outputs[1:]
return ((loss,) + output) if loss is not None else output
return Mamba2CausalLMOutput(
loss=loss,
logits=logits,
cache_params=mamba2_outputs.cache_params,
hidden_states=mamba2_outputs.hidden_states,
)

View File

@ -0,0 +1,337 @@
import math
from dataclasses import dataclass, field
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .cache import Mamba2Cache
@dataclass
class ModelArgs(BaseModelArgs):
num_heads: int
head_dim: int
vocab_size: int
hidden_size: int
state_size: int
num_hidden_layers: int
layer_norm_epsilon: float
expand: int
conv_kernel: int
n_groups: int
use_bias: bool
use_conv_bias: bool
initializer_range: float
residual_in_fp32: bool
time_step_min: float
time_step_max: float
time_step_floor: float
rescale_prenorm_residual: bool
rms_norm: bool
chunk_size: int
tie_word_embeddings: bool
use_cache: bool = True
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
time_step_rank: Union[int, str] = "auto"
model_type: str = "mamba2"
def __post_init__(self):
if not hasattr(self, "intermediate_size"):
self.intermediate_size = int(self.expand * self.hidden_size)
if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_heads
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = mx.ones((hidden_size,))
self.variance_epsilon = eps
def __call__(self, hidden_states, gate=None):
if gate is not None:
hidden_states = hidden_states * nn.silu(gate)
variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True)
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
def silu(x):
return x * mx.sigmoid(x)
def ssd(x, A, B, C, chunk_size):
# Replace einsum operations with explicit reshape and matrix multiply
batch, seqlen, nheads, dim = x.shape
B = mx.expand_dims(B, axis=2)
C = mx.expand_dims(C, axis=2)
state = mx.zeros((batch, nheads, dim, B.shape[-1]))
outputs = []
for i in range(0, seqlen, chunk_size):
chunk = slice(i, min(i + chunk_size, seqlen))
dA = mx.exp(mx.expand_dims(A[chunk], axis=0))
# Replace einsum with explicit operations
x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim]
x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size]
B_chunk = B[:, chunk] # [batch, chunk_size, state_size]
dBx = mx.matmul(x_chunk, B_chunk) # [batch, nheads, dim, state_size]
state = state * mx.expand_dims(dA, axis=-1) + dBx
# Replace einsum with explicit operations
C_chunk = C[:, chunk] # [batch, chunk_size, state_size]
y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size]
y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim]
outputs.append(y)
return mx.concatenate(outputs, axis=1), state
class DepthWiseConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = padding
self.groups = groups if groups is not None else in_channels
assert in_channels == out_channels, "In and out channels must be same for depthwise convolution"
assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution"
self.weight = mx.random.normal((in_channels, 1, kernel_size))
self.bias = mx.zeros((out_channels,)) if bias else None
def __call__(self, x: mx.array, cache=None) -> mx.array:
B, L, C = x.shape
K = self.kernel_size
assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}"
if cache is not None and cache.conv_states[0] is not None:
# Convert None to proper array if needed
if isinstance(cache.conv_states[0], type(None)):
cache.conv_states[0] = mx.zeros((B, K-1, C))
x = mx.concatenate([cache.conv_states[0], x], axis=1)
# Process each channel independently
outputs = []
for c in range(C):
x_c = x[:, :, c]
x_c = mx.expand_dims(x_c, axis=1)
w_c = self.weight[c]
if w_c.ndim == 2:
w_c = mx.expand_dims(w_c, axis=0)
elif w_c.ndim == 1:
w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0)
# Apply convolution
y_c = mx.conv_general(
x_c,
w_c,
stride=1,
padding=0
)
if self.bias is not None:
y_c = y_c + self.bias[c]
outputs.append(mx.squeeze(y_c, axis=1))
y = mx.stack(outputs, axis=-1)
# Update cache
if cache is not None:
cache.conv_states[0] = x[:, -K+1:, :] if x.shape[1] >= K else x
return y
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads
self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias)
conv_dim = args.intermediate_size + 2 * args.state_size
self.conv1d = DepthWiseConv1d(
in_channels=conv_dim,
out_channels=conv_dim,
kernel_size=args.conv_kernel,
groups=conv_dim,
bias=args.use_conv_bias,
padding=args.conv_kernel - 1
)
self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range
self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range
self.D = mx.random.normal((args.num_heads,)) * args.initializer_range
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
if args.rescale_prenorm_residual:
layer_scale = math.sqrt(1.0 / args.num_hidden_layers)
self.out_proj.weight = self.out_proj.weight * layer_scale
def __call__(self, u: mx.array, cache=None):
batch_size = u.shape[0]
seq_len = u.shape[1]
outputs = []
# Initialize states if needed
if cache.conv_states[0] is None:
conv_dim = self.args.intermediate_size + 2 * self.args.state_size
cache.conv_states[0] = mx.zeros((
batch_size,
self.args.conv_kernel - 1,
conv_dim
))
if cache.ssm_states[0] is None:
cache.ssm_states[0] = mx.zeros((
batch_size,
self.args.num_heads,
self.args.head_dim,
self.args.state_size
))
for pos in range(seq_len):
u_t = u[:, pos:pos+1, :]
zxbcdt = self.in_proj(u_t)
n_heads = self.args.num_heads
z = zxbcdt[:, :, :self.args.intermediate_size]
xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size]
dt = zxbcdt[:, :, -(n_heads):]
dt = mx.reshape(dt, (batch_size, n_heads))
dt = mx.clip(
nn.softplus(dt + self.dt_bias),
self.args.time_step_min,
self.args.time_step_max
)
dt = mx.maximum(dt, self.args.time_step_floor)
xBC = self.conv1d(xBC, cache=cache)
xBC = silu(xBC)
x = xBC[:, :, :self.args.intermediate_size]
B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size]
C = xBC[:, :, -self.args.state_size:]
x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim))
x = mx.squeeze(x, axis=1)
B = mx.reshape(B, (batch_size, 1, self.args.state_size))
B = mx.broadcast_to(B, (batch_size, n_heads, self.args.state_size))
B = mx.expand_dims(B, axis=2)
C = mx.reshape(C, (batch_size, 1, self.args.state_size))
C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size))
C = mx.expand_dims(C, axis=3)
A = -mx.exp(self.A_log)
dA = mx.exp(dt * mx.expand_dims(A, 0))
dA = mx.expand_dims(mx.expand_dims(dA, -1), -1)
x = mx.expand_dims(x, axis=3)
dBx = mx.matmul(x, B)
cache.ssm_states[0] = cache.ssm_states[0] * dA + dBx
y = mx.matmul(cache.ssm_states[0], C)
y = mx.squeeze(y, axis=-1)
y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1)
y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim))
y = self.norm(y + z)
y = self.out_proj(y)
outputs.append(y)
return mx.concatenate(outputs, axis=1)
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.residual_in_fp32 = args.residual_in_fp32
self.mixer = Mamba2Block(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
if self.residual_in_fp32:
x = x.astype(mx.float32)
return self.mixer(self.norm(x), cache) + x
class Mamba2(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(self, x: mx.array, cache):
x = self.embeddings(x)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
x = layer(x, c)
return self.norm_f(x)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba2(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None):
B, T = inputs.shape
x = self.backbone(inputs, cache)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(x)
else:
logits = self.lm_head(x)
return logits
def make_cache(self, batch_size=1):
return [Mamba2Cache(batch_size, self.args.conv_kernel) for _ in range(len(self.layers))]
def sanitize(self, weights):
sanitized = {}
for k, v in weights.items():
if "conv1d.weight" in k:
# Ensure weights are in correct shape (channels, 1, kernel_size)
if v.ndim == 2:
v = mx.expand_dims(v, axis=1)
elif v.ndim == 1:
v = mx.expand_dims(mx.expand_dims(v, axis=0), axis=0)
sanitized[k] = v
else:
sanitized[k] = v
return sanitized
@property
def layers(self):
return self.backbone.layers

View File

@ -1,7 +1,6 @@
import math
from dataclasses import dataclass, field
from typing import Optional, Tuple, Union
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@ -28,18 +27,17 @@ class ModelArgs(BaseModelArgs):
time_step_max: float
time_step_floor: float
rescale_prenorm_residual: bool
use_cache: bool
rms_norm: bool
chunk_size: int
tie_word_embeddings: bool
intermediate_size: int = None
use_cache: bool = True
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
time_step_rank: Union[int, str] = "auto"
model_type: str = "mamba2"
def __post_init__(self):
self.intermediate_size = int(self.expand * self.hidden_size) # E*D = ED
if not hasattr(self, "intermediate_size"):
self.intermediate_size = int(self.expand * self.hidden_size)
if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_heads
if self.time_step_rank == "auto":
@ -49,256 +47,241 @@ class ModelArgs(BaseModelArgs):
class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = mx.ones(hidden_size)
self.weight = mx.ones((hidden_size,))
self.variance_epsilon = eps
def forward(self, hidden_states, gate=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(mx.float32)
def __call__(self, hidden_states, gate=None):
if gate is not None:
hidden_states = hidden_states * nn.functional.silu(gate.to(mx.float32))
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * math.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
hidden_states = hidden_states * nn.silu(gate)
variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True)
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
class Mamba2Mixer(nn.Module):
def silu(x):
return x * mx.sigmoid(x)
def ssd(x, A, B, C, chunk_size):
# Replace einsum operations with explicit reshape and matrix multiply
batch, seqlen, nheads, dim = x.shape
B = mx.expand_dims(B, axis=2)
C = mx.expand_dims(C, axis=2)
state = mx.zeros((batch, nheads, dim, B.shape[-1]))
outputs = []
for i in range(0, seqlen, chunk_size):
chunk = slice(i, min(i + chunk_size, seqlen))
dA = mx.exp(mx.expand_dims(A[chunk], axis=0))
# Replace einsum with explicit operations
x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim]
x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size]
B_chunk = B[:, chunk] # [batch, chunk_size, state_size]
dBx = mx.matmul(x_chunk, B_chunk) # [batch, nheads, dim, state_size]
state = state * mx.expand_dims(dA, axis=-1) + dBx
# Replace einsum with explicit operations
C_chunk = C[:, chunk] # [batch, chunk_size, state_size]
y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size]
y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim]
outputs.append(y)
return mx.concatenate(outputs, axis=1), state
class DepthWiseConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = padding
self.groups = groups if groups is not None else in_channels
assert in_channels == out_channels, "In and out channels must be same for depthwise convolution"
assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution"
# Initialize weight with correct shape [C_out, 1, kernel_size]
self.weight = mx.random.normal((out_channels, 1, kernel_size))
self.bias = mx.zeros((out_channels,)) if bias else None
def __call__(self, x: mx.array, cache=None) -> mx.array:
B, L, C = x.shape
K = self.kernel_size
assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}"
# Handle caching for sequential processing
if cache is not None and cache.conv_states[0] is not None:
if isinstance(cache.conv_states[0], type(None)):
cache.conv_states[0] = mx.zeros((B, K-1, C))
x = mx.concatenate([cache.conv_states[0], x], axis=1)
# Process each channel independently
outputs = []
for c in range(C):
# Extract and reshape the channel
x_c = x[:, :, c] # [B, L]
x_c = mx.expand_dims(x_c, axis=1) # [B, 1, L]
# Get weight for this channel - already in correct shape [1, 1, K]
w_c = mx.expand_dims(self.weight[c], axis=0) # Ensure [1, 1, K]
# Apply convolution
y_c = mx.conv_general(
x_c,
w_c,
stride=1,
padding=self.padding
)
if self.bias is not None:
y_c = y_c + self.bias[c]
outputs.append(mx.squeeze(y_c, axis=1))
y = mx.stack(outputs, axis=-1)
# Update cache
if cache is not None:
cache.conv_states[0] = x[:, -K+1:, :] if x.shape[1] >= K else x
return y
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
# Model dimensions
self.hidden_size = args.hidden_size
self.num_heads = args.num_heads
self.head_dim = args.head_dim
self.ssm_state_size = args.state_size
self.n_groups = args.n_groups
self.intermediate_size = int(args.expand * args.hidden_size)
# Convolution parameters
self.conv_kernel = args.conv_kernel
self.use_conv_bias = args.use_conv_bias
# Time step parameters
self.time_step_rank = int(args.time_step_rank)
self.time_step_min = args.time_step_min
self.time_step_max = args.time_step_max
# Processing parameters
self.args = args
self.chunk_size = args.chunk_size
self.layer_norm_epsilon = args.layer_norm_epsilon
# Calculate dimensions
self.conv_dim = (self.intermediate_size +
2 * self.n_groups * self.ssm_state_size)
projection_size = (self.intermediate_size +
self.conv_dim +
self.num_heads)
# Initialize layers
self.in_proj = nn.Linear(
self.hidden_size,
projection_size,
bias=args.use_bias
)
self.conv1d = nn.Conv1d(
d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads
self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias)
self.conv_dim = args.intermediate_size + 2 * args.state_size
self.conv1d = DepthWiseConv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
kernel_size=self.conv_kernel,
kernel_size=args.conv_kernel,
groups=self.conv_dim,
padding=self.conv_kernel - 1,
bias=self.use_conv_bias
bias=args.use_conv_bias,
padding=args.conv_kernel - 1
)
# Initialize parameters
self.dt_bias = mx.ones(self.num_heads)
A = mx.arange(1, self.num_heads + 1)
self.A_log = mx.log(A)
self.D = mx.ones(self.num_heads)
# Output layers
self.norm = MambaRMSNormGated(
self.intermediate_size,
eps=self.layer_norm_epsilon
)
self.out_proj = nn.Linear(
self.intermediate_size,
self.hidden_size,
bias=args.use_bias
)
self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range
self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range
self.D = mx.random.normal((args.num_heads,)) * args.initializer_range
def reshape_into_chunks(self, tensor, pad_size, chunk_size):
if pad_size > 0:
pad_shape = list(tensor.shape)
pad_shape[1] = pad_size
padding = mx.zeros(pad_shape, dtype=tensor.dtype)
tensor = mx.concatenate([tensor, padding], axis=1)
chunk_shape = list(tensor.shape)
chunk_shape[1] = -1
chunk_shape.insert(2, chunk_size)
return tensor.reshape(chunk_shape)
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
def segment_sum(self, x):
return mx.cumsum(x, axis=-1)
if args.rescale_prenorm_residual:
layer_scale = math.sqrt(1.0 / args.num_hidden_layers)
self.out_proj.weight = self.out_proj.weight * layer_scale
def process_single_token(self, hidden_states, B, C, dt, cache):
batch_size = hidden_states.shape[0]
# Process convolution state
if cache is not None:
conv_state = cache.conv_states
# Roll the conv state and update the last position
conv_state = mx.roll(conv_state, shift=-1, axis=-1)
# Create new conv state with updated last position
new_conv_state = mx.array(conv_state)
new_conv_state = new_conv_state.at[:, :, -1].add(hidden_states)
conv_state = new_conv_state
# Compute convolution
conv_out = mx.sum(conv_state * self.conv1d.weight[:, 0, :], axis=-1)
if self.use_conv_bias:
conv_out = conv_out + self.conv1d.bias
# Apply SiLU activation
conv_out = mx.sigmoid(conv_out) * conv_out
else:
# Initialize new cache
conv_state = mx.zeros((batch_size, self.conv_dim, self.conv_kernel - 1))
conv_out = self.conv1d(hidden_states)
conv_out = mx.sigmoid(conv_out) * conv_out
# Process SSM
def __call__(self, u: mx.array, cache=None):
# Expect input shape: [batch_size, 1, hidden_size]
batch_size, seq_len, _ = u.shape
pad_size = self.chunk_size - (seq_len % self.chunk_size)
# Initialize states if needed
if cache.conv_states[0] is None:
cache.conv_states[0] = mx.zeros((
batch_size,
self.args.conv_kernel - 1,
self.conv_dim
))
if cache.ssm_states[0] is None:
cache.ssm_states[0] = mx.zeros((
batch_size,
self.args.num_heads,
self.args.head_dim,
self.args.state_size
))
# Project input
zxbcdt = self.in_proj(u)
# Split projections
z = zxbcdt[:, :, :self.args.intermediate_size]
xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size]
dt = zxbcdt[:, :, -(self.args.num_heads):]
# Process delta time
dt = mx.reshape(dt, (batch_size, seq_len, self.args.num_heads))
dt = mx.squeeze(dt, axis=0) # Remove sequence dimension for single token
dt = mx.clip(
nn.softplus(dt + self.dt_bias),
self.time_step_min,
self.time_step_max
self.args.time_step_min,
self.args.time_step_max
)
A = -mx.exp(self.A_log)
dA = mx.exp(dt * A[None, :])
if cache is not None:
ssm_state = cache.ssm_states
else:
ssm_state = mx.zeros(
(batch_size, self.num_heads, self.head_dim, self.ssm_state_size)
)
# Compute SSM updates
dBx = mx.einsum('bh,bhs,bhd->bhds', dt, B, hidden_states)
next_state = ssm_state * dA[:, :, None, None] + dBx
y = mx.einsum('bhds,bhs->bhd', next_state, C)
# Add skip connection
y = y + hidden_states * self.D[None, :, None]
return y, conv_state, next_state
dt = mx.maximum(dt, self.args.time_step_floor)
def process_long_sequence(self, hidden_states, B, C, dt, ssm_state):
batch_size, seq_len = hidden_states.shape[:2]
pad_size = self.chunk_size - (seq_len % self.chunk_size)
# Reshape into chunks
x_chunks = self.reshape_into_chunks(hidden_states, pad_size, self.chunk_size)
B_chunks = self.reshape_into_chunks(B, pad_size, self.chunk_size)
C_chunks = self.reshape_into_chunks(C, pad_size, self.chunk_size)
# Process time steps
dt = nn.softplus(dt + self.dt_bias)
dt = mx.clip(dt, self.time_step_min)
# Prepare matrices
# Convolution step
xBC = self.conv1d(xBC, cache=cache)
xBC = silu(xBC)
# Split conv output
x = xBC[:, :, :self.args.intermediate_size]
B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size]
C = xBC[:, :, -self.args.state_size:]
# Reshape for SSM
x = mx.reshape(x, (batch_size, 1, self.args.num_heads, self.args.head_dim))
x = mx.squeeze(x, axis=1)
B = mx.reshape(B, (batch_size, 1, self.args.state_size))
B = mx.broadcast_to(B, (batch_size, self.args.num_heads, self.args.state_size))
B = mx.expand_dims(B, axis=2)
C = mx.reshape(C, (batch_size, 1, self.args.state_size))
C = mx.broadcast_to(C, (batch_size, self.args.num_heads, self.args.state_size))
C = mx.expand_dims(C, axis=3)
# SSM state update
A = -mx.exp(self.A_log)
A = A * dt[:, None]
# Process chunks
A_chunks = self.reshape_into_chunks(
mx.broadcast_to(A, (batch_size, seq_len + pad_size, self.num_heads)),
pad_size,
self.chunk_size
)
# Compute cumulative sums
A_cumsum = mx.cumsum(A_chunks, axis=-1)
L = mx.exp(self.segment_sum(A_chunks))
# Process diagonal blocks
G = mx.einsum('...lhn,...shn->...lsh', C_chunks, B_chunks)
M = G * L[..., None, :]
Y_diag = mx.einsum('...lsh,...sh->...lh', M, x_chunks)
# Process off-diagonal blocks
decay_states = mx.exp(A_cumsum[..., -1:] - A_cumsum)
B_decay = B_chunks * decay_states[..., None]
states = mx.einsum('...shn,...sh->...hn', B_decay, x_chunks)
# Combine results
y = Y_diag + states
# Remove padding if necessary
dA = mx.exp(dt * mx.expand_dims(A, 0))
dA = mx.expand_dims(mx.expand_dims(dA, -1), -1)
x = mx.expand_dims(x, axis=3)
dBx = mx.matmul(x, B)
cache.ssm_states[0] = cache.ssm_states[0] * dA + dBx
# Output computation
y = mx.matmul(cache.ssm_states[0], C)
y = mx.squeeze(y, axis=-1)
# y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1)
if pad_size > 0:
y = y[:, :seq_len]
return y, ssm_state
y = y[:, :seq_len, :, :]
# Final reshape and projections
y = mx.reshape(y, (batch_size, 1, self.args.num_heads * self.args.head_dim))
y = self.norm(y + z)
def __call__(self, x: mx.array, cache: Optional[Mamba2Cache] = None) -> mx.array:
batch_size, seq_len, _ = x.shape
# Project input
projected_states = self.in_proj(x.squeeze(1))
# Calculate d_mlp based on projection size
d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 *
self.n_groups * self.ssm_state_size - self.num_heads) // 2
# Split projections with corrected dimensions
splits = [
d_mlp, # z0
d_mlp, # x0
self.intermediate_size, # gate
self.conv_dim, # hidden_states
self.num_heads # dt
]
z0, x0, x1, gate, hidden_states, dt = projected_states.split(splits, axis=-1)
# Split hidden states into components
x_conv, BC = mx.split(hidden_states, [self.intermediate_size], axis=-1)
B, C = mx.split(BC, [self.n_groups * self.ssm_state_size], axis=-1)
# Process based on sequence length
if seq_len > 1 and cache is None:
y, next_state = self.process_long_sequence(
x_conv, B, C, dt,
mx.zeros((batch_size, self.num_heads, self.head_dim, self.ssm_state_size))
)
else:
# Reshape for single token processing
x_conv = x_conv.reshape(batch_size, -1, self.head_dim)
B = B.reshape(batch_size, self.num_heads, -1)
C = C.reshape(batch_size, self.num_heads, -1)
y, conv_state, next_state = self.process_single_token(x_conv, B, C, dt, cache)
if cache is not None:
cache.update(conv_state, next_state)
# Apply normalization and final projection
y = self.norm(y) * gate
return self.out_proj(y)
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.mixer = Mamba2Mixer(args)
self.residual_in_fp32 = args.residual_in_fp32
self.mixer = Mamba2Block(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache: Optional[Mamba2Cache] = None) -> mx.array:
def __call__(self, x: mx.array, cache):
if self.residual_in_fp32:
x = x.astype(mx.float32)
return self.mixer(self.norm(x), cache) + x
class Mamba2Model(nn.Module):
class Mamba2(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
@ -306,26 +289,27 @@ class Mamba2Model(nn.Module):
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(self, x: mx.array, cache=None) -> mx.array:
def __call__(self, x: mx.array, cache):
x = self.embeddings(x)
if cache is None:
cache = [None] * len(self.layers)
for layer, layer_cache in zip(self.layers, cache):
x = layer(x, layer_cache)
for layer, c in zip(self.layers, cache):
x = layer(x, c)
return self.norm_f(x)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.backbone = Mamba2Model(args)
self.model_type = args.model_type
self.backbone = Mamba2(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None) -> mx.array:
def __call__(self, inputs: mx.array, cache=None):
B, T = inputs.shape
x = self.backbone(inputs, cache)
@ -336,26 +320,24 @@ class Model(nn.Module):
logits = self.lm_head(x)
return logits
def make_cache(self, batch_size=1):
return [
Mamba2Cache(
batch_size=batch_size,
conv_dim=self.args.intermediate_size + 2 * self.args.n_groups * self.args.state_size,
kernel_size=self.args.conv_kernel,
num_heads=self.args.num_heads,
head_dim=self.args.head_dim,
state_size=self.args.state_size
)
for _ in range(len(self.backbone.layers))
]
return [Mamba2Cache(batch_size, self.args.conv_kernel) for _ in range(len(self.layers))]
def sanitize(self, weights):
sanitized = {}
for k, v in weights.items():
if "conv1d.weight" in k and v.ndim == 3:
weights[k] = v.moveaxis(2, 1)
return weights
if "conv1d.weight" in k:
# Ensure weights are in correct shape (channels, 1, kernel_size)
if v.ndim == 2:
v = mx.expand_dims(v, axis=1)
elif v.ndim == 1:
v = mx.expand_dims(mx.expand_dims(v, axis=0), axis=0)
sanitized[k] = v
else:
sanitized[k] = v
return sanitized
@property
def layers(self):
return self.backbone.layers

View File

@ -0,0 +1,316 @@
import math
from dataclasses import dataclass, field
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .cache import Mamba2Cache
@dataclass
class ModelArgs(BaseModelArgs):
num_heads: int
head_dim: int
vocab_size: int
hidden_size: int
state_size: int
num_hidden_layers: int
layer_norm_epsilon: float
expand: int
conv_kernel: int
n_groups: int
use_bias: bool
use_conv_bias: bool
initializer_range: float
residual_in_fp32: bool
time_step_min: float
time_step_max: float
time_step_floor: float
rescale_prenorm_residual: bool
rms_norm: bool
chunk_size: int
tie_word_embeddings: bool
use_cache: bool = True
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
time_step_rank: Union[int, str] = "auto"
model_type: str = "mamba2"
def __post_init__(self):
if not hasattr(self, "intermediate_size"):
self.intermediate_size = int(self.expand * self.hidden_size)
if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_heads
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
def silu(x):
return x * mx.sigmoid(x)
class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = mx.ones((hidden_size,))
self.variance_epsilon = eps
def __call__(self, hidden_states, gate=None):
# Fuse operations where possible
if gate is not None:
hidden_states = hidden_states * nn.silu(gate)
# Compute variance in fp32 for better numerical stability
hidden_states_fp32 = hidden_states.astype(mx.float32)
variance = mx.mean(hidden_states_fp32 * hidden_states_fp32, axis=-1, keepdims=True)
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
def ssd_optimized(x, A, B, C, chunk_size):
batch, seqlen, nheads, dim = x.shape
B = mx.expand_dims(B, axis=2)
C = mx.expand_dims(C, axis=2)
output = mx.zeros((batch, seqlen, nheads, dim))
state = mx.zeros((batch, nheads, dim, B.shape[-1]))
for i in range(0, seqlen, chunk_size):
chunk = slice(i, min(i + chunk_size, seqlen))
chunk_size_actual = min(chunk_size, seqlen - i)
dA = mx.exp(mx.expand_dims(A[chunk], axis=0))
x_chunk = mx.transpose(x[:, chunk], [0, 2, 3, 1])
dBx = mx.matmul(x_chunk, B[:, chunk])
state = state * mx.expand_dims(dA, axis=-1) + dBx
y = mx.matmul(state, mx.transpose(C[:, chunk], [0, 2, 1]))
output[:, i:i+chunk_size_actual] = mx.transpose(y, [0, 3, 1, 2])
return output, state
def update_conv_cache(x: mx.array, cache, kernel_size: int) -> Tuple[mx.array, mx.array]:
"""Update convolution cache for sequential processing."""
B, L, C = x.shape
if cache is None:
# Initialize cache with zeros
cache = mx.zeros((B, kernel_size - 1, C))
# Concatenate cache with current input
x_with_cache = mx.concatenate([cache, x], axis=1)
# Update cache with the last (kernel_size - 1) elements
new_cache = x_with_cache[:, -kernel_size+1:] if x_with_cache.shape[1] >= kernel_size else x_with_cache
return x_with_cache, new_cache
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.intermediate_size = int(args.expand * args.hidden_size)
self.state_size = args.state_size
self.num_heads = args.num_heads
self.head_dim = args.head_dim
self.ssm_state_size = args.state_size
self.n_groups = args.n_groups
self.conv_kernel = args.conv_kernel
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
projection_size = self.intermediate_size + self.conv_dim + self.num_heads
self.in_proj = nn.Linear(args.hidden_size, projection_size, bias=args.use_bias)
# Using built-in Conv1d instead of custom DepthwiseConv1d
self.conv1d = nn.Conv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
kernel_size=args.conv_kernel,
groups=self.conv_dim, # For depthwise convolution
padding=0, # We'll handle padding manually with the cache
bias=args.use_conv_bias
)
self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range
self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range
self.D = mx.random.normal((args.num_heads,)) * args.initializer_range
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
if args.rescale_prenorm_residual:
layer_scale = math.sqrt(1.0 / args.num_hidden_layers)
self.out_proj.weight = self.out_proj.weight * layer_scale
def __call__(self, u: mx.array, cache):
batch_size, seq_len, _ = u.shape
projected = self.in_proj(u)
d_conv = self.conv_dim
z = projected[..., :self.intermediate_size]
xBC = projected[..., self.intermediate_size:self.intermediate_size + d_conv]
dt = projected[..., -self.num_heads:]
dt = mx.clip(
nn.softplus(dt + self.dt_bias),
self.args.time_step_min,
self.args.time_step_max
)
dt = mx.maximum(dt, self.args.time_step_floor)
# Handle convolution with separate cache update
if cache is not None:
# Update cache and get padded input
xBC_padded, new_cache = update_conv_cache(xBC, cache.conv_states, self.conv_kernel)
cache.conv_states = new_cache
# Prepare input for conv1d: [B, L, C] -> [B, C, L]
xBC_conv = mx.transpose(xBC_padded, [0, 2, 1])
# Apply convolution
xBC = self.conv1d(xBC_conv)
# Transform back: [B, C, L] -> [B, L, C]
xBC = mx.transpose(xBC, [0, 2, 1])
# Take only the relevant part corresponding to input length
xBC = xBC[:, :seq_len]
else:
# For training, use regular convolution with padding
xBC = mx.transpose(xBC, [0, 2, 1])
xBC = self.conv1d(xBC)
xBC = mx.transpose(xBC, [0, 2, 1])
xBC = silu(xBC)
x = xBC[..., :self.intermediate_size]
BC = xBC[..., self.intermediate_size:]
B = BC[..., :self.state_size]
C = BC[..., self.state_size:]
x = mx.reshape(x, (-1, seq_len, self.num_heads, self.intermediate_size // self.num_heads))
A = -mx.exp(self.A_log)
D_expanded = mx.expand_dims(self.D, -1)
if cache is not None and cache.ssm_state is None:
cache.ssm_state = mx.zeros((
batch_size,
self.num_heads,
self.intermediate_size // self.num_heads,
self.state_size
))
if cache is not None:
output = mx.zeros((batch_size, seq_len, self.args.hidden_size))
for pos in range(seq_len):
x_t = x[:, pos:pos+1]
dA = mx.exp(dt[:, pos:pos+1] * mx.expand_dims(A, 0))
dA = mx.expand_dims(mx.expand_dims(dA, -1), -1)
x_expanded = mx.expand_dims(x_t, axis=3)
dBx = mx.matmul(x_expanded, mx.expand_dims(B[:, pos:pos+1], axis=2))
cache.ssm_state = cache.ssm_state * dA + dBx
y = mx.matmul(cache.ssm_state, mx.expand_dims(C[:, pos:pos+1], axis=3))
y = mx.squeeze(y, axis=-1)
y = y + x_t * D_expanded
y = mx.reshape(y, (batch_size, 1, -1))
y = self.norm(y + z[:, pos:pos+1])
y = self.out_proj(y)
if self.args.residual_in_fp32:
y = y.astype(mx.float32)
output = output.at[:, pos:pos+1].set(y)
else:
y, ssm_state = ssd_optimized(
x * mx.expand_dims(dt, -1),
-mx.exp(self.A_log) * dt,
B, C,
self.args.chunk_size
)
y = mx.reshape(
y + x * mx.expand_dims(self.D, -1),
(batch_size, seq_len, -1)
)
y = self.norm(y + z)
output = self.out_proj(y)
if self.args.residual_in_fp32:
output = output.astype(mx.float32)
return output
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.mixer = Mamba2Block(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
return self.mixer(self.norm(x), cache) + x
class Mamba2(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(self, x: mx.array, cache):
x = self.embeddings(x)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
x = layer(x, c)
return self.norm_f(x)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba2(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None):
B, T = inputs.shape
x = self.backbone(inputs, cache)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(x)
else:
logits = self.lm_head(x)
return logits
def make_cache(self, batch_size=1):
return [Mamba2Cache() for _ in range(len(self.layers))]
def sanitize(self, weights):
for k, v in weights.items():
if "conv1d.weight" in k and v.ndim == 3:
weights[k] = v.moveaxis(2, 1)
return weights
@property
def layers(self):
return self.backbone.layers

View File

@ -0,0 +1,357 @@
import math
from dataclasses import dataclass, field
from typing import Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .cache import Mamba2Cache
@dataclass
class ModelArgs(BaseModelArgs):
num_heads: int
head_dim: int
vocab_size: int
hidden_size: int
state_size: int
num_hidden_layers: int
layer_norm_epsilon: float
expand: int
conv_kernel: int
n_groups: int
use_bias: bool
use_conv_bias: bool
initializer_range: float
residual_in_fp32: bool
time_step_min: float
time_step_max: float
time_step_floor: float
rescale_prenorm_residual: bool
use_cache: bool
rms_norm: bool
chunk_size: int
tie_word_embeddings: bool
intermediate_size: int = None
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
time_step_rank: Union[int, str] = "auto"
model_type: str = "mamba2"
def __post_init__(self):
self.intermediate_size = int(self.expand * self.hidden_size) # E*D = ED
if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_heads
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = mx.ones(hidden_size)
self.variance_epsilon = eps
def forward(self, hidden_states, gate=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(mx.float32)
if gate is not None:
hidden_states = hidden_states * nn.functional.silu(gate.to(mx.float32))
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * math.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class Mamba2Mixer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
# Model dimensions
self.hidden_size = args.hidden_size
self.num_heads = args.num_heads
self.head_dim = args.head_dim
self.ssm_state_size = args.state_size
self.n_groups = args.n_groups
self.intermediate_size = int(args.expand * args.hidden_size)
# Convolution parameters
self.conv_kernel = args.conv_kernel
self.use_conv_bias = args.use_conv_bias
# Time step parameters
self.time_step_rank = int(args.time_step_rank)
self.time_step_min = args.time_step_min
self.time_step_max = args.time_step_max
# Processing parameters
self.chunk_size = args.chunk_size
self.layer_norm_epsilon = args.layer_norm_epsilon
# Calculate dimensions
self.conv_dim = (self.intermediate_size +
2 * self.n_groups * self.ssm_state_size)
projection_size = (self.intermediate_size +
self.conv_dim +
self.num_heads)
# Initialize layers
self.in_proj = nn.Linear(
self.hidden_size,
projection_size,
bias=args.use_bias
)
self.conv1d = nn.Conv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
kernel_size=self.conv_kernel,
groups=self.conv_dim,
padding=self.conv_kernel - 1,
bias=self.use_conv_bias
)
# Initialize parameters
self.dt_bias = mx.ones(self.num_heads)
A = mx.arange(1, self.num_heads + 1)
self.A_log = mx.log(A)
self.D = mx.ones(self.num_heads)
# Output layers
self.norm = MambaRMSNormGated(
self.intermediate_size,
eps=self.layer_norm_epsilon
)
self.out_proj = nn.Linear(
self.intermediate_size,
self.hidden_size,
bias=args.use_bias
)
def reshape_into_chunks(self, tensor, pad_size, chunk_size):
if pad_size > 0:
pad_shape = list(tensor.shape)
pad_shape[1] = pad_size
padding = mx.zeros(pad_shape, dtype=tensor.dtype)
tensor = mx.concatenate([tensor, padding], axis=1)
chunk_shape = list(tensor.shape)
chunk_shape[1] = -1
chunk_shape.insert(2, chunk_size)
return tensor.reshape(chunk_shape)
def segment_sum(self, x):
return mx.cumsum(x, axis=-1)
def process_single_token(self, hidden_states, B, C, dt, cache):
batch_size = hidden_states.shape[0]
# Process convolution state
if cache is not None and cache.conv_states is not None:
conv_state = cache.conv_states
# Roll the conv state and update the last position
conv_state = mx.roll(conv_state, shift=-1, axis=-1)
# Create new conv state with updated last position
new_conv_state = mx.array(conv_state)
new_conv_state = new_conv_state.at[:, :, -1].add(hidden_states)
conv_state = new_conv_state
# Compute convolution
conv_out = mx.sum(conv_state * self.conv1d.weight[:, 0, :], axis=-1)
if self.use_conv_bias:
conv_out = conv_out + self.conv1d.bias
# Apply SiLU activation
conv_out = mx.sigmoid(conv_out) * conv_out
else:
# Initialize new cache and process convolution
conv_state = mx.zeros((batch_size, self.conv_dim, self.conv_kernel - 1))
# Reshape hidden_states for conv1d
hidden_states_reshaped = hidden_states.reshape(batch_size, -1, 1)
conv_out = self.conv1d(hidden_states_reshaped)
conv_out = mx.squeeze(conv_out, axis=-1) # Remove the last dimension
conv_out = mx.sigmoid(conv_out) * conv_out
# Process SSM
dt = mx.clip(
nn.softplus(dt + self.dt_bias),
self.time_step_min,
self.time_step_max
)
A = -mx.exp(self.A_log)
dA = mx.exp(dt[:, None] * A[None, :])
if cache is not None and cache.ssm_states is not None:
ssm_state = cache.ssm_states
else:
ssm_state = mx.zeros(
(batch_size, self.num_heads, self.head_dim, self.ssm_state_size)
)
# Compute SSM updates
dBx = mx.einsum('bh,bhs,bhd->bhds', dt, B, hidden_states)
next_state = ssm_state * dA[:, :, None, None] + dBx
y = mx.einsum('bhds,bhs->bhd', next_state, C)
# Add skip connection
y = y + hidden_states * self.D[None, :, None]
return y, conv_state, next_state
def process_long_sequence(self, hidden_states, B, C, dt, ssm_state):
batch_size, seq_len = hidden_states.shape[:2]
pad_size = self.chunk_size - (seq_len % self.chunk_size)
# Reshape into chunks
x_chunks = self.reshape_into_chunks(hidden_states, pad_size, self.chunk_size)
B_chunks = self.reshape_into_chunks(B, pad_size, self.chunk_size)
C_chunks = self.reshape_into_chunks(C, pad_size, self.chunk_size)
# Process time steps
dt = nn.softplus(dt + self.dt_bias)
dt = mx.clip(dt, self.time_step_min)
# Prepare matrices
A = -mx.exp(self.A_log)
A = A * dt[:, None]
# Process chunks
A_chunks = self.reshape_into_chunks(
mx.broadcast_to(A, (batch_size, seq_len + pad_size, self.num_heads)),
pad_size,
self.chunk_size
)
# Compute cumulative sums
A_cumsum = mx.cumsum(A_chunks, axis=-1)
L = mx.exp(self.segment_sum(A_chunks))
# Process diagonal blocks
G = mx.einsum('...lhn,...shn->...lsh', C_chunks, B_chunks)
M = G * L[..., None, :]
Y_diag = mx.einsum('...lsh,...sh->...lh', M, x_chunks)
# Process off-diagonal blocks
decay_states = mx.exp(A_cumsum[..., -1:] - A_cumsum)
B_decay = B_chunks * decay_states[..., None]
states = mx.einsum('...shn,...sh->...hn', B_decay, x_chunks)
# Combine results
y = Y_diag + states
# Remove padding if necessary
if pad_size > 0:
y = y[:, :seq_len]
return y, ssm_state
def __call__(self, x: mx.array, cache: Optional[Mamba2Cache] = None) -> mx.array:
batch_size, seq_len, _ = x.shape
# Project input
projected_states = self.in_proj(x)
# Calculate d_mlp based on projection size
d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 *
self.n_groups * self.ssm_state_size - self.num_heads) // 2
# Split projections with corrected dimensions
splits = [
d_mlp, # z0
d_mlp, # x0
self.intermediate_size, # gate
self.conv_dim, # hidden_states
self.num_heads # dt
]
z0, x0, x1, gate, hidden_states, dt = projected_states.split(splits, axis=-1)
# Split hidden states into components
x_conv, BC = mx.split(hidden_states, [self.intermediate_size], axis=-1)
B, C = mx.split(BC, [self.n_groups * self.ssm_state_size], axis=-1)
# Process based on sequence length
if seq_len > 1 and cache is None:
y, next_state = self.process_long_sequence(
x_conv, B, C, dt,
mx.zeros((batch_size, self.num_heads, self.head_dim, self.ssm_state_size))
)
else:
# Reshape for single token processing
x_conv = x_conv.reshape(batch_size, -1, self.head_dim)
B = B.reshape(batch_size, self.num_heads, -1)
C = C.reshape(batch_size, self.num_heads, -1)
y, conv_state, next_state = self.process_single_token(x_conv, B, C, dt, cache)
if cache is not None:
cache.update(conv_state, next_state)
# Apply normalization and final projection
y = self.norm(y) * gate
return self.out_proj(y)
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.mixer = Mamba2Mixer(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache: Optional[Mamba2Cache] = None) -> mx.array:
return self.mixer(self.norm(x), cache) + x
class Mamba2Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(self, x: mx.array, cache=None) -> mx.array:
x = self.embeddings(x)
if cache is None:
cache = [None] * len(self.layers)
for layer, layer_cache in zip(self.layers, cache):
x = layer(x, layer_cache)
return self.norm_f(x)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.backbone = Mamba2Model(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None) -> mx.array:
B, T = inputs.shape
x = self.backbone(inputs, cache)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(x)
else:
logits = self.lm_head(x)
return logits
def make_cache(self, batch_size=1):
return [Mamba2Cache() for _ in range(len(self.backbone.layers))]
def sanitize(self, weights):
for k, v in weights.items():
if "conv1d.weight" in k and v.ndim == 3:
weights[k] = v.moveaxis(2, 1)
return weights
@property
def layers(self):
return self.backbone.layers

View File

@ -0,0 +1,430 @@
import math
from dataclasses import dataclass, field
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .cache import Mamba2Cache
@dataclass
class ModelArgs(BaseModelArgs):
num_heads: int
head_dim: int
vocab_size: int
hidden_size: int
state_size: int
num_hidden_layers: int
layer_norm_epsilon: float
expand: int
conv_kernel: int
n_groups: int
use_bias: bool
use_conv_bias: bool
initializer_range: float
residual_in_fp32: bool
time_step_min: float
time_step_max: float
time_step_floor: float
rescale_prenorm_residual: bool
rms_norm: bool
chunk_size: int
tie_word_embeddings: bool
use_cache: bool = True
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
time_step_rank: Union[int, str] = "auto"
model_type: str = "mamba2"
def __post_init__(self):
if not hasattr(self, "intermediate_size"):
self.intermediate_size = int(self.expand * self.hidden_size)
if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_heads
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = mx.ones((hidden_size,))
self.variance_epsilon = eps
def __call__(self, hidden_states, gate=None):
if gate is not None:
hidden_states = hidden_states * nn.silu(gate)
variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True)
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
def pad_tensor_by_size(input_tensor: mx.array, pad_size: int):
"""
Padding x tensor with `pad_size` on the seq_len dim (dim=1)
Assumes that we only have tensors of either size 4 or 3
"""
pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0)
return mx.pad(input_tensor, pad_shape, mode="constant", value=0)
def reshape_into_chunks(input_tensor, pad_size, chunk_size):
"""
Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and
simultaneously splitting it into chunk sequences.
Assumes that we only have tensors of either size 4 or 3
"""
# [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...]
input_tensor = pad_tensor_by_size(input_tensor, pad_size)
if len(input_tensor.shape) == 3:
# [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads]
return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2])
else:
# [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size]
return input_tensor.reshape(
input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3]
)
def segment_sum(input_tensor):
"""
More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.
"""
chunk_size = input_tensor.size(-1)
# 1. expand input tensor to have an additional dimension and repeat along that dimension
# [..., chunk_size] -> [..., chunk_size, chunk_size]
input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size)
# 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag
mask = mx.tril(mx.ones(chunk_size, chunk_size, device=input_tensor.device), diagonal=-1)
input_tensor = input_tensor.masked_fill(~mask, 0)
# 3. compute actual cumsum
tensor_segsum = mx.cumsum(input_tensor, dim=-2)
# 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time)
mask = mx.tril(mx.ones(chunk_size, chunk_size, device=input_tensor.device), diagonal=0)
tensor_segsum = tensor_segsum.masked_fill(~mask, -mx.inf)
return tensor_segsum
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
self.args = args
self.hidden_size = args.hidden_size
self.num_heads = args.num_heads
self.head_dim = args.head_dim
self.state_size = args.state_size
self.n_groups = args.n_groups
self.conv_kernel = args.conv_kernel
self.intermediate_size = int(args.expand * args.hidden_size)
self.time_step_rank = int(args.time_step_rank)
self.time_step_min = args.time_step_min
self.time_step_max = args.time_step_max
self.chunk_size = args.chunk_size
# Convolution dimension includes both intermediate sizes
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size
self.conv1d = nn.Conv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
bias=args.use_conv_bias,
kernel_size=args.conv_kernel,
groups=self.conv_dim,
padding=args.conv_kernel - 1
)
# Compute input projection dimension
projection_size = self.intermediate_size + self.conv_dim + self.num_heads
self.in_proj = nn.Linear(args.hidden_size, projection_size, bias=args.use_bias)
self.dt_bias = mx.ones(self.num_heads)
A = mx.arange(1, self.num_heads + 1)
self.A_log = mx.log(A)
self.D = mx.ones(self.num_heads)
self.norm = MambaRMSNormGated(self.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias)
def __call__(self, input_states: mx.array, cache):
batch_size, seq_len, _ = input_states.shape
# Gated MLP's linear projection
projected_states = self.in_proj(input_states) # [1, 1, projection_size]
d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size -
2 * self.n_groups * self.state_size - self.num_heads) // 2
# Split projected states
*_, gate, hidden_states, dt = projected_states.split(
[d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads],
axis=-1
)
# hidden_states shape: [1, 1, conv_dim]
# Get SSM state from cache
ssm_state = cache.ssm_states[self.layer_idx]
if cache.seqlen_offset > 0:
# Handle cached generation case
conv_state = cache.conv_states[self.layer_idx] # [batch, conv_dim, conv_kernel]
conv_state = mx.roll(conv_state, shifts=-1, axis=-1)
# Handle batched generation - states are copied through
# Properly reshape hidden_states for the conv_state update
conv_state = conv_state.at[:, :, -1].set(hidden_states[:, 0, :])
cache.conv_states[self.layer_idx] = conv_state
# Compute convolution output
hidden_states = mx.sum(conv_state * self.conv1d.weight[:, 0, :], axis=-1)
if self.args.use_conv_bias:
hidden_states += self.conv1d.bias
hidden_states = nn.silu(hidden_states)[:, None, ...] # [batch, 1, conv_dim] : decoding
else:
# Handle normal forward pass
# Properly transpose while preserving the sequence dimension
hidden_states = hidden_states.transpose(0, 2, 1) # [1, conv_dim, 1]
# Pad the convolution state
padding_size = self.conv_kernel - 1
conv_state = mx.pad(
hidden_states,
((0, 0), (0, 0), (padding_size, 0))
)
# Store in cache
cache.conv_states[self.layer_idx] = conv_state
# Apply convolution with proper padding
hidden_states = self.conv1d(hidden_states) # [1, conv_dim, 1]
hidden_states = hidden_states.transpose(0, 2, 1) # [1, 1, conv_dim]
hidden_states = nn.silu(hidden_states)
# Split hidden states for SSM computation
hidden_states, B, C = mx.split(
hidden_states,
[self.intermediate_size, self.n_groups * self.state_size, self.n_groups * self.state_size],
axis=-1
)
# Compute A matrix
A = -mx.exp(self.A_log)
if cache is not None and cache.seqlen_offset > 0:
# Note: there is no need to pad parameter matrices here, as there is just one new token
# for batched generation
dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...]
dt = dt.transpose(0, 2, 1).expand(batch_size, dt.shape[-1], self.head_dim)
# [num_heads] -> [num_heads, head_dim]
dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)
dt = nn.softplus(dt + dt_bias)
dt = mx.clamp(dt, self.time_step_min) #, self.time_step_max)
A = A[..., None, None].expand(self.num_heads, self.head_dim, self.state_size)
# [bsz, num_heads, head_dim, state_size]
dA = mx.exp(dt[..., None] * A)
# Discretize B
# [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
# -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size]
B = B.reshape(batch_size, self.n_groups, -1)[..., None, :]
B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous()
B = B.reshape(batch_size, -1, B.shape[-1])
# [bsz, num_heads, head_dim, state_size]
dB = dt[..., None] * B[..., None, :]
# Discretize x into dB
# [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
dBx = dB * hidden_states[..., None]
# State calculation
cache.ssm_states[self.layer_idx].copy_(
cache.ssm_states[self.layer_idx] * dA + dBx
)
# Subsequent output
# [bsz, n_groups * state_size] -> [bsz, num_heads, state_size]
C = C.reshape(batch_size, self.n_groups, -1)[..., None, :]
C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous()
C = C.reshape(batch_size, -1, C.shape[-1])
# [bsz, num_heads, head_dim]
ssm_states = cache.ssm_states[self.layer_idx] # Shape: [b, h, d, n]
# Reshape ssm_states to merge the first two dimensions
ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.state_size) # Shape: [b*h, d, n]
C_reshaped = C.view(batch_size * self.num_heads, self.state_size, 1) # Shape: [b*h, n, 1]
y = ssm_states_reshaped @ C_reshaped
y = y.view(batch_size, self.num_heads, self.head_dim)
# D skip connection
# [num_heads] -> [num_heads, head_dim]
D = self.D[..., None].expand(self.D.shape[0], self.head_dim)
y = (y + hidden_states * D)
# [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size]
y = y.reshape(batch_size, -1)[:, None, ...]
else:
# begin ssd naive implementation without einsums
dt = nn.functional.softplus(dt + self.dt_bias)
dt = mx.clamp(dt, self.time_step_min)
hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim)
B = B.reshape(batch_size, seq_len, -1, self.state_size)
C = C.reshape(batch_size, seq_len, -1, self.state_size)
B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
pad_size = self.chunk_size - (seq_len % self.chunk_size)
D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
# Discretize x and A
hidden_states = hidden_states * dt[..., None]
A = A * dt
# Rearrange into blocks/chunks
hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]
# [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size]
A = A.permute(0, 3, 1, 2)
A_cumsum = mx.cumsum(A, dim=-1)
# 1. Compute the output for each intra-chunk (diagonal blocks)
# This is the analog of a causal mask
L = mx.exp(segment_sum(A))
# First, contraction of C and B to get G (attention-weights like)
G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n)
G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)
# Step 2: Compute M, equivalent to applying attention mask to weights
M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
M = M_intermediate.sum(dim=-1)
# Step 3: Compute Y_diag (apply to values)
Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3)
# (right term of low-rank factorization of off-diagonal blocks; B terms)
decay_states = mx.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None]
# permute back B * decay states
states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3)
if cache is not None and cache.seqlen_offset > 0:
previous_states = cache.ssm_states[self.layer_idx][:, None, ...]
else:
previous_states = mx.zeros_like(states[:, :1])
states = mx.concat([previous_states, states], dim=1)
decay_chunk = mx.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
states_permuted = states.permute(0, 2, 1, 3, 4)
result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2)
new_states = result.permute(0, 2, 1, 3, 4)
states, ssm_state = new_states[:, :-1], new_states[:, -1]
# Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = mx.exp(A_cumsum)
# compute Yoff
C_times_states = (C[..., None, :] * states[:, :, None, ...])
state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
y = Y_diag + Y_off
# [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
y = y + D_residual
# Cutting off padded chunks
if pad_size > 0:
y = y[:, :seq_len, :, :]
y = y.reshape(batch_size, seq_len, -1)
if ssm_state is not None and cache is not None:
cache.ssm_states[self.layer_idx] = ssm_state
scan_output = self.norm(y, gate)
# end ssd naive
# 4. Final linear projection
return self.out_proj(scan_output) # [batch, seq_len, hidden_size]
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__()
self.residual_in_fp32 = args.residual_in_fp32
self.mixer = Mamba2Block(args, layer_idx)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
return self.mixer(self.norm(x), cache) + x
class Mamba2(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [ResidualBlock(args, idx) for idx in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(self, x: mx.array, cache):
x = self.embeddings(x)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
x = layer(x, c)
return self.norm_f(x)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba2(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None):
B, T = inputs.shape
x = self.backbone(inputs, cache)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(x)
else:
logits = self.lm_head(x)
return logits
def make_cache(self, batch_size=1):
return [Mamba2Cache(
batch_size,
self.args.intermediate_size,
self.args.conv_kernel,
self.args.head_dim,
self.args.num_heads,
self.args.n_groups,
self.args.state_size
) for _ in range(len(self.layers))]
def sanitize(self, weights):
for k, v in weights.items():
if "conv1d.weight" in k and v.ndim == 3:
weights[k] = v.moveaxis(2, 1)
return weights
@property
def layers(self):
return self.backbone.layers