Fixed streaming generation and got rid of generating gibberish, but is still a litle slow: 0.222 tokens-per-sec

This commit is contained in:
Goekdeniz-Guelmez
2024-11-21 22:01:28 +01:00
parent e4eae973e8
commit e22b2dbf27
6 changed files with 884 additions and 1015 deletions

View File

@@ -419,11 +419,9 @@ class RotatingKVCache(_BaseCache):
raise NotImplementedError("RotatingKVCache Quantization NYI")
class MambaCache:
class MambaCache(_BaseCache):
def __init__(self):
# [conv_state, ssm_state]
self.cache = [None, None]
self.offset = 0 # Sliding window caching
def __setitem__(self, idx, value):
self.cache[idx] = value
@@ -438,15 +436,3 @@ class MambaCache:
@state.setter
def state(self, v):
self.cache = v
@property
def conv_states(self):
return [self.cache[0]]
@property
def ssm_states(self):
return [self.cache[1]]
def reset(self):
self.cache = [None, None]
self.offset = 0

View File

@@ -1,449 +1,437 @@
# coding=utf-8
# Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch MAMBA2 model."""
"""
mamba2-minimal
==============
import math
A minimal, single-file implementation of the Mamba-2 model in PyTorch.
> **Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality**
> Authors: Tri Dao, Albert Gu
> Paper: https://arxiv.org/abs/2405.21060
"""
import json
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from typing import Iterable, NamedTuple, TypeAlias, cast
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import LongTensor, Tensor, nn
logger = logging.get_logger(__name__)
Device: TypeAlias = str | torch.device | None
def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int):
"""
Padding x tensor with `pad_size` on the seq_len dim (dim=1)
@dataclass
class Mamba2Config:
d_model: int # model dimension (D)
n_layer: int = 24 # number of Mamba-2 layers in the language model
d_state: int = 128 # state dimension (N)
d_conv: int = 4 # convolution kernel size
expand: int = 2 # expansion factor (E)
headdim: int = 64 # head dimension (P)
chunk_size: int = 64 # matrix partition size (Q)
vocab_size: int = 50277
pad_vocab_size_multiple: int = 16
Assumes that we only have tensors of either size 4 or 3
"""
pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0)
return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0)
def __post_init__(self):
self.d_inner = self.expand * self.d_model
assert self.d_inner % self.headdim == 0
self.nheads = self.d_inner // self.headdim
if self.vocab_size % self.pad_vocab_size_multiple != 0:
self.vocab_size += (
self.pad_vocab_size_multiple
- self.vocab_size % self.pad_vocab_size_multiple
)
def reshape_into_chunks(input_tensor, pad_size, chunk_size):
"""
Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and
simultaneously splitting it into chunk sequences.
class InferenceCache(NamedTuple):
conv_state: Tensor # (batch, d_inner + 2 * d_state, d_conv)
ssm_state: Tensor # (batch, nheads, headdim, d_state)
Assumes that we only have tensors of either size 4 or 3
"""
# [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...]
input_tensor = pad_tensor_by_size(input_tensor, pad_size)
if len(input_tensor.shape) == 3:
# [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads]
return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2])
else:
# [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size]
return input_tensor.reshape(
input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3]
@staticmethod
def alloc(batch_size: int, args: Mamba2Config, device: Device = None):
return InferenceCache(
torch.zeros(
batch_size, args.d_inner + 2 * args.d_state, args.d_conv, device=device
),
torch.zeros(
batch_size, args.nheads, args.headdim, args.d_state, device=device
),
)
def segment_sum(input_tensor):
"""
More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.
"""
chunk_size = input_tensor.size(-1)
# 1. expand input tensor to have an additional dimension and repeat along that dimension
# [..., chunk_size] -> [..., chunk_size, chunk_size]
input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size)
# 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag
mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1)
input_tensor = input_tensor.masked_fill(~mask, 0)
# 3. compute actual cumsum
tensor_segsum = torch.cumsum(input_tensor, dim=-2)
# 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time)
mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0)
tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf)
return tensor_segsum
class Mamba2Cache:
"""
Arguments:
config: ModelArgs
batch_size: int
dtype: torch.dtype
device: torch.device
Attributes:
seqlen_offset: int
dtype: torch.dtype
conv_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, conv_kernel]
ssm_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, ssm_state_size]
"""
def __init__(
self, config: ModelArgs, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None
):
self.seqlen_offset = 0
self.dtype = dtype
self.conv_kernel = config.conv_kernel
self.intermediate_size = int(config.expand * config.hidden_size)
self.conv_states = {
i: torch.zeros(
batch_size,
self.intermediate_size + 2 * config.n_groups * config.state_size,
self.conv_kernel,
device=device,
dtype=dtype,
)
for i in range(config.num_hidden_layers)
}
self.ssm_states = {
i: torch.zeros(
batch_size, config.num_heads, config.head_dim, config.state_size, device=device, dtype=dtype
)
for i in range(config.num_hidden_layers)
}
def update_conv_state(
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
) -> torch.Tensor:
conv_state = self.conv_states[layer_idx]
cache_position = cache_position.clamp(0, self.conv_kernel - 1)
conv_state = conv_state.roll(shifts=-1, dims=-1)
conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
self.conv_states[layer_idx].zero_()
self.conv_states[layer_idx] += conv_state
return self.conv_states[layer_idx]
def reset(self):
self.conv_states.zero_()
self.ssm_states.zero_()
class MambaRMSNormGated(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6):
class Mamba2LMHeadModel(nn.Module):
def __init__(self, args: Mamba2Config, device: Device = None):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.args = args
self.device = device
def forward(self, hidden_states, gate=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states
self.backbone = nn.ModuleDict(
dict(
embedding=nn.Embedding(args.vocab_size, args.d_model, device=device),
layers=nn.ModuleList(
[
nn.ModuleDict(
dict(
mixer=Mamba2(args, device=device),
norm=RMSNorm(args.d_model, device=device),
)
)
for _ in range(args.n_layer)
]
),
norm_f=RMSNorm(args.d_model, device=device),
)
)
self.lm_head = nn.Linear(
args.d_model, args.vocab_size, bias=False, device=device
)
self.lm_head.weight = self.backbone.embedding.weight
if gate is not None:
hidden_states = hidden_states * nn.functional.silu(gate)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
@staticmethod
def from_pretrained(huggingface_model_id: str, device: Device = None):
from transformers.utils import CONFIG_NAME, WEIGHTS_NAME
from transformers.utils.hub import cached_file
return self.weight * hidden_states
config_path = cached_file(huggingface_model_id, CONFIG_NAME)
assert config_path, "Failed to get huggingface config file"
state_dict_path = cached_file(huggingface_model_id, WEIGHTS_NAME)
assert state_dict_path, "Failed to get huggingface state dict file"
config = json.load(open(config_path))
args = Mamba2Config(
d_model=config["d_model"],
n_layer=config["n_layer"],
vocab_size=config["vocab_size"],
pad_vocab_size_multiple=config["pad_vocab_size_multiple"],
)
map_location = "cpu" if device is None else device
state_dict = torch.load(
state_dict_path, weights_only=True, map_location=map_location, mmap=True
)
model = Mamba2LMHeadModel(args, device=device)
model.load_state_dict(state_dict)
model.eval()
return model
def forward(
self, input_ids: LongTensor, h: list[InferenceCache] | list[None] | None = None
) -> tuple[LongTensor, list[InferenceCache]]:
"""
Arguments
input_ids: (batch, seqlen) tokens from `EleutherAI/gpt-neox-20b` tokenizer
h: hidden states for inference step. If present the constant-time
(wrt sequence length) inference path will be taken, input_ids
should have shape (batch, 1) containing the next batch of prompt
token.
Return (logits, h)
logits: (batch, seqlen, vocab_size)
h: updated inference cache after processing `input_ids`
"""
seqlen = input_ids.shape[1]
if h is None:
h = [None for _ in range(self.args.n_layer)]
x = self.backbone.embedding(input_ids)
for i, layer in enumerate(self.backbone.layers):
y, h[i] = layer.mixer(layer.norm(x), h[i])
x = y + x
x = self.backbone.norm_f(x)
logits = self.lm_head(x)
return logits[:, :seqlen], cast(list[InferenceCache], h)
def generate(
self,
input_ids: LongTensor,
max_new_length: int = 20,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 1.0,
eos_token_id: int = 0,
) -> Iterable[tuple[int, list[InferenceCache]]]:
prefix, tokens = input_ids[:-1], input_ids[-1:].unsqueeze(0)
# Process prompt
# The input sequence to forward (non-inference path) must have length multiple that of chunk_size.
# We split out excess tokens so that n_chunked tokens can be processed by one forward call and
# process the rest in multiple inference steps.
n_chunked = (prefix.shape[0] // self.args.chunk_size) * self.args.chunk_size
if n_chunked > 0:
_, h = self(prefix[:n_chunked].unsqueeze(0), None)
else:
h = [
InferenceCache.alloc(1, self.args, device=self.device)
for _ in range(self.args.n_layer)
]
for i in range(n_chunked, prefix.shape[0]):
_, h = self(prefix[i : i + 1].unsqueeze(0), h)
# Generate
for _ in range(max_new_length):
with torch.no_grad():
out, h = self(tokens, h)
logits = out[0, -1]
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, k=top_k)[0][-1]
logits[indices_to_remove] = -torch.inf
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > 0.5
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
sorted_indices_to_remove[0] = False
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = -torch.inf
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
if next_token.item() == eos_token_id:
return
tokens = next_token.unsqueeze(0)
yield cast(int, next_token.item()), h
class Mamba2Mixer(nn.Module):
def __init__(self, config: ModelArgs):
class Mamba2(nn.Module):
def __init__(self, args: Mamba2Config, device: Device = None):
super().__init__()
self.num_heads = config.num_heads
self.hidden_size = config.hidden_size
self.state_size = config.state_size
self.conv_kernel = config.conv_kernel
self.intermediate_size = int(config.expand * self.hidden_size)
self.time_step_rank = int(config.time_step_rank)
self.use_conv_bias = config.use_conv_bias
self.args = args
self.device = device
self.layer_norm_epsilon = config.layer_norm_epsilon
# Order: (z, x, B, C, dt)
d_in_proj = 2 * args.d_inner + 2 * args.d_state + args.nheads
self.in_proj = nn.Linear(args.d_model, d_in_proj, bias=False, device=device)
self.n_groups = config.n_groups
self.head_dim = config.head_dim
self.chunk_size = config.chunk_size
self.time_step_limit = config.time_step_limit
self.time_step_min = config.time_step_min
self.time_step_max = config.time_step_max
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size
conv_dim = args.d_inner + 2 * args.d_state
self.conv1d = nn.Conv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
bias=config.use_conv_bias,
kernel_size=config.conv_kernel,
groups=self.conv_dim,
padding=config.conv_kernel - 1,
in_channels=conv_dim,
out_channels=conv_dim,
kernel_size=args.d_conv,
groups=conv_dim,
padding=args.d_conv - 1,
device=device,
)
# projection of the input hidden states
projection_size = self.intermediate_size + self.conv_dim + self.num_heads
self.in_proj = nn.Linear(
self.hidden_size,
projection_size,
bias=config.use_bias,
self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device))
self.A_log = nn.Parameter(torch.empty(args.nheads, device=device))
self.D = nn.Parameter(torch.empty(args.nheads, device=device))
self.norm = RMSNorm(args.d_inner, device=device)
self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device)
def forward(self, u: Tensor, h: InferenceCache | None = None):
"""
Arguments
u: (batch, seqlen, d_model) input. seqlen should be a multiple of chunk_size.
h: hidden states for inference step. Initialized to 0s if not present.
Return (y, h)
y: (batch, seqlen, d_model) output
h: updated inference cache after processing `u`
"""
if h:
return self.step(u, h)
A = -torch.exp(self.A_log) # (nheads,)
zxbcdt = self.in_proj(u) # (batch, seqlen, d_in_proj)
z, xBC, dt = torch.split(
zxbcdt,
[
self.args.d_inner,
self.args.d_inner + 2 * self.args.d_state,
self.args.nheads,
],
dim=-1,
)
dt = F.softplus(dt + self.dt_bias) # (batch, seqlen, nheads)
# Pad or truncate xBC seqlen to d_conv
conv_state = F.pad(
rearrange(xBC, "b l d -> b d l"), (self.args.d_conv - u.shape[1], 0)
)
self.dt_bias = torch.ones(self.num_heads)
A = torch.arange(1, self.num_heads + 1)
self.A_log = torch.log(A)
self.D = torch.ones(self.num_heads)
xBC = silu(
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, : u.shape[1], :]
) # (batch, seqlen, d_inner + 2 * d_state))
x, B, C = torch.split(
xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1
)
x = rearrange(x, "b l (h p) -> b l h p", p=self.args.headdim)
y, ssm_state = ssd(
x * dt.unsqueeze(-1),
A * dt,
rearrange(B, "b l n -> b l 1 n"),
rearrange(C, "b l n -> b l 1 n"),
self.args.chunk_size,
device=self.device,
)
y = y + x * self.D.unsqueeze(-1)
y = rearrange(y, "b l h p -> b l (h p)")
y = self.norm(y, z)
y = self.out_proj(y)
self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon)
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
h = InferenceCache(conv_state, ssm_state)
return y, h
def forward(self, input_states, cache: Optional[Mamba2Cache]=None):
batch_size, seq_len, _ = input_states.shape
# Gated MLP's linear projection
projected_states = self.in_proj(input_states.squeeze(1))
d_mlp = (
projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.state_size- self.num_heads) // 2
_, _, gate, hidden_states, dt = projected_states.split(
[d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
def step(self, u: Tensor, h: InferenceCache) -> tuple[Tensor, InferenceCache]:
"""Take a single inference step for the current input and hidden state
Unlike attention-based models, RNN-based models (eg Mamba) does not need
to look back at all the past tokens to generate a new token. Instead a
hidden state (initialized to 0s initially) is updated for each input and
passed to the next inference step. This means that the total inference
time is linear with respect to the sequence length instead of quadratic
in attention's case.
Arguments
u: (batch, 1, d_model)
h: initial/running hidden state
Return (y, h)
y: (batch, 1, d_model)
h: updated hidden state
"""
assert u.shape[1] == 1, "Only one token can be decoded per inference step"
zxbcdt = self.in_proj(u.squeeze(1)) # (batch, d_in_proj)
z, xBC, dt = torch.split(
zxbcdt,
[
self.args.d_inner,
self.args.d_inner + 2 * self.args.d_state,
self.args.nheads,
],
dim=-1,
)
# Convolution sequence transformation
ssm_state = cache.ssm_states[self.layer_idx].clone()
ssm_state = ssm_state.to(hidden_states.device)
# Advance convolution input
h.conv_state.copy_(torch.roll(h.conv_state, shifts=-1, dims=-1))
h.conv_state[:, :, -1] = xBC
# Convolution step
xBC = torch.sum(
h.conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
)
xBC += self.conv1d.bias
xBC = silu(xBC)
if cache.seqlen_offset > 0:
conv_state = cache.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel]
conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
x, B, C = torch.split(
xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1
)
A = -torch.exp(self.A_log) # (nheads,)
# handle batched generation - states are copied through
conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states
cache.conv_states[self.layer_idx].copy_(conv_state)
hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
# SSM step
dt = F.softplus(dt + self.dt_bias) # (batch, nheads)
dA = torch.exp(dt * A) # (batch, nheads)
x = rearrange(x, "b (h p) -> b h p", p=self.args.headdim)
dBx = torch.einsum("bh, bn, bhp -> bhpn", dt, B, x)
h.ssm_state.copy_(h.ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
y = torch.einsum("bhpn, bn -> bhp", h.ssm_state, C)
y = y + rearrange(self.D, "h -> h 1") * x
y = rearrange(y, "b h p -> b (h p)")
y = self.norm(y, z)
y = self.out_proj(y)
if self.use_conv_bias:
hidden_states += self.conv1d.bias
hidden_states = nn.silu(hidden_states)[:, None, ...] # [batch, 1, intermediate_size] : decoding
else:
hidden_states = hidden_states.transpose(1,2)
conv_state = nn.functional.pad(
hidden_states,
(self.conv_kernel - hidden_states.shape[-1], 0)
)
cache.conv_states[self.layer_idx].copy_(conv_state)
hidden_states = nn.silu(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len]
hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.state_size, self.n_groups * self.state_size], dim=-1)
A = -torch.exp(self.A_log.float()) # [num_heads]
if cache is not None and cache.seqlen_offset > 0:
# Note: there is no need to pad parameter matrices here, as there is just one new token
# for batched generation
dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...]
dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim)
# [num_heads] -> [num_heads, head_dim]
dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)
dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype))
dt = torch.clamp(dt, self.time_step_min) #, self.time_step_max)
A = A[..., None, None].expand(self.num_heads, self.head_dim, self.state_size).to(dtype=torch.float32)
# [bsz, num_heads, head_dim, state_size]
dA = torch.exp(dt[..., None] * A)
# Discretize B
# [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
# -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size]
B = B.reshape(batch_size, self.n_groups, -1)[..., None, :]
B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous()
B = B.reshape(batch_size, -1, B.shape[-1])
# [bsz, num_heads, head_dim, state_size]
dB = dt[..., None] * B[..., None, :]
# Discretize x into dB
# [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
dBx = dB * hidden_states[..., None]
# State calculation
cache.ssm_states[self.layer_idx].copy_(
cache.ssm_states[self.layer_idx] * dA + dBx
)
# Subsequent output
# [bsz, n_groups * state_size] -> [bsz, num_heads, state_size]
C = C.reshape(batch_size, self.n_groups, -1)[..., None, :]
C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous()
C = C.reshape(batch_size, -1, C.shape[-1])
# [bsz, num_heads, head_dim]
ssm_states = cache.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n]
# Reshape ssm_states to merge the first two dimensions
ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.state_size) # Shape: [b*h, d, n]
C_reshaped = C.view(batch_size * self.num_heads, self.state_size, 1) # Shape: [b*h, n, 1]
y = torch.bmm(ssm_states_reshaped, C_reshaped)
y = y.view(batch_size, self.num_heads, self.head_dim)
# D skip connection
# [num_heads] -> [num_heads, head_dim]
D = self.D[..., None].expand(self.D.shape[0], self.head_dim)
y = (y + hidden_states * D).to(y.dtype)
# [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size]
y = y.reshape(batch_size, -1)[:, None, ...]
else:
# begin ssd naive implementation without einsums
dt = nn.functional.softplus(dt + self.dt_bias)
dt = torch.clamp(dt, self.time_step_min)
hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
B = B.reshape(batch_size, seq_len, -1, self.state_size).float()
C = C.reshape(batch_size, seq_len, -1, self.state_size).float()
B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
pad_size = self.chunk_size - (seq_len % self.chunk_size)
D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
# Discretize x and A
hidden_states = hidden_states * dt[..., None]
A = A.to(hidden_states.dtype) * dt
# Rearrange into blocks/chunks
hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]
return y.unsqueeze(1), h
# [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size]
A = A.permute(0, 3, 1, 2)
A_cumsum = torch.cumsum(A, dim=-1)
def segsum(x: Tensor, device: Device = None) -> Tensor:
"""Stable segment sum calculation.
# 1. Compute the output for each intra-chunk (diagonal blocks)
# This is the analog of a causal mask
L = torch.exp(segment_sum(A))
`exp(segsum(A))` produces a 1-semiseparable matrix, which is equivalent to a scalar SSM.
# First, contraction of C and B to get G (attention-weights like)
G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n)
G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)
Source: https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L23-L32
"""
T = x.size(-1)
x = repeat(x, "... d -> ... d e", e=T)
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=-1)
x = x.masked_fill(~mask, 0)
x_segsum = torch.cumsum(x, dim=-2)
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=0)
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
return x_segsum
# Step 2: Compute M, equivalent to applying attention mask to weights
M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
M = M_intermediate.sum(dim=-1)
def ssd(x, A, B, C, chunk_size, initial_states=None, device: Device = None):
"""Structed State Space Duality (SSD) - the core of Mamba-2
# Step 3: Compute Y_diag (apply to values)
Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3)
This is almost the exact same minimal SSD code from the blog post.
# (right term of low-rank factorization of off-diagonal blocks; B terms)
Arguments
x: (batch, seqlen, n_heads, d_head)
A: (batch, seqlen, n_heads)
B: (batch, seqlen, n_heads, d_state)
C: (batch, seqlen, n_heads, d_state)
decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None]
# permute back B * decay states
states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3)
if cache is not None and cache.seqlen_offset > 0:
previous_states = cache.ssm_states[self.layer_idx][:, None, ...]
else:
previous_states = torch.zeros_like(states[:, :1])
states = torch.cat([previous_states, states], dim=1)
decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
Return
y: (batch, seqlen, n_heads, d_head)
states_permuted = states.permute(0, 2, 1, 3, 4)
result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2)
new_states = result.permute(0, 2, 1, 3, 4)
states, ssm_state = new_states[:, :-1], new_states[:, -1]
Source
1. https://tridao.me/blog/2024/mamba2-part3-algorithm/
2. https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L34-L78
"""
assert x.shape[1] % chunk_size == 0
# Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = torch.exp(A_cumsum)
# compute Yoff
C_times_states = (C[..., None, :] * states[:, :, None, ...])
state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
# Rearrange into chunks
# Step 1, 2 and 4 of SSD can be computed in parallel for each chunk across devices (sequence parallel)
# This is not implemented and left as an exercise for the reader 😜
x, A, B, C = [
rearrange(m, "b (c l) ... -> b c l ...", l=chunk_size) for m in (x, A, B, C)
]
y = Y_diag + Y_off
# [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
A = rearrange(A, "b c l h -> b h c l")
A_cumsum = torch.cumsum(A, dim=-1)
y = y + D_residual
# Cutting off padded chunks
if pad_size > 0:
y = y[:, :seq_len, :, :]
y = y.reshape(batch_size, seq_len, -1)
if ssm_state is not None and cache is not None:
cache.ssm_states[self.layer_idx].copy_(ssm_state)
# 1. Compute the output for each intra-chunk (diagonal blocks)
L = torch.exp(segsum(A, device=device))
Y_diag = torch.einsum("bclhn, bcshn, bhcls, bcshp -> bclhp", C, B, L, x)
scan_output = self.norm(y, gate)
# end ssd naive
# 2. Compute the state for each intra-chunk
# (right term of low-rank factorization of off-diagonal blocks; B terms)
decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
states = torch.einsum("bclhn, bhcl, bclhp -> bchpn", B, decay_states, x)
# 4. Final linear projection
contextualized_states = self.out_proj(scan_output) # [batch, seq_len, hidden_size]
return contextualized_states
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
if initial_states is None:
initial_states = torch.zeros_like(states[:, :1])
states = torch.cat([initial_states, states], dim=1)
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)), device=device))
new_states = torch.einsum("bhzc, bchpn -> bzhpn", decay_chunk, states)
states, final_state = new_states[:, :-1], new_states[:, -1]
# 4. Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = torch.exp(A_cumsum)
Y_off = torch.einsum("bclhn, bchpn, bhcl -> bclhp", C, states, state_decay_out)
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
return Y, final_state
class Mamba2Block(nn.Module):
def __init__(self, config):
class RMSNorm(nn.Module):
def __init__(self, d: int, eps: float = 1e-5, device: Device = None):
"""Gated Root Mean Square Layer Normalization
Paper: https://arxiv.org/abs/1910.07467
"""
super().__init__()
self.config = config
self.norm = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mixer = Mamba2Mixer(config)
self.eps = eps
self.weight = nn.Parameter(torch.ones(d, device=device))
def forward(
self,
hidden_states,
cache: Optional[Mamba2Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
):
x = self.mixer(
self.norm(hidden_states), cache=cache, cache_position=cache_position
)
return x + hidden_states
def forward(self, x, z=None):
if z is not None:
x = x * silu(z)
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
class Mamba2Model(nn.Module):
def __init__(self, config):
super().__init__(config)
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([Mamba2Block(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
def silu(x):
"""Applies the Sigmoid Linear Unit (SiLU), element-wise.
self.gradient_checkpointing = False
self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
cache: Optional[Mamba2Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.embeddings(input_ids)
hidden_states = inputs_embeds
for mixer_block in self.layers:
hidden_states = mixer_block(
hidden_states,
cache=cache,
cache_position=cache_position,
)
cache.seqlen_offset += inputs_embeds.shape[1]
return self.norm_f(hidden_states), cache
class Mamba2ForCausalLM(nn.Module):
def __init__(self, config):
super().__init__(config)
self.backbone = Mamba2Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
cache: Optional[Mamba2Cache] = None,
cache_position: Optional[torch.Tensor] = None,
):
out, cache = self.backbone(
input_ids,
cache=cache,
cache_position=cache_position,
)
logits = self.lm_head(out)
return logits, cache
Define this manually since torch's version doesn't seem to work on MPS.
"""
return x * F.sigmoid(x)

View File

@@ -0,0 +1,319 @@
import math
from dataclasses import dataclass, field
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .cache import MambaCache
@dataclass
class ModelArgs(BaseModelArgs):
num_heads: int
head_dim: int
vocab_size: int
hidden_size: int
state_size: int
num_hidden_layers: int
layer_norm_epsilon: float
expand: int
conv_kernel: int
n_groups: int
use_bias: bool
use_conv_bias: bool
initializer_range: float
residual_in_fp32: bool
time_step_min: float
time_step_max: float
time_step_floor: float
rescale_prenorm_residual: bool
rms_norm: bool
chunk_size: int
tie_word_embeddings: bool
use_cache: bool = True
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
time_step_rank: Union[int, str] = "auto"
model_type: str = "mamba2"
def __post_init__(self):
if not hasattr(self, "intermediate_size"):
self.intermediate_size = int(self.expand * self.hidden_size)
if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_heads
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = mx.ones((hidden_size,))
self.variance_epsilon = eps
def __call__(self, hidden_states, gate=None):
if gate is not None:
hidden_states = hidden_states * nn.silu(gate)
variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True)
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
def silu(x):
return x * mx.sigmoid(x)
def ssd(x, A, B, C, chunk_size):
batch, seqlen, nheads, dim = x.shape
B = mx.expand_dims(B, axis=2)
C = mx.expand_dims(C, axis=2)
state = mx.zeros((batch, nheads, dim, B.shape[-1]))
outputs = []
for i in range(0, seqlen, chunk_size):
chunk = slice(i, min(i + chunk_size, seqlen))
dA = mx.exp(mx.expand_dims(A[chunk], axis=0))
x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim]
x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size]
B_chunk = B[:, chunk] # [batch, chunk_size, state_size]
dBx = mx.matmul(x_chunk, B_chunk) # [batch, nheads, dim, state_size]
state = state * mx.expand_dims(dA, axis=-1) + dBx
C_chunk = C[:, chunk] # [batch, chunk_size, state_size]
y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size]
y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim]
outputs.append(y)
return mx.concatenate(outputs, axis=1), state
class DepthWiseConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = padding
self.groups = groups if groups is not None else in_channels
assert in_channels == out_channels, "In and out channels must be same for depthwise convolution"
assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution"
self.weight = mx.random.normal((in_channels, 1, kernel_size))
self.bias = mx.zeros((out_channels,)) if bias else None
def __call__(self, x: mx.array, cache=None) -> mx.array:
B, L, C = x.shape
K = self.kernel_size
assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}"
if cache is not None:
# Access conv_state directly from cache[0]
if cache[0] is None:
cache[0] = mx.zeros((B, K-1, C))
x = mx.concatenate([cache[0], x], axis=1)
outputs = []
for c in range(C):
x_c = x[:, :, c]
x_c = mx.expand_dims(x_c, axis=1)
w_c = self.weight[c]
if w_c.ndim == 2:
w_c = mx.expand_dims(w_c, axis=0)
elif w_c.ndim == 1:
w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0)
y_c = mx.conv_general(
x_c,
w_c,
stride=1,
padding=0
)
if self.bias is not None:
y_c = y_c + self.bias[c]
y_c = mx.squeeze(y_c, axis=1)
outputs.append(y_c)
y = mx.stack(outputs, axis=-1)
# Update cache directly using cache[0]
if cache is not None:
cache[0] = x[:, -K+1:, :] if x.shape[1] >= K else x
return y
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads
self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias)
conv_dim = args.intermediate_size + 2 * args.state_size
self.conv1d = DepthWiseConv1d(
in_channels=conv_dim,
out_channels=conv_dim,
kernel_size=args.conv_kernel,
groups=conv_dim,
bias=args.use_conv_bias,
padding=args.conv_kernel - 1
)
self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range
self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range
self.D = mx.random.normal((args.num_heads,)) * args.initializer_range
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
if args.rescale_prenorm_residual:
layer_scale = math.sqrt(1.0 / args.num_hidden_layers)
self.out_proj.weight = self.out_proj.weight * layer_scale
def __call__(self, u: mx.array, cache=None):
batch_size, seq_len, dimension = u.shape
# Initialize the negative A matrix
A = -mx.exp(self.A_log)
# Process sequence in chunks if needed
outputs = []
current_cache = cache
for i in range(seq_len):
# Extract current token
current_input = u[:, i:i+1, :]
# Initialize cache states if needed
if current_cache[0] is None: # conv state
conv_dim = self.args.intermediate_size + 2 * self.args.state_size
current_cache[0] = mx.zeros((batch_size, self.args.conv_kernel - 1, conv_dim))
if current_cache[1] is None: # ssm state
current_cache[1] = mx.zeros((
batch_size,
self.args.num_heads,
self.args.head_dim,
self.args.state_size
))
# Project input
zxbcdt = self.in_proj(current_input)
n_heads = self.args.num_heads
z = zxbcdt[:, :, :self.args.intermediate_size]
xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size]
dt = zxbcdt[:, :, -(n_heads):]
# Process time steps
dt = mx.reshape(dt, (batch_size, n_heads))
dt = mx.clip(nn.softplus(dt + self.dt_bias), self.args.time_step_min, self.args.time_step_max)
dt = mx.maximum(dt, self.args.time_step_floor)
# Apply convolution
xBC = self.conv1d(xBC, cache=current_cache)
xBC = silu(xBC)
# Split states
x = xBC[:, :, :self.args.intermediate_size]
B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size]
C = xBC[:, :, -self.args.state_size:]
# Reshape for SSM
x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim))
x = mx.squeeze(x, axis=1)
B = mx.reshape(B, (batch_size, 1, self.args.state_size))
B = mx.broadcast_to(B, (batch_size, n_heads, self.args.state_size))
B = mx.expand_dims(B, axis=2)
C = mx.reshape(C, (batch_size, 1, self.args.state_size))
C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size))
C = mx.expand_dims(C, axis=3)
# SSM updates
dA = mx.exp(dt * mx.expand_dims(A, 0))
dA = mx.expand_dims(mx.expand_dims(dA, -1), -1)
# Update state
x = mx.expand_dims(x, axis=3)
dBx = mx.matmul(x, B)
current_cache[1] = current_cache[1] * dA + dBx
# Compute output
y = mx.matmul(current_cache[1], C)
y = mx.squeeze(y, axis=-1)
y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1)
y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim))
y = self.norm(y + z)
outputs.append(self.out_proj(y))
# Concatenate all outputs
return mx.concatenate(outputs, axis=1)
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.residual_in_fp32 = args.residual_in_fp32
self.mixer = Mamba2Block(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
if self.residual_in_fp32:
x = x.astype(mx.float32)
normed = self.norm(x)
output = self.mixer(normed, cache)
return output + x
class Mamba2(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(self, x: mx.array, cache):
x = self.embeddings(x)
if cache is None:
cache = [None] * len(self.layers)
hidden = x
for layer, c in zip(self.layers, cache):
hidden = layer(hidden, c)
return self.norm_f(hidden)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba2(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None):
hidden = self.backbone(inputs, cache)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(hidden)
else:
logits = self.lm_head(hidden)
return logits
def make_cache(self):
return [MambaCache() for _ in range(len(self.layers))]
@property
def layers(self):
return self.backbone.layers

View File

@@ -5,7 +5,7 @@ import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .cache import Mamba2Cache
from .cache import MambaCache
@dataclass
class ModelArgs(BaseModelArgs):
@@ -62,7 +62,6 @@ def silu(x):
return x * mx.sigmoid(x)
def ssd(x, A, B, C, chunk_size):
# Replace einsum operations with explicit reshape and matrix multiply
batch, seqlen, nheads, dim = x.shape
B = mx.expand_dims(B, axis=2)
C = mx.expand_dims(C, axis=2)
@@ -74,7 +73,6 @@ def ssd(x, A, B, C, chunk_size):
chunk = slice(i, min(i + chunk_size, seqlen))
dA = mx.exp(mx.expand_dims(A[chunk], axis=0))
# Replace einsum with explicit operations
x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim]
x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size]
B_chunk = B[:, chunk] # [batch, chunk_size, state_size]
@@ -82,7 +80,6 @@ def ssd(x, A, B, C, chunk_size):
state = state * mx.expand_dims(dA, axis=-1) + dBx
# Replace einsum with explicit operations
C_chunk = C[:, chunk] # [batch, chunk_size, state_size]
y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size]
y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim]
@@ -112,13 +109,13 @@ class DepthWiseConv1d(nn.Module):
assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}"
if cache is not None and cache.conv_states[0] is not None:
# Convert None to proper array if needed
if isinstance(cache.conv_states[0], type(None)):
cache.conv_states[0] = mx.zeros((B, K-1, C))
x = mx.concatenate([cache.conv_states[0], x], axis=1)
if cache is not None:
# Access conv_state directly from cache[0]
if cache[0] is None:
cache[0] = mx.zeros((B, K-1, C))
x = mx.concatenate([cache[0], x], axis=1)
# Process each channel independently
outputs = []
for c in range(C):
x_c = x[:, :, c]
@@ -130,24 +127,23 @@ class DepthWiseConv1d(nn.Module):
elif w_c.ndim == 1:
w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0)
# Apply convolution
y_c = mx.conv_general(
x_c,
w_c,
stride=1,
padding=0
)
if self.bias is not None:
y_c = y_c + self.bias[c]
outputs.append(mx.squeeze(y_c, axis=1))
y_c = mx.squeeze(y_c, axis=1)
outputs.append(y_c)
y = mx.stack(outputs, axis=-1)
# Update cache
# Update cache directly using cache[0]
if cache is not None:
cache.conv_states[0] = x[:, -K+1:, :] if x.shape[1] >= K else x
cache[0] = x[:, -K+1:, :] if x.shape[1] >= K else x
return y
@@ -182,98 +178,80 @@ class Mamba2Block(nn.Module):
self.out_proj.weight = self.out_proj.weight * layer_scale
def __call__(self, u: mx.array, cache=None):
batch_size = u.shape[0]
seq_len = u.shape[1]
outputs = []
batch_size, seq_len, dimension = u.shape
assert seq_len == 1, "Input should be a single token"
# Initialize states if needed
if cache.conv_states[0] is None:
# Initialize cache states directly using indices
if cache[0] is None: # conv state
conv_dim = self.args.intermediate_size + 2 * self.args.state_size
cache.conv_states[0] = mx.zeros((
batch_size,
self.args.conv_kernel - 1,
conv_dim
))
cache[0] = mx.zeros((batch_size, self.args.conv_kernel - 1, conv_dim))
if cache.ssm_states[0] is None:
cache.ssm_states[0] = mx.zeros((
if cache[1] is None: # ssm state
cache[1] = mx.zeros((
batch_size,
self.args.num_heads,
self.args.head_dim,
self.args.state_size
))
for pos in range(seq_len):
u_t = u[:, pos:pos+1, :]
zxbcdt = self.in_proj(u_t)
zxbcdt = self.in_proj(u)
n_heads = self.args.num_heads
n_heads = self.args.num_heads
z = zxbcdt[:, :, :self.args.intermediate_size]
xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size]
dt = zxbcdt[:, :, -(n_heads):]
z = zxbcdt[:, :, :self.args.intermediate_size]
xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size]
dt = zxbcdt[:, :, -(n_heads):]
dt = mx.reshape(dt, (batch_size, n_heads))
dt = mx.clip(nn.softplus(dt + self.dt_bias), self.args.time_step_min, self.args.time_step_max)
dt = mx.maximum(dt, self.args.time_step_floor)
dt = mx.reshape(dt, (batch_size, n_heads))
dt = mx.clip(
nn.softplus(dt + self.dt_bias),
self.args.time_step_min,
self.args.time_step_max
)
dt = mx.maximum(dt, self.args.time_step_floor)
xBC = self.conv1d(xBC, cache=cache)
xBC = silu(xBC)
xBC = self.conv1d(xBC, cache=cache)
xBC = silu(xBC)
x = xBC[:, :, :self.args.intermediate_size]
B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size]
C = xBC[:, :, -self.args.state_size:]
x = xBC[:, :, :self.args.intermediate_size]
B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size]
C = xBC[:, :, -self.args.state_size:]
x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim))
x = mx.squeeze(x, axis=1)
B = mx.reshape(B, (batch_size, 1, self.args.state_size))
B = mx.broadcast_to(B, (batch_size, n_heads, self.args.state_size))
B = mx.expand_dims(B, axis=2)
C = mx.reshape(C, (batch_size, 1, self.args.state_size))
C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size))
C = mx.expand_dims(C, axis=3)
x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim))
x = mx.squeeze(x, axis=1)
A = -mx.exp(self.A_log)
dA = mx.exp(dt * mx.expand_dims(A, 0))
dA = mx.expand_dims(mx.expand_dims(dA, -1), -1)
B = mx.reshape(B, (batch_size, 1, self.args.state_size))
B = mx.broadcast_to(B, (batch_size, n_heads, self.args.state_size))
B = mx.expand_dims(B, axis=2)
x = mx.expand_dims(x, axis=3)
dBx = mx.matmul(x, B)
# Update ssm state directly using cache[1]
cache[1] = cache[1] * dA + dBx
C = mx.reshape(C, (batch_size, 1, self.args.state_size))
C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size))
C = mx.expand_dims(C, axis=3)
y = mx.matmul(cache[1], C)
y = mx.squeeze(y, axis=-1)
y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1)
y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim))
y = self.norm(y + z)
A = -mx.exp(self.A_log)
dA = mx.exp(dt * mx.expand_dims(A, 0))
dA = mx.expand_dims(mx.expand_dims(dA, -1), -1)
x = mx.expand_dims(x, axis=3)
dBx = mx.matmul(x, B)
cache.ssm_states[0] = cache.ssm_states[0] * dA + dBx
y = mx.matmul(cache.ssm_states[0], C)
y = mx.squeeze(y, axis=-1)
y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1)
y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim))
y = self.norm(y + z)
y = self.out_proj(y)
outputs.append(y)
return mx.concatenate(outputs, axis=1)
return self.out_proj(y)
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.residual_in_fp32 = args.residual_in_fp32
self.mixer = Mamba2Block(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
if self.residual_in_fp32:
x = x.astype(mx.float32)
return self.mixer(self.norm(x), cache) + x
normed = self.norm(x)
output = self.mixer(normed, cache)
return output + x
class Mamba2(nn.Module):
def __init__(self, args: ModelArgs):
@@ -287,9 +265,11 @@ class Mamba2(nn.Module):
x = self.embeddings(x)
if cache is None:
cache = [None] * len(self.layers)
hidden = x
for layer, c in zip(self.layers, cache):
x = layer(x, c)
return self.norm_f(x)
hidden = layer(hidden, c)
return self.norm_f(hidden)
class Model(nn.Module):
@@ -297,40 +277,23 @@ class Model(nn.Module):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba2(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None):
B, T = inputs.shape
x = self.backbone(inputs, cache)
hidden = self.backbone(inputs, cache)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(x)
logits = self.backbone.embeddings.as_linear(hidden)
else:
logits = self.lm_head(x)
logits = self.lm_head(hidden)
return logits
def make_cache(self, batch_size=1):
return [Mamba2Cache(batch_size, self.args.conv_kernel) for _ in range(len(self.layers))]
def sanitize(self, weights):
sanitized = {}
for k, v in weights.items():
if "conv1d.weight" in k:
# Ensure weights are in correct shape (channels, 1, kernel_size)
if v.ndim == 2:
v = mx.expand_dims(v, axis=1)
elif v.ndim == 1:
v = mx.expand_dims(mx.expand_dims(v, axis=0), axis=0)
sanitized[k] = v
else:
sanitized[k] = v
return sanitized
def make_cache(self):
return [MambaCache() for _ in range(len(self.layers))]
@property
def layers(self):

View File

@@ -148,202 +148,131 @@ class DepthWiseConv1d(nn.Module):
return y
# class Mamba2Block(nn.Module):
# def __init__(self, args: ModelArgs):
# super().__init__()
# self.args = args
# d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads
# self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias)
# conv_dim = args.intermediate_size + 2 * args.state_size
# self.conv1d = DepthWiseConv1d(
# in_channels=conv_dim,
# out_channels=conv_dim,
# kernel_size=args.conv_kernel,
# groups=conv_dim,
# bias=args.use_conv_bias,
# padding=args.conv_kernel - 1
# )
# self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range
# self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range
# self.D = mx.random.normal((args.num_heads,)) * args.initializer_range
# self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
# self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
# if args.rescale_prenorm_residual:
# layer_scale = math.sqrt(1.0 / args.num_hidden_layers)
# self.out_proj.weight = self.out_proj.weight * layer_scale
# def __call__(self, u: mx.array, cache=None):
# batch_size, seq_len, dimension = u.shape
# assert seq_len == 1, "Input should be a single token"
# # Initialize cache states directly using indices
# if cache[0] is None: # conv state
# conv_dim = self.args.intermediate_size + 2 * self.args.state_size
# cache[0] = mx.zeros((batch_size, self.args.conv_kernel - 1, conv_dim))
# if cache[1] is None: # ssm state
# cache[1] = mx.zeros((
# batch_size,
# self.args.num_heads,
# self.args.head_dim,
# self.args.state_size
# ))
# zxbcdt = self.in_proj(u)
# n_heads = self.args.num_heads
# z = zxbcdt[:, :, :self.args.intermediate_size]
# xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size]
# dt = zxbcdt[:, :, -(n_heads):]
# dt = mx.reshape(dt, (batch_size, n_heads))
# dt = mx.clip(nn.softplus(dt + self.dt_bias), self.args.time_step_min, self.args.time_step_max)
# dt = mx.maximum(dt, self.args.time_step_floor)
# xBC = self.conv1d(xBC, cache=cache)
# xBC = silu(xBC)
# x = xBC[:, :, :self.args.intermediate_size]
# B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size]
# C = xBC[:, :, -self.args.state_size:]
# x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim))
# x = mx.squeeze(x, axis=1)
# B = mx.reshape(B, (batch_size, 1, self.args.state_size))
# B = mx.broadcast_to(B, (batch_size, n_heads, self.args.state_size))
# B = mx.expand_dims(B, axis=2)
# C = mx.reshape(C, (batch_size, 1, self.args.state_size))
# C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size))
# C = mx.expand_dims(C, axis=3)
# A = -mx.exp(self.A_log)
# dA = mx.exp(dt * mx.expand_dims(A, 0))
# dA = mx.expand_dims(mx.expand_dims(dA, -1), -1)
# x = mx.expand_dims(x, axis=3)
# dBx = mx.matmul(x, B)
# # Update ssm state directly using cache[1]
# cache[1] = cache[1] * dA + dBx
# y = mx.matmul(cache[1], C)
# y = mx.squeeze(y, axis=-1)
# y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1)
# y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim))
# y = self.norm(y + z)
# return self.out_proj(y)
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads
self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias)
# Calculate dimensions
self.d_model = args.hidden_size
self.d_state = args.state_size
self.d_conv = args.conv_kernel
self.expand = args.expand
self.d_inner = int(self.expand * self.d_model)
self.n_heads = args.num_heads
self.d_head = self.d_inner // self.n_heads
conv_dim = args.intermediate_size + 2 * args.state_size
# Input projection
d_in_proj = self.d_inner * 2 + self.d_state * 2 + self.n_heads
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=args.use_bias)
# Convolution
conv_dim = self.d_inner + 2 * self.d_state
self.conv1d = DepthWiseConv1d(
in_channels=conv_dim,
out_channels=conv_dim,
kernel_size=args.conv_kernel,
groups=conv_dim,
kernel_size=self.d_conv,
bias=args.use_conv_bias,
padding=args.conv_kernel - 1
groups=conv_dim
)
self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range
self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range
self.D = mx.random.normal((args.num_heads,)) * args.initializer_range
# SSM parameters
self.dt_bias = mx.random.normal((self.n_heads,)) * args.initializer_range
self.A_log = mx.random.normal((self.n_heads,)) * args.initializer_range
self.D = mx.random.normal((self.n_heads,)) * args.initializer_range
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
# Output projection
self.norm = MambaRMSNormGated(self.d_inner, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=args.use_bias)
if args.rescale_prenorm_residual:
layer_scale = math.sqrt(1.0 / args.num_hidden_layers)
self.out_proj.weight = self.out_proj.weight * layer_scale
def __call__(self, u: mx.array, cache=None):
batch_size, seq_len, dimension = u.shape
batch_size, seq_len, _ = u.shape
# Process sequence in chunks if needed
# Project input
proj = self.in_proj(u) # [batch, seq_len, d_in_proj]
# Calculate split indices and slice tensors
z = proj[..., :self.d_inner]
x_conv = proj[..., self.d_inner:self.d_inner + (self.d_inner + 2 * self.d_state)]
dt = proj[..., -self.n_heads:]
# Process time steps
dt = nn.softplus(dt + self.dt_bias)
dt = mx.clip(dt, self.args.time_step_min, self.args.time_step_max)
dt = mx.maximum(dt, self.args.time_step_floor)
# Convolution and activation
x_conv = self.conv1d(x_conv, cache=[cache[0] if cache else None])
x_conv = silu(x_conv)
# Split conv output
x = x_conv[..., :self.d_inner]
B = x_conv[..., self.d_inner:self.d_inner + self.d_state]
C = x_conv[..., -self.d_state:]
# Reshape x for SSM
x = mx.reshape(x, (batch_size, seq_len, self.n_heads, self.d_head))
# Process B and C without reshaping heads
B = mx.expand_dims(B, axis=2) # [batch, seq_len, 1, d_state]
B = mx.broadcast_to(B, (batch_size, seq_len, self.n_heads, self.d_state))
C = mx.expand_dims(C, axis=2) # [batch, seq_len, 1, d_state]
C = mx.broadcast_to(C, (batch_size, seq_len, self.n_heads, self.d_state))
# Initialize or get previous state
if cache and cache[1] is not None:
prev_state = cache[1]
else:
prev_state = mx.zeros((batch_size, self.n_heads, self.d_head, self.d_state))
# Compute dA
dA = -mx.exp(self.A_log) # [n_heads]
dt = mx.reshape(dt, (batch_size, seq_len, self.n_heads)) # Ensure correct shape
dA = mx.exp(mx.expand_dims(dt * mx.expand_dims(dA, 0), -1)) # [batch, seq_len, n_heads, 1]
dA = mx.expand_dims(dA, -1) # [batch, seq_len, n_heads, 1, 1]
# Process sequence
next_state = prev_state
outputs = []
current_cache = cache
for i in range(seq_len):
# Extract current token
current_input = u[:, i:i+1, :]
# Initialize cache states if needed
if current_cache[0] is None: # conv state
conv_dim = self.args.intermediate_size + 2 * self.args.state_size
current_cache[0] = mx.zeros((batch_size, self.args.conv_kernel - 1, conv_dim))
if current_cache[1] is None: # ssm state
current_cache[1] = mx.zeros((
batch_size,
self.args.num_heads,
self.args.head_dim,
self.args.state_size
))
# Project input
zxbcdt = self.in_proj(current_input)
n_heads = self.args.num_heads
z = zxbcdt[:, :, :self.args.intermediate_size]
xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size]
dt = zxbcdt[:, :, -(n_heads):]
# Process time steps
dt = mx.reshape(dt, (batch_size, n_heads))
dt = mx.clip(nn.softplus(dt + self.dt_bias), self.args.time_step_min, self.args.time_step_max)
dt = mx.maximum(dt, self.args.time_step_floor)
# Apply convolution
xBC = self.conv1d(xBC, cache=current_cache)
xBC = silu(xBC)
# Split states
x = xBC[:, :, :self.args.intermediate_size]
B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size]
C = xBC[:, :, -self.args.state_size:]
# Reshape for SSM
x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim))
x = mx.squeeze(x, axis=1)
B = mx.reshape(B, (batch_size, 1, self.args.state_size))
B = mx.broadcast_to(B, (batch_size, n_heads, self.args.state_size))
B = mx.expand_dims(B, axis=2)
C = mx.reshape(C, (batch_size, 1, self.args.state_size))
C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size))
C = mx.expand_dims(C, axis=3)
# SSM updates
A = -mx.exp(self.A_log)
dA = mx.exp(dt * mx.expand_dims(A, 0))
dA = mx.expand_dims(mx.expand_dims(dA, -1), -1)
for t in range(seq_len):
# Get current step tensors
xt = x[:, t] # [batch, n_heads, d_head]
Bt = B[:, t] # [batch, n_heads, d_state]
Ct = C[:, t] # [batch, n_heads, d_state]
dAt = dA[:, t] # [batch, n_heads, 1, 1]
# Update state
x = mx.expand_dims(x, axis=3)
dBx = mx.matmul(x, B)
current_cache[1] = current_cache[1] * dA + dBx
next_state = (
next_state * dAt + # Broadcasting: [batch, n_heads, d_head, d_state] * [batch, n_heads, 1, 1]
mx.matmul(
mx.expand_dims(xt, -1), # [batch, n_heads, d_head, 1]
mx.expand_dims(Bt, -2) # [batch, n_heads, 1, d_state]
)
)
# Compute output
y = mx.matmul(current_cache[1], C)
y = mx.squeeze(y, axis=-1)
y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1)
y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim))
y = self.norm(y + z)
yt = mx.matmul(
next_state, # [batch, n_heads, d_head, d_state]
mx.expand_dims(Ct, -1) # [batch, n_heads, d_state, 1]
)
yt = mx.squeeze(yt, -1) # [batch, n_heads, d_head]
yt = yt + xt * mx.expand_dims(self.D, -1)
outputs.append(self.out_proj(y))
# Reshape and normalize
yt = mx.reshape(yt, (batch_size, 1, self.d_inner))
yt = self.norm(yt, z[:, t:t+1])
outputs.append(self.out_proj(yt))
# Update cache
if cache is not None:
cache[1] = next_state
# Concatenate all outputs
return mx.concatenate(outputs, axis=1)

View File

@@ -1,316 +0,0 @@
import math
from dataclasses import dataclass, field
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .cache import Mamba2Cache
@dataclass
class ModelArgs(BaseModelArgs):
num_heads: int
head_dim: int
vocab_size: int
hidden_size: int
state_size: int
num_hidden_layers: int
layer_norm_epsilon: float
expand: int
conv_kernel: int
n_groups: int
use_bias: bool
use_conv_bias: bool
initializer_range: float
residual_in_fp32: bool
time_step_min: float
time_step_max: float
time_step_floor: float
rescale_prenorm_residual: bool
rms_norm: bool
chunk_size: int
tie_word_embeddings: bool
use_cache: bool = True
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
time_step_rank: Union[int, str] = "auto"
model_type: str = "mamba2"
def __post_init__(self):
if not hasattr(self, "intermediate_size"):
self.intermediate_size = int(self.expand * self.hidden_size)
if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_heads
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
def silu(x):
return x * mx.sigmoid(x)
class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = mx.ones((hidden_size,))
self.variance_epsilon = eps
def __call__(self, hidden_states, gate=None):
# Fuse operations where possible
if gate is not None:
hidden_states = hidden_states * nn.silu(gate)
# Compute variance in fp32 for better numerical stability
hidden_states_fp32 = hidden_states.astype(mx.float32)
variance = mx.mean(hidden_states_fp32 * hidden_states_fp32, axis=-1, keepdims=True)
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
def ssd_optimized(x, A, B, C, chunk_size):
batch, seqlen, nheads, dim = x.shape
B = mx.expand_dims(B, axis=2)
C = mx.expand_dims(C, axis=2)
output = mx.zeros((batch, seqlen, nheads, dim))
state = mx.zeros((batch, nheads, dim, B.shape[-1]))
for i in range(0, seqlen, chunk_size):
chunk = slice(i, min(i + chunk_size, seqlen))
chunk_size_actual = min(chunk_size, seqlen - i)
dA = mx.exp(mx.expand_dims(A[chunk], axis=0))
x_chunk = mx.transpose(x[:, chunk], [0, 2, 3, 1])
dBx = mx.matmul(x_chunk, B[:, chunk])
state = state * mx.expand_dims(dA, axis=-1) + dBx
y = mx.matmul(state, mx.transpose(C[:, chunk], [0, 2, 1]))
output[:, i:i+chunk_size_actual] = mx.transpose(y, [0, 3, 1, 2])
return output, state
def update_conv_cache(x: mx.array, cache, kernel_size: int) -> Tuple[mx.array, mx.array]:
"""Update convolution cache for sequential processing."""
B, L, C = x.shape
if cache is None:
# Initialize cache with zeros
cache = mx.zeros((B, kernel_size - 1, C))
# Concatenate cache with current input
x_with_cache = mx.concatenate([cache, x], axis=1)
# Update cache with the last (kernel_size - 1) elements
new_cache = x_with_cache[:, -kernel_size+1:] if x_with_cache.shape[1] >= kernel_size else x_with_cache
return x_with_cache, new_cache
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.intermediate_size = int(args.expand * args.hidden_size)
self.state_size = args.state_size
self.num_heads = args.num_heads
self.head_dim = args.head_dim
self.ssm_state_size = args.state_size
self.n_groups = args.n_groups
self.conv_kernel = args.conv_kernel
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
projection_size = self.intermediate_size + self.conv_dim + self.num_heads
self.in_proj = nn.Linear(args.hidden_size, projection_size, bias=args.use_bias)
# Using built-in Conv1d instead of custom DepthwiseConv1d
self.conv1d = nn.Conv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
kernel_size=args.conv_kernel,
groups=self.conv_dim, # For depthwise convolution
padding=0, # We'll handle padding manually with the cache
bias=args.use_conv_bias
)
self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range
self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range
self.D = mx.random.normal((args.num_heads,)) * args.initializer_range
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
if args.rescale_prenorm_residual:
layer_scale = math.sqrt(1.0 / args.num_hidden_layers)
self.out_proj.weight = self.out_proj.weight * layer_scale
def __call__(self, u: mx.array, cache):
batch_size, seq_len, _ = u.shape
projected = self.in_proj(u)
d_conv = self.conv_dim
z = projected[..., :self.intermediate_size]
xBC = projected[..., self.intermediate_size:self.intermediate_size + d_conv]
dt = projected[..., -self.num_heads:]
dt = mx.clip(
nn.softplus(dt + self.dt_bias),
self.args.time_step_min,
self.args.time_step_max
)
dt = mx.maximum(dt, self.args.time_step_floor)
# Handle convolution with separate cache update
if cache is not None:
# Update cache and get padded input
xBC_padded, new_cache = update_conv_cache(xBC, cache.conv_states, self.conv_kernel)
cache.conv_states = new_cache
# Prepare input for conv1d: [B, L, C] -> [B, C, L]
xBC_conv = mx.transpose(xBC_padded, [0, 2, 1])
# Apply convolution
xBC = self.conv1d(xBC_conv)
# Transform back: [B, C, L] -> [B, L, C]
xBC = mx.transpose(xBC, [0, 2, 1])
# Take only the relevant part corresponding to input length
xBC = xBC[:, :seq_len]
else:
# For training, use regular convolution with padding
xBC = mx.transpose(xBC, [0, 2, 1])
xBC = self.conv1d(xBC)
xBC = mx.transpose(xBC, [0, 2, 1])
xBC = silu(xBC)
x = xBC[..., :self.intermediate_size]
BC = xBC[..., self.intermediate_size:]
B = BC[..., :self.state_size]
C = BC[..., self.state_size:]
x = mx.reshape(x, (-1, seq_len, self.num_heads, self.intermediate_size // self.num_heads))
A = -mx.exp(self.A_log)
D_expanded = mx.expand_dims(self.D, -1)
if cache is not None and cache.ssm_state is None:
cache.ssm_state = mx.zeros((
batch_size,
self.num_heads,
self.intermediate_size // self.num_heads,
self.state_size
))
if cache is not None:
output = mx.zeros((batch_size, seq_len, self.args.hidden_size))
for pos in range(seq_len):
x_t = x[:, pos:pos+1]
dA = mx.exp(dt[:, pos:pos+1] * mx.expand_dims(A, 0))
dA = mx.expand_dims(mx.expand_dims(dA, -1), -1)
x_expanded = mx.expand_dims(x_t, axis=3)
dBx = mx.matmul(x_expanded, mx.expand_dims(B[:, pos:pos+1], axis=2))
cache.ssm_state = cache.ssm_state * dA + dBx
y = mx.matmul(cache.ssm_state, mx.expand_dims(C[:, pos:pos+1], axis=3))
y = mx.squeeze(y, axis=-1)
y = y + x_t * D_expanded
y = mx.reshape(y, (batch_size, 1, -1))
y = self.norm(y + z[:, pos:pos+1])
y = self.out_proj(y)
if self.args.residual_in_fp32:
y = y.astype(mx.float32)
output = output.at[:, pos:pos+1].set(y)
else:
y, ssm_state = ssd_optimized(
x * mx.expand_dims(dt, -1),
-mx.exp(self.A_log) * dt,
B, C,
self.args.chunk_size
)
y = mx.reshape(
y + x * mx.expand_dims(self.D, -1),
(batch_size, seq_len, -1)
)
y = self.norm(y + z)
output = self.out_proj(y)
if self.args.residual_in_fp32:
output = output.astype(mx.float32)
return output
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.mixer = Mamba2Block(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
return self.mixer(self.norm(x), cache) + x
class Mamba2(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(self, x: mx.array, cache):
x = self.embeddings(x)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
x = layer(x, c)
return self.norm_f(x)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba2(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None):
B, T = inputs.shape
x = self.backbone(inputs, cache)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(x)
else:
logits = self.lm_head(x)
return logits
def make_cache(self, batch_size=1):
return [Mamba2Cache() for _ in range(len(self.layers))]
def sanitize(self, weights):
for k, v in weights.items():
if "conv1d.weight" in k and v.ndim == 3:
weights[k] = v.moveaxis(2, 1)
return weights
@property
def layers(self):
return self.backbone.layers