removing some files

This commit is contained in:
Goekdeniz-Guelmez 2024-11-21 22:05:42 +01:00
parent e22b2dbf27
commit 117ffd3909
6 changed files with 0 additions and 2948 deletions

View File

@ -1,437 +0,0 @@
"""
mamba2-minimal
==============
A minimal, single-file implementation of the Mamba-2 model in PyTorch.
> **Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality**
> Authors: Tri Dao, Albert Gu
> Paper: https://arxiv.org/abs/2405.21060
"""
import json
from dataclasses import dataclass
from typing import Iterable, NamedTuple, TypeAlias, cast
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import LongTensor, Tensor, nn
Device: TypeAlias = str | torch.device | None
@dataclass
class Mamba2Config:
d_model: int # model dimension (D)
n_layer: int = 24 # number of Mamba-2 layers in the language model
d_state: int = 128 # state dimension (N)
d_conv: int = 4 # convolution kernel size
expand: int = 2 # expansion factor (E)
headdim: int = 64 # head dimension (P)
chunk_size: int = 64 # matrix partition size (Q)
vocab_size: int = 50277
pad_vocab_size_multiple: int = 16
def __post_init__(self):
self.d_inner = self.expand * self.d_model
assert self.d_inner % self.headdim == 0
self.nheads = self.d_inner // self.headdim
if self.vocab_size % self.pad_vocab_size_multiple != 0:
self.vocab_size += (
self.pad_vocab_size_multiple
- self.vocab_size % self.pad_vocab_size_multiple
)
class InferenceCache(NamedTuple):
conv_state: Tensor # (batch, d_inner + 2 * d_state, d_conv)
ssm_state: Tensor # (batch, nheads, headdim, d_state)
@staticmethod
def alloc(batch_size: int, args: Mamba2Config, device: Device = None):
return InferenceCache(
torch.zeros(
batch_size, args.d_inner + 2 * args.d_state, args.d_conv, device=device
),
torch.zeros(
batch_size, args.nheads, args.headdim, args.d_state, device=device
),
)
class Mamba2LMHeadModel(nn.Module):
def __init__(self, args: Mamba2Config, device: Device = None):
super().__init__()
self.args = args
self.device = device
self.backbone = nn.ModuleDict(
dict(
embedding=nn.Embedding(args.vocab_size, args.d_model, device=device),
layers=nn.ModuleList(
[
nn.ModuleDict(
dict(
mixer=Mamba2(args, device=device),
norm=RMSNorm(args.d_model, device=device),
)
)
for _ in range(args.n_layer)
]
),
norm_f=RMSNorm(args.d_model, device=device),
)
)
self.lm_head = nn.Linear(
args.d_model, args.vocab_size, bias=False, device=device
)
self.lm_head.weight = self.backbone.embedding.weight
@staticmethod
def from_pretrained(huggingface_model_id: str, device: Device = None):
from transformers.utils import CONFIG_NAME, WEIGHTS_NAME
from transformers.utils.hub import cached_file
config_path = cached_file(huggingface_model_id, CONFIG_NAME)
assert config_path, "Failed to get huggingface config file"
state_dict_path = cached_file(huggingface_model_id, WEIGHTS_NAME)
assert state_dict_path, "Failed to get huggingface state dict file"
config = json.load(open(config_path))
args = Mamba2Config(
d_model=config["d_model"],
n_layer=config["n_layer"],
vocab_size=config["vocab_size"],
pad_vocab_size_multiple=config["pad_vocab_size_multiple"],
)
map_location = "cpu" if device is None else device
state_dict = torch.load(
state_dict_path, weights_only=True, map_location=map_location, mmap=True
)
model = Mamba2LMHeadModel(args, device=device)
model.load_state_dict(state_dict)
model.eval()
return model
def forward(
self, input_ids: LongTensor, h: list[InferenceCache] | list[None] | None = None
) -> tuple[LongTensor, list[InferenceCache]]:
"""
Arguments
input_ids: (batch, seqlen) tokens from `EleutherAI/gpt-neox-20b` tokenizer
h: hidden states for inference step. If present the constant-time
(wrt sequence length) inference path will be taken, input_ids
should have shape (batch, 1) containing the next batch of prompt
token.
Return (logits, h)
logits: (batch, seqlen, vocab_size)
h: updated inference cache after processing `input_ids`
"""
seqlen = input_ids.shape[1]
if h is None:
h = [None for _ in range(self.args.n_layer)]
x = self.backbone.embedding(input_ids)
for i, layer in enumerate(self.backbone.layers):
y, h[i] = layer.mixer(layer.norm(x), h[i])
x = y + x
x = self.backbone.norm_f(x)
logits = self.lm_head(x)
return logits[:, :seqlen], cast(list[InferenceCache], h)
def generate(
self,
input_ids: LongTensor,
max_new_length: int = 20,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 1.0,
eos_token_id: int = 0,
) -> Iterable[tuple[int, list[InferenceCache]]]:
prefix, tokens = input_ids[:-1], input_ids[-1:].unsqueeze(0)
# Process prompt
# The input sequence to forward (non-inference path) must have length multiple that of chunk_size.
# We split out excess tokens so that n_chunked tokens can be processed by one forward call and
# process the rest in multiple inference steps.
n_chunked = (prefix.shape[0] // self.args.chunk_size) * self.args.chunk_size
if n_chunked > 0:
_, h = self(prefix[:n_chunked].unsqueeze(0), None)
else:
h = [
InferenceCache.alloc(1, self.args, device=self.device)
for _ in range(self.args.n_layer)
]
for i in range(n_chunked, prefix.shape[0]):
_, h = self(prefix[i : i + 1].unsqueeze(0), h)
# Generate
for _ in range(max_new_length):
with torch.no_grad():
out, h = self(tokens, h)
logits = out[0, -1]
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, k=top_k)[0][-1]
logits[indices_to_remove] = -torch.inf
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > 0.5
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
sorted_indices_to_remove[0] = False
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = -torch.inf
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
if next_token.item() == eos_token_id:
return
tokens = next_token.unsqueeze(0)
yield cast(int, next_token.item()), h
class Mamba2(nn.Module):
def __init__(self, args: Mamba2Config, device: Device = None):
super().__init__()
self.args = args
self.device = device
# Order: (z, x, B, C, dt)
d_in_proj = 2 * args.d_inner + 2 * args.d_state + args.nheads
self.in_proj = nn.Linear(args.d_model, d_in_proj, bias=False, device=device)
conv_dim = args.d_inner + 2 * args.d_state
self.conv1d = nn.Conv1d(
in_channels=conv_dim,
out_channels=conv_dim,
kernel_size=args.d_conv,
groups=conv_dim,
padding=args.d_conv - 1,
device=device,
)
self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device))
self.A_log = nn.Parameter(torch.empty(args.nheads, device=device))
self.D = nn.Parameter(torch.empty(args.nheads, device=device))
self.norm = RMSNorm(args.d_inner, device=device)
self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device)
def forward(self, u: Tensor, h: InferenceCache | None = None):
"""
Arguments
u: (batch, seqlen, d_model) input. seqlen should be a multiple of chunk_size.
h: hidden states for inference step. Initialized to 0s if not present.
Return (y, h)
y: (batch, seqlen, d_model) output
h: updated inference cache after processing `u`
"""
if h:
return self.step(u, h)
A = -torch.exp(self.A_log) # (nheads,)
zxbcdt = self.in_proj(u) # (batch, seqlen, d_in_proj)
z, xBC, dt = torch.split(
zxbcdt,
[
self.args.d_inner,
self.args.d_inner + 2 * self.args.d_state,
self.args.nheads,
],
dim=-1,
)
dt = F.softplus(dt + self.dt_bias) # (batch, seqlen, nheads)
# Pad or truncate xBC seqlen to d_conv
conv_state = F.pad(
rearrange(xBC, "b l d -> b d l"), (self.args.d_conv - u.shape[1], 0)
)
xBC = silu(
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, : u.shape[1], :]
) # (batch, seqlen, d_inner + 2 * d_state))
x, B, C = torch.split(
xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1
)
x = rearrange(x, "b l (h p) -> b l h p", p=self.args.headdim)
y, ssm_state = ssd(
x * dt.unsqueeze(-1),
A * dt,
rearrange(B, "b l n -> b l 1 n"),
rearrange(C, "b l n -> b l 1 n"),
self.args.chunk_size,
device=self.device,
)
y = y + x * self.D.unsqueeze(-1)
y = rearrange(y, "b l h p -> b l (h p)")
y = self.norm(y, z)
y = self.out_proj(y)
h = InferenceCache(conv_state, ssm_state)
return y, h
def step(self, u: Tensor, h: InferenceCache) -> tuple[Tensor, InferenceCache]:
"""Take a single inference step for the current input and hidden state
Unlike attention-based models, RNN-based models (eg Mamba) does not need
to look back at all the past tokens to generate a new token. Instead a
hidden state (initialized to 0s initially) is updated for each input and
passed to the next inference step. This means that the total inference
time is linear with respect to the sequence length instead of quadratic
in attention's case.
Arguments
u: (batch, 1, d_model)
h: initial/running hidden state
Return (y, h)
y: (batch, 1, d_model)
h: updated hidden state
"""
assert u.shape[1] == 1, "Only one token can be decoded per inference step"
zxbcdt = self.in_proj(u.squeeze(1)) # (batch, d_in_proj)
z, xBC, dt = torch.split(
zxbcdt,
[
self.args.d_inner,
self.args.d_inner + 2 * self.args.d_state,
self.args.nheads,
],
dim=-1,
)
# Advance convolution input
h.conv_state.copy_(torch.roll(h.conv_state, shifts=-1, dims=-1))
h.conv_state[:, :, -1] = xBC
# Convolution step
xBC = torch.sum(
h.conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
)
xBC += self.conv1d.bias
xBC = silu(xBC)
x, B, C = torch.split(
xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1
)
A = -torch.exp(self.A_log) # (nheads,)
# SSM step
dt = F.softplus(dt + self.dt_bias) # (batch, nheads)
dA = torch.exp(dt * A) # (batch, nheads)
x = rearrange(x, "b (h p) -> b h p", p=self.args.headdim)
dBx = torch.einsum("bh, bn, bhp -> bhpn", dt, B, x)
h.ssm_state.copy_(h.ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
y = torch.einsum("bhpn, bn -> bhp", h.ssm_state, C)
y = y + rearrange(self.D, "h -> h 1") * x
y = rearrange(y, "b h p -> b (h p)")
y = self.norm(y, z)
y = self.out_proj(y)
return y.unsqueeze(1), h
def segsum(x: Tensor, device: Device = None) -> Tensor:
"""Stable segment sum calculation.
`exp(segsum(A))` produces a 1-semiseparable matrix, which is equivalent to a scalar SSM.
Source: https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L23-L32
"""
T = x.size(-1)
x = repeat(x, "... d -> ... d e", e=T)
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=-1)
x = x.masked_fill(~mask, 0)
x_segsum = torch.cumsum(x, dim=-2)
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=0)
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
return x_segsum
def ssd(x, A, B, C, chunk_size, initial_states=None, device: Device = None):
"""Structed State Space Duality (SSD) - the core of Mamba-2
This is almost the exact same minimal SSD code from the blog post.
Arguments
x: (batch, seqlen, n_heads, d_head)
A: (batch, seqlen, n_heads)
B: (batch, seqlen, n_heads, d_state)
C: (batch, seqlen, n_heads, d_state)
Return
y: (batch, seqlen, n_heads, d_head)
Source
1. https://tridao.me/blog/2024/mamba2-part3-algorithm/
2. https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L34-L78
"""
assert x.shape[1] % chunk_size == 0
# Rearrange into chunks
# Step 1, 2 and 4 of SSD can be computed in parallel for each chunk across devices (sequence parallel)
# This is not implemented and left as an exercise for the reader 😜
x, A, B, C = [
rearrange(m, "b (c l) ... -> b c l ...", l=chunk_size) for m in (x, A, B, C)
]
A = rearrange(A, "b c l h -> b h c l")
A_cumsum = torch.cumsum(A, dim=-1)
# 1. Compute the output for each intra-chunk (diagonal blocks)
L = torch.exp(segsum(A, device=device))
Y_diag = torch.einsum("bclhn, bcshn, bhcls, bcshp -> bclhp", C, B, L, x)
# 2. Compute the state for each intra-chunk
# (right term of low-rank factorization of off-diagonal blocks; B terms)
decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
states = torch.einsum("bclhn, bhcl, bclhp -> bchpn", B, decay_states, x)
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
if initial_states is None:
initial_states = torch.zeros_like(states[:, :1])
states = torch.cat([initial_states, states], dim=1)
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)), device=device))
new_states = torch.einsum("bhzc, bchpn -> bzhpn", decay_chunk, states)
states, final_state = new_states[:, :-1], new_states[:, -1]
# 4. Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = torch.exp(A_cumsum)
Y_off = torch.einsum("bclhn, bchpn, bhcl -> bclhp", C, states, state_decay_out)
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
return Y, final_state
class RMSNorm(nn.Module):
def __init__(self, d: int, eps: float = 1e-5, device: Device = None):
"""Gated Root Mean Square Layer Normalization
Paper: https://arxiv.org/abs/1910.07467
"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d, device=device))
def forward(self, x, z=None):
if z is not None:
x = x * silu(z)
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
def silu(x):
"""Applies the Sigmoid Linear Unit (SiLU), element-wise.
Define this manually since torch's version doesn't seem to work on MPS.
"""
return x * F.sigmoid(x)

File diff suppressed because it is too large Load Diff

View File

@ -1,300 +0,0 @@
import math
from dataclasses import dataclass, field
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .cache import MambaCache
@dataclass
class ModelArgs(BaseModelArgs):
num_heads: int
head_dim: int
vocab_size: int
hidden_size: int
state_size: int
num_hidden_layers: int
layer_norm_epsilon: float
expand: int
conv_kernel: int
n_groups: int
use_bias: bool
use_conv_bias: bool
initializer_range: float
residual_in_fp32: bool
time_step_min: float
time_step_max: float
time_step_floor: float
rescale_prenorm_residual: bool
rms_norm: bool
chunk_size: int
tie_word_embeddings: bool
use_cache: bool = True
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
time_step_rank: Union[int, str] = "auto"
model_type: str = "mamba2"
def __post_init__(self):
if not hasattr(self, "intermediate_size"):
self.intermediate_size = int(self.expand * self.hidden_size)
if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_heads
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = mx.ones((hidden_size,))
self.variance_epsilon = eps
def __call__(self, hidden_states, gate=None):
if gate is not None:
hidden_states = hidden_states * nn.silu(gate)
variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True)
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
def silu(x):
return x * mx.sigmoid(x)
def ssd(x, A, B, C, chunk_size):
batch, seqlen, nheads, dim = x.shape
B = mx.expand_dims(B, axis=2)
C = mx.expand_dims(C, axis=2)
state = mx.zeros((batch, nheads, dim, B.shape[-1]))
outputs = []
for i in range(0, seqlen, chunk_size):
chunk = slice(i, min(i + chunk_size, seqlen))
dA = mx.exp(mx.expand_dims(A[chunk], axis=0))
x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim]
x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size]
B_chunk = B[:, chunk] # [batch, chunk_size, state_size]
dBx = mx.matmul(x_chunk, B_chunk) # [batch, nheads, dim, state_size]
state = state * mx.expand_dims(dA, axis=-1) + dBx
C_chunk = C[:, chunk] # [batch, chunk_size, state_size]
y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size]
y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim]
outputs.append(y)
return mx.concatenate(outputs, axis=1), state
class DepthWiseConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = padding
self.groups = groups if groups is not None else in_channels
assert in_channels == out_channels, "In and out channels must be same for depthwise convolution"
assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution"
self.weight = mx.random.normal((in_channels, 1, kernel_size))
self.bias = mx.zeros((out_channels,)) if bias else None
def __call__(self, x: mx.array, cache=None) -> mx.array:
B, L, C = x.shape
K = self.kernel_size
assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}"
if cache is not None:
# Access conv_state directly from cache[0]
if cache[0] is None:
cache[0] = mx.zeros((B, K-1, C))
x = mx.concatenate([cache[0], x], axis=1)
outputs = []
for c in range(C):
x_c = x[:, :, c]
x_c = mx.expand_dims(x_c, axis=1)
w_c = self.weight[c]
if w_c.ndim == 2:
w_c = mx.expand_dims(w_c, axis=0)
elif w_c.ndim == 1:
w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0)
y_c = mx.conv_general(
x_c,
w_c,
stride=1,
padding=0
)
if self.bias is not None:
y_c = y_c + self.bias[c]
y_c = mx.squeeze(y_c, axis=1)
outputs.append(y_c)
y = mx.stack(outputs, axis=-1)
# Update cache directly using cache[0]
if cache is not None:
cache[0] = x[:, -K+1:, :] if x.shape[1] >= K else x
return y
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads
self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias)
conv_dim = args.intermediate_size + 2 * args.state_size
self.conv1d = DepthWiseConv1d(
in_channels=conv_dim,
out_channels=conv_dim,
kernel_size=args.conv_kernel,
groups=conv_dim,
bias=args.use_conv_bias,
padding=args.conv_kernel - 1
)
self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range
self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range
self.D = mx.random.normal((args.num_heads,)) * args.initializer_range
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
if args.rescale_prenorm_residual:
layer_scale = math.sqrt(1.0 / args.num_hidden_layers)
self.out_proj.weight = self.out_proj.weight * layer_scale
def __call__(self, u: mx.array, cache=None):
batch_size, seq_len, dimension = u.shape
assert seq_len == 1, "Input should be a single token"
# Initialize cache states directly using indices
if cache[0] is None: # conv state
conv_dim = self.args.intermediate_size + 2 * self.args.state_size
cache[0] = mx.zeros((batch_size, self.args.conv_kernel - 1, conv_dim))
if cache[1] is None: # ssm state
cache[1] = mx.zeros((
batch_size,
self.args.num_heads,
self.args.head_dim,
self.args.state_size
))
zxbcdt = self.in_proj(u)
n_heads = self.args.num_heads
z = zxbcdt[:, :, :self.args.intermediate_size]
xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size]
dt = zxbcdt[:, :, -(n_heads):]
dt = mx.reshape(dt, (batch_size, n_heads))
dt = mx.clip(nn.softplus(dt + self.dt_bias), self.args.time_step_min, self.args.time_step_max)
dt = mx.maximum(dt, self.args.time_step_floor)
xBC = self.conv1d(xBC, cache=cache)
xBC = silu(xBC)
x = xBC[:, :, :self.args.intermediate_size]
B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size]
C = xBC[:, :, -self.args.state_size:]
x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim))
x = mx.squeeze(x, axis=1)
B = mx.reshape(B, (batch_size, 1, self.args.state_size))
B = mx.broadcast_to(B, (batch_size, n_heads, self.args.state_size))
B = mx.expand_dims(B, axis=2)
C = mx.reshape(C, (batch_size, 1, self.args.state_size))
C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size))
C = mx.expand_dims(C, axis=3)
A = -mx.exp(self.A_log)
dA = mx.exp(dt * mx.expand_dims(A, 0))
dA = mx.expand_dims(mx.expand_dims(dA, -1), -1)
x = mx.expand_dims(x, axis=3)
dBx = mx.matmul(x, B)
# Update ssm state directly using cache[1]
cache[1] = cache[1] * dA + dBx
y = mx.matmul(cache[1], C)
y = mx.squeeze(y, axis=-1)
y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1)
y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim))
y = self.norm(y + z)
return self.out_proj(y)
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.residual_in_fp32 = args.residual_in_fp32
self.mixer = Mamba2Block(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
if self.residual_in_fp32:
x = x.astype(mx.float32)
normed = self.norm(x)
output = self.mixer(normed, cache)
return output + x
class Mamba2(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(self, x: mx.array, cache):
x = self.embeddings(x)
if cache is None:
cache = [None] * len(self.layers)
hidden = x
for layer, c in zip(self.layers, cache):
hidden = layer(hidden, c)
return self.norm_f(hidden)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba2(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None):
hidden = self.backbone(inputs, cache)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(hidden)
else:
logits = self.lm_head(hidden)
return logits
def make_cache(self):
return [MambaCache() for _ in range(len(self.layers))]
@property
def layers(self):
return self.backbone.layers

View File

@ -1,357 +0,0 @@
import math
from dataclasses import dataclass, field
from typing import Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .cache import Mamba2Cache
@dataclass
class ModelArgs(BaseModelArgs):
num_heads: int
head_dim: int
vocab_size: int
hidden_size: int
state_size: int
num_hidden_layers: int
layer_norm_epsilon: float
expand: int
conv_kernel: int
n_groups: int
use_bias: bool
use_conv_bias: bool
initializer_range: float
residual_in_fp32: bool
time_step_min: float
time_step_max: float
time_step_floor: float
rescale_prenorm_residual: bool
use_cache: bool
rms_norm: bool
chunk_size: int
tie_word_embeddings: bool
intermediate_size: int = None
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
time_step_rank: Union[int, str] = "auto"
model_type: str = "mamba2"
def __post_init__(self):
self.intermediate_size = int(self.expand * self.hidden_size) # E*D = ED
if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_heads
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = mx.ones(hidden_size)
self.variance_epsilon = eps
def forward(self, hidden_states, gate=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(mx.float32)
if gate is not None:
hidden_states = hidden_states * nn.functional.silu(gate.to(mx.float32))
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * math.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class Mamba2Mixer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
# Model dimensions
self.hidden_size = args.hidden_size
self.num_heads = args.num_heads
self.head_dim = args.head_dim
self.ssm_state_size = args.state_size
self.n_groups = args.n_groups
self.intermediate_size = int(args.expand * args.hidden_size)
# Convolution parameters
self.conv_kernel = args.conv_kernel
self.use_conv_bias = args.use_conv_bias
# Time step parameters
self.time_step_rank = int(args.time_step_rank)
self.time_step_min = args.time_step_min
self.time_step_max = args.time_step_max
# Processing parameters
self.chunk_size = args.chunk_size
self.layer_norm_epsilon = args.layer_norm_epsilon
# Calculate dimensions
self.conv_dim = (self.intermediate_size +
2 * self.n_groups * self.ssm_state_size)
projection_size = (self.intermediate_size +
self.conv_dim +
self.num_heads)
# Initialize layers
self.in_proj = nn.Linear(
self.hidden_size,
projection_size,
bias=args.use_bias
)
self.conv1d = nn.Conv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
kernel_size=self.conv_kernel,
groups=self.conv_dim,
padding=self.conv_kernel - 1,
bias=self.use_conv_bias
)
# Initialize parameters
self.dt_bias = mx.ones(self.num_heads)
A = mx.arange(1, self.num_heads + 1)
self.A_log = mx.log(A)
self.D = mx.ones(self.num_heads)
# Output layers
self.norm = MambaRMSNormGated(
self.intermediate_size,
eps=self.layer_norm_epsilon
)
self.out_proj = nn.Linear(
self.intermediate_size,
self.hidden_size,
bias=args.use_bias
)
def reshape_into_chunks(self, tensor, pad_size, chunk_size):
if pad_size > 0:
pad_shape = list(tensor.shape)
pad_shape[1] = pad_size
padding = mx.zeros(pad_shape, dtype=tensor.dtype)
tensor = mx.concatenate([tensor, padding], axis=1)
chunk_shape = list(tensor.shape)
chunk_shape[1] = -1
chunk_shape.insert(2, chunk_size)
return tensor.reshape(chunk_shape)
def segment_sum(self, x):
return mx.cumsum(x, axis=-1)
def process_single_token(self, hidden_states, B, C, dt, cache):
batch_size = hidden_states.shape[0]
# Process convolution state
if cache is not None and cache.conv_states is not None:
conv_state = cache.conv_states
# Roll the conv state and update the last position
conv_state = mx.roll(conv_state, shift=-1, axis=-1)
# Create new conv state with updated last position
new_conv_state = mx.array(conv_state)
new_conv_state = new_conv_state.at[:, :, -1].add(hidden_states)
conv_state = new_conv_state
# Compute convolution
conv_out = mx.sum(conv_state * self.conv1d.weight[:, 0, :], axis=-1)
if self.use_conv_bias:
conv_out = conv_out + self.conv1d.bias
# Apply SiLU activation
conv_out = mx.sigmoid(conv_out) * conv_out
else:
# Initialize new cache and process convolution
conv_state = mx.zeros((batch_size, self.conv_dim, self.conv_kernel - 1))
# Reshape hidden_states for conv1d
hidden_states_reshaped = hidden_states.reshape(batch_size, -1, 1)
conv_out = self.conv1d(hidden_states_reshaped)
conv_out = mx.squeeze(conv_out, axis=-1) # Remove the last dimension
conv_out = mx.sigmoid(conv_out) * conv_out
# Process SSM
dt = mx.clip(
nn.softplus(dt + self.dt_bias),
self.time_step_min,
self.time_step_max
)
A = -mx.exp(self.A_log)
dA = mx.exp(dt[:, None] * A[None, :])
if cache is not None and cache.ssm_states is not None:
ssm_state = cache.ssm_states
else:
ssm_state = mx.zeros(
(batch_size, self.num_heads, self.head_dim, self.ssm_state_size)
)
# Compute SSM updates
dBx = mx.einsum('bh,bhs,bhd->bhds', dt, B, hidden_states)
next_state = ssm_state * dA[:, :, None, None] + dBx
y = mx.einsum('bhds,bhs->bhd', next_state, C)
# Add skip connection
y = y + hidden_states * self.D[None, :, None]
return y, conv_state, next_state
def process_long_sequence(self, hidden_states, B, C, dt, ssm_state):
batch_size, seq_len = hidden_states.shape[:2]
pad_size = self.chunk_size - (seq_len % self.chunk_size)
# Reshape into chunks
x_chunks = self.reshape_into_chunks(hidden_states, pad_size, self.chunk_size)
B_chunks = self.reshape_into_chunks(B, pad_size, self.chunk_size)
C_chunks = self.reshape_into_chunks(C, pad_size, self.chunk_size)
# Process time steps
dt = nn.softplus(dt + self.dt_bias)
dt = mx.clip(dt, self.time_step_min)
# Prepare matrices
A = -mx.exp(self.A_log)
A = A * dt[:, None]
# Process chunks
A_chunks = self.reshape_into_chunks(
mx.broadcast_to(A, (batch_size, seq_len + pad_size, self.num_heads)),
pad_size,
self.chunk_size
)
# Compute cumulative sums
A_cumsum = mx.cumsum(A_chunks, axis=-1)
L = mx.exp(self.segment_sum(A_chunks))
# Process diagonal blocks
G = mx.einsum('...lhn,...shn->...lsh', C_chunks, B_chunks)
M = G * L[..., None, :]
Y_diag = mx.einsum('...lsh,...sh->...lh', M, x_chunks)
# Process off-diagonal blocks
decay_states = mx.exp(A_cumsum[..., -1:] - A_cumsum)
B_decay = B_chunks * decay_states[..., None]
states = mx.einsum('...shn,...sh->...hn', B_decay, x_chunks)
# Combine results
y = Y_diag + states
# Remove padding if necessary
if pad_size > 0:
y = y[:, :seq_len]
return y, ssm_state
def __call__(self, x: mx.array, cache: Optional[Mamba2Cache] = None) -> mx.array:
batch_size, seq_len, _ = x.shape
# Project input
projected_states = self.in_proj(x)
# Calculate d_mlp based on projection size
d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 *
self.n_groups * self.ssm_state_size - self.num_heads) // 2
# Split projections with corrected dimensions
splits = [
d_mlp, # z0
d_mlp, # x0
self.intermediate_size, # gate
self.conv_dim, # hidden_states
self.num_heads # dt
]
z0, x0, x1, gate, hidden_states, dt = projected_states.split(splits, axis=-1)
# Split hidden states into components
x_conv, BC = mx.split(hidden_states, [self.intermediate_size], axis=-1)
B, C = mx.split(BC, [self.n_groups * self.ssm_state_size], axis=-1)
# Process based on sequence length
if seq_len > 1 and cache is None:
y, next_state = self.process_long_sequence(
x_conv, B, C, dt,
mx.zeros((batch_size, self.num_heads, self.head_dim, self.ssm_state_size))
)
else:
# Reshape for single token processing
x_conv = x_conv.reshape(batch_size, -1, self.head_dim)
B = B.reshape(batch_size, self.num_heads, -1)
C = C.reshape(batch_size, self.num_heads, -1)
y, conv_state, next_state = self.process_single_token(x_conv, B, C, dt, cache)
if cache is not None:
cache.update(conv_state, next_state)
# Apply normalization and final projection
y = self.norm(y) * gate
return self.out_proj(y)
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.mixer = Mamba2Mixer(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache: Optional[Mamba2Cache] = None) -> mx.array:
return self.mixer(self.norm(x), cache) + x
class Mamba2Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(self, x: mx.array, cache=None) -> mx.array:
x = self.embeddings(x)
if cache is None:
cache = [None] * len(self.layers)
for layer, layer_cache in zip(self.layers, cache):
x = layer(x, layer_cache)
return self.norm_f(x)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.backbone = Mamba2Model(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None) -> mx.array:
B, T = inputs.shape
x = self.backbone(inputs, cache)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(x)
else:
logits = self.lm_head(x)
return logits
def make_cache(self, batch_size=1):
return [Mamba2Cache() for _ in range(len(self.backbone.layers))]
def sanitize(self, weights):
for k, v in weights.items():
if "conv1d.weight" in k and v.ndim == 3:
weights[k] = v.moveaxis(2, 1)
return weights
@property
def layers(self):
return self.backbone.layers

View File

@ -1,430 +0,0 @@
import math
from dataclasses import dataclass, field
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .cache import Mamba2Cache
@dataclass
class ModelArgs(BaseModelArgs):
num_heads: int
head_dim: int
vocab_size: int
hidden_size: int
state_size: int
num_hidden_layers: int
layer_norm_epsilon: float
expand: int
conv_kernel: int
n_groups: int
use_bias: bool
use_conv_bias: bool
initializer_range: float
residual_in_fp32: bool
time_step_min: float
time_step_max: float
time_step_floor: float
rescale_prenorm_residual: bool
rms_norm: bool
chunk_size: int
tie_word_embeddings: bool
use_cache: bool = True
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
time_step_rank: Union[int, str] = "auto"
model_type: str = "mamba2"
def __post_init__(self):
if not hasattr(self, "intermediate_size"):
self.intermediate_size = int(self.expand * self.hidden_size)
if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_heads
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = mx.ones((hidden_size,))
self.variance_epsilon = eps
def __call__(self, hidden_states, gate=None):
if gate is not None:
hidden_states = hidden_states * nn.silu(gate)
variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True)
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
def pad_tensor_by_size(input_tensor: mx.array, pad_size: int):
"""
Padding x tensor with `pad_size` on the seq_len dim (dim=1)
Assumes that we only have tensors of either size 4 or 3
"""
pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0)
return mx.pad(input_tensor, pad_shape, mode="constant", value=0)
def reshape_into_chunks(input_tensor, pad_size, chunk_size):
"""
Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and
simultaneously splitting it into chunk sequences.
Assumes that we only have tensors of either size 4 or 3
"""
# [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...]
input_tensor = pad_tensor_by_size(input_tensor, pad_size)
if len(input_tensor.shape) == 3:
# [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads]
return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2])
else:
# [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size]
return input_tensor.reshape(
input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3]
)
def segment_sum(input_tensor):
"""
More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.
"""
chunk_size = input_tensor.size(-1)
# 1. expand input tensor to have an additional dimension and repeat along that dimension
# [..., chunk_size] -> [..., chunk_size, chunk_size]
input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size)
# 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag
mask = mx.tril(mx.ones(chunk_size, chunk_size, device=input_tensor.device), diagonal=-1)
input_tensor = input_tensor.masked_fill(~mask, 0)
# 3. compute actual cumsum
tensor_segsum = mx.cumsum(input_tensor, dim=-2)
# 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time)
mask = mx.tril(mx.ones(chunk_size, chunk_size, device=input_tensor.device), diagonal=0)
tensor_segsum = tensor_segsum.masked_fill(~mask, -mx.inf)
return tensor_segsum
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
self.args = args
self.hidden_size = args.hidden_size
self.num_heads = args.num_heads
self.head_dim = args.head_dim
self.state_size = args.state_size
self.n_groups = args.n_groups
self.conv_kernel = args.conv_kernel
self.intermediate_size = int(args.expand * args.hidden_size)
self.time_step_rank = int(args.time_step_rank)
self.time_step_min = args.time_step_min
self.time_step_max = args.time_step_max
self.chunk_size = args.chunk_size
# Convolution dimension includes both intermediate sizes
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size
self.conv1d = nn.Conv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
bias=args.use_conv_bias,
kernel_size=args.conv_kernel,
groups=self.conv_dim,
padding=args.conv_kernel - 1
)
# Compute input projection dimension
projection_size = self.intermediate_size + self.conv_dim + self.num_heads
self.in_proj = nn.Linear(args.hidden_size, projection_size, bias=args.use_bias)
self.dt_bias = mx.ones(self.num_heads)
A = mx.arange(1, self.num_heads + 1)
self.A_log = mx.log(A)
self.D = mx.ones(self.num_heads)
self.norm = MambaRMSNormGated(self.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias)
def __call__(self, input_states: mx.array, cache):
batch_size, seq_len, _ = input_states.shape
# Gated MLP's linear projection
projected_states = self.in_proj(input_states) # [1, 1, projection_size]
d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size -
2 * self.n_groups * self.state_size - self.num_heads) // 2
# Split projected states
*_, gate, hidden_states, dt = projected_states.split(
[d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads],
axis=-1
)
# hidden_states shape: [1, 1, conv_dim]
# Get SSM state from cache
ssm_state = cache.ssm_states[self.layer_idx]
if cache.seqlen_offset > 0:
# Handle cached generation case
conv_state = cache.conv_states[self.layer_idx] # [batch, conv_dim, conv_kernel]
conv_state = mx.roll(conv_state, shifts=-1, axis=-1)
# Handle batched generation - states are copied through
# Properly reshape hidden_states for the conv_state update
conv_state = conv_state.at[:, :, -1].set(hidden_states[:, 0, :])
cache.conv_states[self.layer_idx] = conv_state
# Compute convolution output
hidden_states = mx.sum(conv_state * self.conv1d.weight[:, 0, :], axis=-1)
if self.args.use_conv_bias:
hidden_states += self.conv1d.bias
hidden_states = nn.silu(hidden_states)[:, None, ...] # [batch, 1, conv_dim] : decoding
else:
# Handle normal forward pass
# Properly transpose while preserving the sequence dimension
hidden_states = hidden_states.transpose(0, 2, 1) # [1, conv_dim, 1]
# Pad the convolution state
padding_size = self.conv_kernel - 1
conv_state = mx.pad(
hidden_states,
((0, 0), (0, 0), (padding_size, 0))
)
# Store in cache
cache.conv_states[self.layer_idx] = conv_state
# Apply convolution with proper padding
hidden_states = self.conv1d(hidden_states) # [1, conv_dim, 1]
hidden_states = hidden_states.transpose(0, 2, 1) # [1, 1, conv_dim]
hidden_states = nn.silu(hidden_states)
# Split hidden states for SSM computation
hidden_states, B, C = mx.split(
hidden_states,
[self.intermediate_size, self.n_groups * self.state_size, self.n_groups * self.state_size],
axis=-1
)
# Compute A matrix
A = -mx.exp(self.A_log)
if cache is not None and cache.seqlen_offset > 0:
# Note: there is no need to pad parameter matrices here, as there is just one new token
# for batched generation
dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...]
dt = dt.transpose(0, 2, 1).expand(batch_size, dt.shape[-1], self.head_dim)
# [num_heads] -> [num_heads, head_dim]
dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)
dt = nn.softplus(dt + dt_bias)
dt = mx.clamp(dt, self.time_step_min) #, self.time_step_max)
A = A[..., None, None].expand(self.num_heads, self.head_dim, self.state_size)
# [bsz, num_heads, head_dim, state_size]
dA = mx.exp(dt[..., None] * A)
# Discretize B
# [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
# -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size]
B = B.reshape(batch_size, self.n_groups, -1)[..., None, :]
B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous()
B = B.reshape(batch_size, -1, B.shape[-1])
# [bsz, num_heads, head_dim, state_size]
dB = dt[..., None] * B[..., None, :]
# Discretize x into dB
# [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
dBx = dB * hidden_states[..., None]
# State calculation
cache.ssm_states[self.layer_idx].copy_(
cache.ssm_states[self.layer_idx] * dA + dBx
)
# Subsequent output
# [bsz, n_groups * state_size] -> [bsz, num_heads, state_size]
C = C.reshape(batch_size, self.n_groups, -1)[..., None, :]
C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous()
C = C.reshape(batch_size, -1, C.shape[-1])
# [bsz, num_heads, head_dim]
ssm_states = cache.ssm_states[self.layer_idx] # Shape: [b, h, d, n]
# Reshape ssm_states to merge the first two dimensions
ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.state_size) # Shape: [b*h, d, n]
C_reshaped = C.view(batch_size * self.num_heads, self.state_size, 1) # Shape: [b*h, n, 1]
y = ssm_states_reshaped @ C_reshaped
y = y.view(batch_size, self.num_heads, self.head_dim)
# D skip connection
# [num_heads] -> [num_heads, head_dim]
D = self.D[..., None].expand(self.D.shape[0], self.head_dim)
y = (y + hidden_states * D)
# [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size]
y = y.reshape(batch_size, -1)[:, None, ...]
else:
# begin ssd naive implementation without einsums
dt = nn.functional.softplus(dt + self.dt_bias)
dt = mx.clamp(dt, self.time_step_min)
hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim)
B = B.reshape(batch_size, seq_len, -1, self.state_size)
C = C.reshape(batch_size, seq_len, -1, self.state_size)
B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
pad_size = self.chunk_size - (seq_len % self.chunk_size)
D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
# Discretize x and A
hidden_states = hidden_states * dt[..., None]
A = A * dt
# Rearrange into blocks/chunks
hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]
# [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size]
A = A.permute(0, 3, 1, 2)
A_cumsum = mx.cumsum(A, dim=-1)
# 1. Compute the output for each intra-chunk (diagonal blocks)
# This is the analog of a causal mask
L = mx.exp(segment_sum(A))
# First, contraction of C and B to get G (attention-weights like)
G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n)
G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)
# Step 2: Compute M, equivalent to applying attention mask to weights
M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
M = M_intermediate.sum(dim=-1)
# Step 3: Compute Y_diag (apply to values)
Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3)
# (right term of low-rank factorization of off-diagonal blocks; B terms)
decay_states = mx.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None]
# permute back B * decay states
states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3)
if cache is not None and cache.seqlen_offset > 0:
previous_states = cache.ssm_states[self.layer_idx][:, None, ...]
else:
previous_states = mx.zeros_like(states[:, :1])
states = mx.concat([previous_states, states], dim=1)
decay_chunk = mx.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
states_permuted = states.permute(0, 2, 1, 3, 4)
result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2)
new_states = result.permute(0, 2, 1, 3, 4)
states, ssm_state = new_states[:, :-1], new_states[:, -1]
# Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = mx.exp(A_cumsum)
# compute Yoff
C_times_states = (C[..., None, :] * states[:, :, None, ...])
state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
y = Y_diag + Y_off
# [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
y = y + D_residual
# Cutting off padded chunks
if pad_size > 0:
y = y[:, :seq_len, :, :]
y = y.reshape(batch_size, seq_len, -1)
if ssm_state is not None and cache is not None:
cache.ssm_states[self.layer_idx] = ssm_state
scan_output = self.norm(y, gate)
# end ssd naive
# 4. Final linear projection
return self.out_proj(scan_output) # [batch, seq_len, hidden_size]
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__()
self.residual_in_fp32 = args.residual_in_fp32
self.mixer = Mamba2Block(args, layer_idx)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
return self.mixer(self.norm(x), cache) + x
class Mamba2(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [ResidualBlock(args, idx) for idx in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(self, x: mx.array, cache):
x = self.embeddings(x)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
x = layer(x, c)
return self.norm_f(x)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba2(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None):
B, T = inputs.shape
x = self.backbone(inputs, cache)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(x)
else:
logits = self.lm_head(x)
return logits
def make_cache(self, batch_size=1):
return [Mamba2Cache(
batch_size,
self.args.intermediate_size,
self.args.conv_kernel,
self.args.head_dim,
self.args.num_heads,
self.args.n_groups,
self.args.state_size
) for _ in range(len(self.layers))]
def sanitize(self, weights):
for k, v in weights.items():
if "conv1d.weight" in k and v.ndim == 3:
weights[k] = v.moveaxis(2, 1)
return weights
@property
def layers(self):
return self.backbone.layers

View File

@ -1,343 +0,0 @@
import math
from dataclasses import dataclass, field
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .cache import Mamba2Cache
@dataclass
class ModelArgs(BaseModelArgs):
num_heads: int
head_dim: int
vocab_size: int
hidden_size: int
state_size: int
num_hidden_layers: int
layer_norm_epsilon: float
expand: int
conv_kernel: int
n_groups: int
use_bias: bool
use_conv_bias: bool
initializer_range: float
residual_in_fp32: bool
time_step_min: float
time_step_max: float
time_step_floor: float
rescale_prenorm_residual: bool
rms_norm: bool
chunk_size: int
tie_word_embeddings: bool
use_cache: bool = True
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
time_step_rank: Union[int, str] = "auto"
model_type: str = "mamba2"
def __post_init__(self):
if not hasattr(self, "intermediate_size"):
self.intermediate_size = int(self.expand * self.hidden_size)
if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_heads
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = mx.ones((hidden_size,))
self.variance_epsilon = eps
def __call__(self, hidden_states, gate=None):
if gate is not None:
hidden_states = hidden_states * nn.silu(gate)
variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True)
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
def silu(x):
return x * mx.sigmoid(x)
def ssd(x, A, B, C, chunk_size):
# Replace einsum operations with explicit reshape and matrix multiply
batch, seqlen, nheads, dim = x.shape
B = mx.expand_dims(B, axis=2)
C = mx.expand_dims(C, axis=2)
state = mx.zeros((batch, nheads, dim, B.shape[-1]))
outputs = []
for i in range(0, seqlen, chunk_size):
chunk = slice(i, min(i + chunk_size, seqlen))
dA = mx.exp(mx.expand_dims(A[chunk], axis=0))
# Replace einsum with explicit operations
x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim]
x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size]
B_chunk = B[:, chunk] # [batch, chunk_size, state_size]
dBx = mx.matmul(x_chunk, B_chunk) # [batch, nheads, dim, state_size]
state = state * mx.expand_dims(dA, axis=-1) + dBx
# Replace einsum with explicit operations
C_chunk = C[:, chunk] # [batch, chunk_size, state_size]
y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size]
y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim]
outputs.append(y)
return mx.concatenate(outputs, axis=1), state
class DepthWiseConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = padding
self.groups = groups if groups is not None else in_channels
assert in_channels == out_channels, "In and out channels must be same for depthwise convolution"
assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution"
# Initialize weight with correct shape [C_out, 1, kernel_size]
self.weight = mx.random.normal((out_channels, 1, kernel_size))
self.bias = mx.zeros((out_channels,)) if bias else None
def __call__(self, x: mx.array, cache=None) -> mx.array:
B, L, C = x.shape
K = self.kernel_size
assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}"
# Handle caching for sequential processing
if cache is not None and cache.conv_states[0] is not None:
if isinstance(cache.conv_states[0], type(None)):
cache.conv_states[0] = mx.zeros((B, K-1, C))
x = mx.concatenate([cache.conv_states[0], x], axis=1)
# Process each channel independently
outputs = []
for c in range(C):
# Extract and reshape the channel
x_c = x[:, :, c] # [B, L]
x_c = mx.expand_dims(x_c, axis=1) # [B, 1, L]
# Get weight for this channel - already in correct shape [1, 1, K]
w_c = mx.expand_dims(self.weight[c], axis=0) # Ensure [1, 1, K]
# Apply convolution
y_c = mx.conv_general(
x_c,
w_c,
stride=1,
padding=self.padding
)
if self.bias is not None:
y_c = y_c + self.bias[c]
outputs.append(mx.squeeze(y_c, axis=1))
y = mx.stack(outputs, axis=-1)
# Update cache
if cache is not None:
cache.conv_states[0] = x[:, -K+1:, :] if x.shape[1] >= K else x
return y
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.chunk_size = args.chunk_size
d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads
self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias)
self.conv_dim = args.intermediate_size + 2 * args.state_size
self.conv1d = DepthWiseConv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
kernel_size=args.conv_kernel,
groups=self.conv_dim,
bias=args.use_conv_bias,
padding=args.conv_kernel - 1
)
self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range
self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range
self.D = mx.random.normal((args.num_heads,)) * args.initializer_range
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
if args.rescale_prenorm_residual:
layer_scale = math.sqrt(1.0 / args.num_hidden_layers)
self.out_proj.weight = self.out_proj.weight * layer_scale
def __call__(self, u: mx.array, cache=None):
# Expect input shape: [batch_size, 1, hidden_size]
batch_size, seq_len, _ = u.shape
pad_size = self.chunk_size - (seq_len % self.chunk_size)
# Initialize states if needed
if cache.conv_states[0] is None:
cache.conv_states[0] = mx.zeros((
batch_size,
self.args.conv_kernel - 1,
self.conv_dim
))
if cache.ssm_states[0] is None:
cache.ssm_states[0] = mx.zeros((
batch_size,
self.args.num_heads,
self.args.head_dim,
self.args.state_size
))
# Project input
zxbcdt = self.in_proj(u)
# Split projections
z = zxbcdt[:, :, :self.args.intermediate_size]
xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size]
dt = zxbcdt[:, :, -(self.args.num_heads):]
# Process delta time
dt = mx.reshape(dt, (batch_size, seq_len, self.args.num_heads))
dt = mx.squeeze(dt, axis=0) # Remove sequence dimension for single token
dt = mx.clip(
nn.softplus(dt + self.dt_bias),
self.args.time_step_min,
self.args.time_step_max
)
dt = mx.maximum(dt, self.args.time_step_floor)
# Convolution step
xBC = self.conv1d(xBC, cache=cache)
xBC = silu(xBC)
# Split conv output
x = xBC[:, :, :self.args.intermediate_size]
B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size]
C = xBC[:, :, -self.args.state_size:]
# Reshape for SSM
x = mx.reshape(x, (batch_size, 1, self.args.num_heads, self.args.head_dim))
x = mx.squeeze(x, axis=1)
B = mx.reshape(B, (batch_size, 1, self.args.state_size))
B = mx.broadcast_to(B, (batch_size, self.args.num_heads, self.args.state_size))
B = mx.expand_dims(B, axis=2)
C = mx.reshape(C, (batch_size, 1, self.args.state_size))
C = mx.broadcast_to(C, (batch_size, self.args.num_heads, self.args.state_size))
C = mx.expand_dims(C, axis=3)
# SSM state update
A = -mx.exp(self.A_log)
dA = mx.exp(dt * mx.expand_dims(A, 0))
dA = mx.expand_dims(mx.expand_dims(dA, -1), -1)
x = mx.expand_dims(x, axis=3)
dBx = mx.matmul(x, B)
cache.ssm_states[0] = cache.ssm_states[0] * dA + dBx
# Output computation
y = mx.matmul(cache.ssm_states[0], C)
y = mx.squeeze(y, axis=-1)
# y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1)
if pad_size > 0:
y = y[:, :seq_len, :, :]
# Final reshape and projections
y = mx.reshape(y, (batch_size, 1, self.args.num_heads * self.args.head_dim))
y = self.norm(y + z)
return self.out_proj(y)
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.residual_in_fp32 = args.residual_in_fp32
self.mixer = Mamba2Block(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
if self.residual_in_fp32:
x = x.astype(mx.float32)
return self.mixer(self.norm(x), cache) + x
class Mamba2(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(self, x: mx.array, cache):
x = self.embeddings(x)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
x = layer(x, c)
return self.norm_f(x)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba2(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None):
B, T = inputs.shape
x = self.backbone(inputs, cache)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(x)
else:
logits = self.lm_head(x)
return logits
def make_cache(self, batch_size=1):
return [Mamba2Cache(batch_size, self.args.conv_kernel) for _ in range(len(self.layers))]
def sanitize(self, weights):
sanitized = {}
for k, v in weights.items():
if "conv1d.weight" in k:
# Ensure weights are in correct shape (channels, 1, kernel_size)
if v.ndim == 2:
v = mx.expand_dims(v, axis=1)
elif v.ndim == 1:
v = mx.expand_dims(mx.expand_dims(v, axis=0), axis=0)
sanitized[k] = v
else:
sanitized[k] = v
return sanitized
@property
def layers(self):
return self.backbone.layers