Fixed streaming generation and got rid of generating gibberish, but is still a litle slow: 0.222 tokens-per-sec

This commit is contained in:
Goekdeniz-Guelmez
2024-11-21 22:01:28 +01:00
parent e4eae973e8
commit e22b2dbf27
6 changed files with 884 additions and 1015 deletions

View File

@@ -419,11 +419,9 @@ class RotatingKVCache(_BaseCache):
raise NotImplementedError("RotatingKVCache Quantization NYI") raise NotImplementedError("RotatingKVCache Quantization NYI")
class MambaCache: class MambaCache(_BaseCache):
def __init__(self): def __init__(self):
# [conv_state, ssm_state]
self.cache = [None, None] self.cache = [None, None]
self.offset = 0 # Sliding window caching
def __setitem__(self, idx, value): def __setitem__(self, idx, value):
self.cache[idx] = value self.cache[idx] = value
@@ -438,15 +436,3 @@ class MambaCache:
@state.setter @state.setter
def state(self, v): def state(self, v):
self.cache = v self.cache = v
@property
def conv_states(self):
return [self.cache[0]]
@property
def ssm_states(self):
return [self.cache[1]]
def reset(self):
self.cache = [None, None]
self.offset = 0

View File

@@ -1,449 +1,437 @@
# coding=utf-8 """
# Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team. mamba2-minimal
# ==============
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch MAMBA2 model."""
import math A minimal, single-file implementation of the Mamba-2 model in PyTorch.
> **Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality**
> Authors: Tri Dao, Albert Gu
> Paper: https://arxiv.org/abs/2405.21060
"""
import json
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import Iterable, NamedTuple, TypeAlias, cast
import torch import torch
import torch.utils.checkpoint import torch.nn.functional as F
from torch import nn from einops import rearrange, repeat
from torch.nn import CrossEntropyLoss from torch import LongTensor, Tensor, nn
logger = logging.get_logger(__name__) Device: TypeAlias = str | torch.device | None
def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): @dataclass
""" class Mamba2Config:
Padding x tensor with `pad_size` on the seq_len dim (dim=1) d_model: int # model dimension (D)
n_layer: int = 24 # number of Mamba-2 layers in the language model
d_state: int = 128 # state dimension (N)
d_conv: int = 4 # convolution kernel size
expand: int = 2 # expansion factor (E)
headdim: int = 64 # head dimension (P)
chunk_size: int = 64 # matrix partition size (Q)
vocab_size: int = 50277
pad_vocab_size_multiple: int = 16
Assumes that we only have tensors of either size 4 or 3 def __post_init__(self):
""" self.d_inner = self.expand * self.d_model
pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) assert self.d_inner % self.headdim == 0
self.nheads = self.d_inner // self.headdim
return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) if self.vocab_size % self.pad_vocab_size_multiple != 0:
self.vocab_size += (
self.pad_vocab_size_multiple
- self.vocab_size % self.pad_vocab_size_multiple
)
def reshape_into_chunks(input_tensor, pad_size, chunk_size): class InferenceCache(NamedTuple):
""" conv_state: Tensor # (batch, d_inner + 2 * d_state, d_conv)
Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and ssm_state: Tensor # (batch, nheads, headdim, d_state)
simultaneously splitting it into chunk sequences.
Assumes that we only have tensors of either size 4 or 3 @staticmethod
""" def alloc(batch_size: int, args: Mamba2Config, device: Device = None):
# [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] return InferenceCache(
input_tensor = pad_tensor_by_size(input_tensor, pad_size) torch.zeros(
batch_size, args.d_inner + 2 * args.d_state, args.d_conv, device=device
if len(input_tensor.shape) == 3: ),
# [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] torch.zeros(
return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) batch_size, args.nheads, args.headdim, args.d_state, device=device
else: ),
# [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size]
return input_tensor.reshape(
input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3]
) )
def segment_sum(input_tensor): class Mamba2LMHeadModel(nn.Module):
""" def __init__(self, args: Mamba2Config, device: Device = None):
More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.
"""
chunk_size = input_tensor.size(-1)
# 1. expand input tensor to have an additional dimension and repeat along that dimension
# [..., chunk_size] -> [..., chunk_size, chunk_size]
input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size)
# 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag
mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1)
input_tensor = input_tensor.masked_fill(~mask, 0)
# 3. compute actual cumsum
tensor_segsum = torch.cumsum(input_tensor, dim=-2)
# 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time)
mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0)
tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf)
return tensor_segsum
class Mamba2Cache:
"""
Arguments:
config: ModelArgs
batch_size: int
dtype: torch.dtype
device: torch.device
Attributes:
seqlen_offset: int
dtype: torch.dtype
conv_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, conv_kernel]
ssm_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, ssm_state_size]
"""
def __init__(
self, config: ModelArgs, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None
):
self.seqlen_offset = 0
self.dtype = dtype
self.conv_kernel = config.conv_kernel
self.intermediate_size = int(config.expand * config.hidden_size)
self.conv_states = {
i: torch.zeros(
batch_size,
self.intermediate_size + 2 * config.n_groups * config.state_size,
self.conv_kernel,
device=device,
dtype=dtype,
)
for i in range(config.num_hidden_layers)
}
self.ssm_states = {
i: torch.zeros(
batch_size, config.num_heads, config.head_dim, config.state_size, device=device, dtype=dtype
)
for i in range(config.num_hidden_layers)
}
def update_conv_state(
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
) -> torch.Tensor:
conv_state = self.conv_states[layer_idx]
cache_position = cache_position.clamp(0, self.conv_kernel - 1)
conv_state = conv_state.roll(shifts=-1, dims=-1)
conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
self.conv_states[layer_idx].zero_()
self.conv_states[layer_idx] += conv_state
return self.conv_states[layer_idx]
def reset(self):
self.conv_states.zero_()
self.ssm_states.zero_()
class MambaRMSNormGated(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size)) self.args = args
self.variance_epsilon = eps self.device = device
def forward(self, hidden_states, gate=None): self.backbone = nn.ModuleDict(
input_dtype = hidden_states.dtype dict(
hidden_states = hidden_states embedding=nn.Embedding(args.vocab_size, args.d_model, device=device),
layers=nn.ModuleList(
[
nn.ModuleDict(
dict(
mixer=Mamba2(args, device=device),
norm=RMSNorm(args.d_model, device=device),
)
)
for _ in range(args.n_layer)
]
),
norm_f=RMSNorm(args.d_model, device=device),
)
)
self.lm_head = nn.Linear(
args.d_model, args.vocab_size, bias=False, device=device
)
self.lm_head.weight = self.backbone.embedding.weight
if gate is not None: @staticmethod
hidden_states = hidden_states * nn.functional.silu(gate) def from_pretrained(huggingface_model_id: str, device: Device = None):
variance = hidden_states.pow(2).mean(-1, keepdim=True) from transformers.utils import CONFIG_NAME, WEIGHTS_NAME
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) from transformers.utils.hub import cached_file
return self.weight * hidden_states config_path = cached_file(huggingface_model_id, CONFIG_NAME)
assert config_path, "Failed to get huggingface config file"
state_dict_path = cached_file(huggingface_model_id, WEIGHTS_NAME)
assert state_dict_path, "Failed to get huggingface state dict file"
config = json.load(open(config_path))
args = Mamba2Config(
d_model=config["d_model"],
n_layer=config["n_layer"],
vocab_size=config["vocab_size"],
pad_vocab_size_multiple=config["pad_vocab_size_multiple"],
)
map_location = "cpu" if device is None else device
state_dict = torch.load(
state_dict_path, weights_only=True, map_location=map_location, mmap=True
)
model = Mamba2LMHeadModel(args, device=device)
model.load_state_dict(state_dict)
model.eval()
return model
def forward(
self, input_ids: LongTensor, h: list[InferenceCache] | list[None] | None = None
) -> tuple[LongTensor, list[InferenceCache]]:
"""
Arguments
input_ids: (batch, seqlen) tokens from `EleutherAI/gpt-neox-20b` tokenizer
h: hidden states for inference step. If present the constant-time
(wrt sequence length) inference path will be taken, input_ids
should have shape (batch, 1) containing the next batch of prompt
token.
Return (logits, h)
logits: (batch, seqlen, vocab_size)
h: updated inference cache after processing `input_ids`
"""
seqlen = input_ids.shape[1]
if h is None:
h = [None for _ in range(self.args.n_layer)]
x = self.backbone.embedding(input_ids)
for i, layer in enumerate(self.backbone.layers):
y, h[i] = layer.mixer(layer.norm(x), h[i])
x = y + x
x = self.backbone.norm_f(x)
logits = self.lm_head(x)
return logits[:, :seqlen], cast(list[InferenceCache], h)
def generate(
self,
input_ids: LongTensor,
max_new_length: int = 20,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 1.0,
eos_token_id: int = 0,
) -> Iterable[tuple[int, list[InferenceCache]]]:
prefix, tokens = input_ids[:-1], input_ids[-1:].unsqueeze(0)
# Process prompt
# The input sequence to forward (non-inference path) must have length multiple that of chunk_size.
# We split out excess tokens so that n_chunked tokens can be processed by one forward call and
# process the rest in multiple inference steps.
n_chunked = (prefix.shape[0] // self.args.chunk_size) * self.args.chunk_size
if n_chunked > 0:
_, h = self(prefix[:n_chunked].unsqueeze(0), None)
else:
h = [
InferenceCache.alloc(1, self.args, device=self.device)
for _ in range(self.args.n_layer)
]
for i in range(n_chunked, prefix.shape[0]):
_, h = self(prefix[i : i + 1].unsqueeze(0), h)
# Generate
for _ in range(max_new_length):
with torch.no_grad():
out, h = self(tokens, h)
logits = out[0, -1]
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, k=top_k)[0][-1]
logits[indices_to_remove] = -torch.inf
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > 0.5
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
sorted_indices_to_remove[0] = False
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = -torch.inf
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
if next_token.item() == eos_token_id:
return
tokens = next_token.unsqueeze(0)
yield cast(int, next_token.item()), h
class Mamba2Mixer(nn.Module): class Mamba2(nn.Module):
def __init__(self, config: ModelArgs): def __init__(self, args: Mamba2Config, device: Device = None):
super().__init__() super().__init__()
self.num_heads = config.num_heads self.args = args
self.hidden_size = config.hidden_size self.device = device
self.state_size = config.state_size
self.conv_kernel = config.conv_kernel
self.intermediate_size = int(config.expand * self.hidden_size)
self.time_step_rank = int(config.time_step_rank)
self.use_conv_bias = config.use_conv_bias
self.layer_norm_epsilon = config.layer_norm_epsilon # Order: (z, x, B, C, dt)
d_in_proj = 2 * args.d_inner + 2 * args.d_state + args.nheads
self.in_proj = nn.Linear(args.d_model, d_in_proj, bias=False, device=device)
self.n_groups = config.n_groups conv_dim = args.d_inner + 2 * args.d_state
self.head_dim = config.head_dim
self.chunk_size = config.chunk_size
self.time_step_limit = config.time_step_limit
self.time_step_min = config.time_step_min
self.time_step_max = config.time_step_max
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size
self.conv1d = nn.Conv1d( self.conv1d = nn.Conv1d(
in_channels=self.conv_dim, in_channels=conv_dim,
out_channels=self.conv_dim, out_channels=conv_dim,
bias=config.use_conv_bias, kernel_size=args.d_conv,
kernel_size=config.conv_kernel, groups=conv_dim,
groups=self.conv_dim, padding=args.d_conv - 1,
padding=config.conv_kernel - 1, device=device,
) )
# projection of the input hidden states self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device))
projection_size = self.intermediate_size + self.conv_dim + self.num_heads self.A_log = nn.Parameter(torch.empty(args.nheads, device=device))
self.in_proj = nn.Linear( self.D = nn.Parameter(torch.empty(args.nheads, device=device))
self.hidden_size, self.norm = RMSNorm(args.d_inner, device=device)
projection_size, self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device)
bias=config.use_bias,
def forward(self, u: Tensor, h: InferenceCache | None = None):
"""
Arguments
u: (batch, seqlen, d_model) input. seqlen should be a multiple of chunk_size.
h: hidden states for inference step. Initialized to 0s if not present.
Return (y, h)
y: (batch, seqlen, d_model) output
h: updated inference cache after processing `u`
"""
if h:
return self.step(u, h)
A = -torch.exp(self.A_log) # (nheads,)
zxbcdt = self.in_proj(u) # (batch, seqlen, d_in_proj)
z, xBC, dt = torch.split(
zxbcdt,
[
self.args.d_inner,
self.args.d_inner + 2 * self.args.d_state,
self.args.nheads,
],
dim=-1,
)
dt = F.softplus(dt + self.dt_bias) # (batch, seqlen, nheads)
# Pad or truncate xBC seqlen to d_conv
conv_state = F.pad(
rearrange(xBC, "b l d -> b d l"), (self.args.d_conv - u.shape[1], 0)
) )
self.dt_bias = torch.ones(self.num_heads) xBC = silu(
A = torch.arange(1, self.num_heads + 1) self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, : u.shape[1], :]
self.A_log = torch.log(A) ) # (batch, seqlen, d_inner + 2 * d_state))
self.D = torch.ones(self.num_heads) x, B, C = torch.split(
xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1
)
x = rearrange(x, "b l (h p) -> b l h p", p=self.args.headdim)
y, ssm_state = ssd(
x * dt.unsqueeze(-1),
A * dt,
rearrange(B, "b l n -> b l 1 n"),
rearrange(C, "b l n -> b l 1 n"),
self.args.chunk_size,
device=self.device,
)
y = y + x * self.D.unsqueeze(-1)
y = rearrange(y, "b l h p -> b l (h p)")
y = self.norm(y, z)
y = self.out_proj(y)
self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon) h = InferenceCache(conv_state, ssm_state)
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) return y, h
def forward(self, input_states, cache: Optional[Mamba2Cache]=None): def step(self, u: Tensor, h: InferenceCache) -> tuple[Tensor, InferenceCache]:
batch_size, seq_len, _ = input_states.shape """Take a single inference step for the current input and hidden state
# Gated MLP's linear projection
projected_states = self.in_proj(input_states.squeeze(1)) Unlike attention-based models, RNN-based models (eg Mamba) does not need
d_mlp = ( to look back at all the past tokens to generate a new token. Instead a
projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.state_size- self.num_heads) // 2 hidden state (initialized to 0s initially) is updated for each input and
_, _, gate, hidden_states, dt = projected_states.split( passed to the next inference step. This means that the total inference
[d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 time is linear with respect to the sequence length instead of quadratic
in attention's case.
Arguments
u: (batch, 1, d_model)
h: initial/running hidden state
Return (y, h)
y: (batch, 1, d_model)
h: updated hidden state
"""
assert u.shape[1] == 1, "Only one token can be decoded per inference step"
zxbcdt = self.in_proj(u.squeeze(1)) # (batch, d_in_proj)
z, xBC, dt = torch.split(
zxbcdt,
[
self.args.d_inner,
self.args.d_inner + 2 * self.args.d_state,
self.args.nheads,
],
dim=-1,
) )
# Convolution sequence transformation # Advance convolution input
ssm_state = cache.ssm_states[self.layer_idx].clone() h.conv_state.copy_(torch.roll(h.conv_state, shifts=-1, dims=-1))
ssm_state = ssm_state.to(hidden_states.device) h.conv_state[:, :, -1] = xBC
# Convolution step
xBC = torch.sum(
h.conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
)
xBC += self.conv1d.bias
xBC = silu(xBC)
if cache.seqlen_offset > 0: x, B, C = torch.split(
conv_state = cache.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel] xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1
conv_state = torch.roll(conv_state, shifts=-1, dims=-1) )
A = -torch.exp(self.A_log) # (nheads,)
# handle batched generation - states are copied through # SSM step
conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states dt = F.softplus(dt + self.dt_bias) # (batch, nheads)
cache.conv_states[self.layer_idx].copy_(conv_state) dA = torch.exp(dt * A) # (batch, nheads)
hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) x = rearrange(x, "b (h p) -> b h p", p=self.args.headdim)
dBx = torch.einsum("bh, bn, bhp -> bhpn", dt, B, x)
h.ssm_state.copy_(h.ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
y = torch.einsum("bhpn, bn -> bhp", h.ssm_state, C)
y = y + rearrange(self.D, "h -> h 1") * x
y = rearrange(y, "b h p -> b (h p)")
y = self.norm(y, z)
y = self.out_proj(y)
if self.use_conv_bias: return y.unsqueeze(1), h
hidden_states += self.conv1d.bias
hidden_states = nn.silu(hidden_states)[:, None, ...] # [batch, 1, intermediate_size] : decoding
else:
hidden_states = hidden_states.transpose(1,2)
conv_state = nn.functional.pad(
hidden_states,
(self.conv_kernel - hidden_states.shape[-1], 0)
)
cache.conv_states[self.layer_idx].copy_(conv_state)
hidden_states = nn.silu(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len]
hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.state_size, self.n_groups * self.state_size], dim=-1)
A = -torch.exp(self.A_log.float()) # [num_heads]
if cache is not None and cache.seqlen_offset > 0:
# Note: there is no need to pad parameter matrices here, as there is just one new token
# for batched generation
dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...]
dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim)
# [num_heads] -> [num_heads, head_dim]
dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)
dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype))
dt = torch.clamp(dt, self.time_step_min) #, self.time_step_max)
A = A[..., None, None].expand(self.num_heads, self.head_dim, self.state_size).to(dtype=torch.float32)
# [bsz, num_heads, head_dim, state_size]
dA = torch.exp(dt[..., None] * A)
# Discretize B
# [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
# -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size]
B = B.reshape(batch_size, self.n_groups, -1)[..., None, :]
B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous()
B = B.reshape(batch_size, -1, B.shape[-1])
# [bsz, num_heads, head_dim, state_size]
dB = dt[..., None] * B[..., None, :]
# Discretize x into dB
# [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
dBx = dB * hidden_states[..., None]
# State calculation
cache.ssm_states[self.layer_idx].copy_(
cache.ssm_states[self.layer_idx] * dA + dBx
)
# Subsequent output
# [bsz, n_groups * state_size] -> [bsz, num_heads, state_size]
C = C.reshape(batch_size, self.n_groups, -1)[..., None, :]
C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous()
C = C.reshape(batch_size, -1, C.shape[-1])
# [bsz, num_heads, head_dim]
ssm_states = cache.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n]
# Reshape ssm_states to merge the first two dimensions
ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.state_size) # Shape: [b*h, d, n]
C_reshaped = C.view(batch_size * self.num_heads, self.state_size, 1) # Shape: [b*h, n, 1]
y = torch.bmm(ssm_states_reshaped, C_reshaped)
y = y.view(batch_size, self.num_heads, self.head_dim)
# D skip connection
# [num_heads] -> [num_heads, head_dim]
D = self.D[..., None].expand(self.D.shape[0], self.head_dim)
y = (y + hidden_states * D).to(y.dtype)
# [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size]
y = y.reshape(batch_size, -1)[:, None, ...]
else:
# begin ssd naive implementation without einsums
dt = nn.functional.softplus(dt + self.dt_bias)
dt = torch.clamp(dt, self.time_step_min)
hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
B = B.reshape(batch_size, seq_len, -1, self.state_size).float()
C = C.reshape(batch_size, seq_len, -1, self.state_size).float()
B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
pad_size = self.chunk_size - (seq_len % self.chunk_size)
D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
# Discretize x and A
hidden_states = hidden_states * dt[..., None]
A = A.to(hidden_states.dtype) * dt
# Rearrange into blocks/chunks
hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]
# [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] def segsum(x: Tensor, device: Device = None) -> Tensor:
A = A.permute(0, 3, 1, 2) """Stable segment sum calculation.
A_cumsum = torch.cumsum(A, dim=-1)
# 1. Compute the output for each intra-chunk (diagonal blocks) `exp(segsum(A))` produces a 1-semiseparable matrix, which is equivalent to a scalar SSM.
# This is the analog of a causal mask
L = torch.exp(segment_sum(A))
# First, contraction of C and B to get G (attention-weights like) Source: https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L23-L32
G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n) """
G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) T = x.size(-1)
x = repeat(x, "... d -> ... d e", e=T)
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=-1)
x = x.masked_fill(~mask, 0)
x_segsum = torch.cumsum(x, dim=-2)
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=0)
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
return x_segsum
# Step 2: Compute M, equivalent to applying attention mask to weights def ssd(x, A, B, C, chunk_size, initial_states=None, device: Device = None):
M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] """Structed State Space Duality (SSD) - the core of Mamba-2
M = M_intermediate.sum(dim=-1)
# Step 3: Compute Y_diag (apply to values) This is almost the exact same minimal SSD code from the blog post.
Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3)
# (right term of low-rank factorization of off-diagonal blocks; B terms) Arguments
x: (batch, seqlen, n_heads, d_head)
A: (batch, seqlen, n_heads)
B: (batch, seqlen, n_heads, d_state)
C: (batch, seqlen, n_heads, d_state)
decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) Return
B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None] y: (batch, seqlen, n_heads, d_head)
# permute back B * decay states
states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3)
if cache is not None and cache.seqlen_offset > 0:
previous_states = cache.ssm_states[self.layer_idx][:, None, ...]
else:
previous_states = torch.zeros_like(states[:, :1])
states = torch.cat([previous_states, states], dim=1)
decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
states_permuted = states.permute(0, 2, 1, 3, 4) Source
result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2) 1. https://tridao.me/blog/2024/mamba2-part3-algorithm/
new_states = result.permute(0, 2, 1, 3, 4) 2. https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L34-L78
states, ssm_state = new_states[:, :-1], new_states[:, -1] """
assert x.shape[1] % chunk_size == 0
# Compute state -> output conversion per chunk # Rearrange into chunks
# (left term of low-rank factorization of off-diagonal blocks; C terms) # Step 1, 2 and 4 of SSD can be computed in parallel for each chunk across devices (sequence parallel)
state_decay_out = torch.exp(A_cumsum) # This is not implemented and left as an exercise for the reader 😜
# compute Yoff x, A, B, C = [
C_times_states = (C[..., None, :] * states[:, :, None, ...]) rearrange(m, "b (c l) ... -> b c l ...", l=chunk_size) for m in (x, A, B, C)
state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) ]
Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
y = Y_diag + Y_off A = rearrange(A, "b c l h -> b h c l")
# [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] A_cumsum = torch.cumsum(A, dim=-1)
y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
y = y + D_residual # 1. Compute the output for each intra-chunk (diagonal blocks)
# Cutting off padded chunks L = torch.exp(segsum(A, device=device))
if pad_size > 0: Y_diag = torch.einsum("bclhn, bcshn, bhcls, bcshp -> bclhp", C, B, L, x)
y = y[:, :seq_len, :, :]
y = y.reshape(batch_size, seq_len, -1)
if ssm_state is not None and cache is not None:
cache.ssm_states[self.layer_idx].copy_(ssm_state)
scan_output = self.norm(y, gate) # 2. Compute the state for each intra-chunk
# end ssd naive # (right term of low-rank factorization of off-diagonal blocks; B terms)
decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
states = torch.einsum("bclhn, bhcl, bclhp -> bchpn", B, decay_states, x)
# 4. Final linear projection # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
contextualized_states = self.out_proj(scan_output) # [batch, seq_len, hidden_size] # (middle term of factorization of off-diag blocks; A terms)
return contextualized_states if initial_states is None:
initial_states = torch.zeros_like(states[:, :1])
states = torch.cat([initial_states, states], dim=1)
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)), device=device))
new_states = torch.einsum("bhzc, bchpn -> bzhpn", decay_chunk, states)
states, final_state = new_states[:, :-1], new_states[:, -1]
# 4. Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = torch.exp(A_cumsum)
Y_off = torch.einsum("bclhn, bchpn, bhcl -> bclhp", C, states, state_decay_out)
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
return Y, final_state
class Mamba2Block(nn.Module): class RMSNorm(nn.Module):
def __init__(self, config): def __init__(self, d: int, eps: float = 1e-5, device: Device = None):
"""Gated Root Mean Square Layer Normalization
Paper: https://arxiv.org/abs/1910.07467
"""
super().__init__() super().__init__()
self.config = config self.eps = eps
self.norm = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.weight = nn.Parameter(torch.ones(d, device=device))
self.mixer = Mamba2Mixer(config)
def forward( def forward(self, x, z=None):
self, if z is not None:
hidden_states, x = x * silu(z)
cache: Optional[Mamba2Cache] = None, return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
cache_position: Optional[torch.LongTensor] = None,
):
x = self.mixer(
self.norm(hidden_states), cache=cache, cache_position=cache_position
)
return x + hidden_states
class Mamba2Model(nn.Module): def silu(x):
def __init__(self, config): """Applies the Sigmoid Linear Unit (SiLU), element-wise.
super().__init__(config)
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([Mamba2Block(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
self.gradient_checkpointing = False Define this manually since torch's version doesn't seem to work on MPS.
self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) """
return x * F.sigmoid(x)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
cache: Optional[Mamba2Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.embeddings(input_ids)
hidden_states = inputs_embeds
for mixer_block in self.layers:
hidden_states = mixer_block(
hidden_states,
cache=cache,
cache_position=cache_position,
)
cache.seqlen_offset += inputs_embeds.shape[1]
return self.norm_f(hidden_states), cache
class Mamba2ForCausalLM(nn.Module):
def __init__(self, config):
super().__init__(config)
self.backbone = Mamba2Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
cache: Optional[Mamba2Cache] = None,
cache_position: Optional[torch.Tensor] = None,
):
out, cache = self.backbone(
input_ids,
cache=cache,
cache_position=cache_position,
)
logits = self.lm_head(out)
return logits, cache

View File

@@ -0,0 +1,319 @@
import math
from dataclasses import dataclass, field
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .cache import MambaCache
@dataclass
class ModelArgs(BaseModelArgs):
num_heads: int
head_dim: int
vocab_size: int
hidden_size: int
state_size: int
num_hidden_layers: int
layer_norm_epsilon: float
expand: int
conv_kernel: int
n_groups: int
use_bias: bool
use_conv_bias: bool
initializer_range: float
residual_in_fp32: bool
time_step_min: float
time_step_max: float
time_step_floor: float
rescale_prenorm_residual: bool
rms_norm: bool
chunk_size: int
tie_word_embeddings: bool
use_cache: bool = True
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
time_step_rank: Union[int, str] = "auto"
model_type: str = "mamba2"
def __post_init__(self):
if not hasattr(self, "intermediate_size"):
self.intermediate_size = int(self.expand * self.hidden_size)
if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_heads
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = mx.ones((hidden_size,))
self.variance_epsilon = eps
def __call__(self, hidden_states, gate=None):
if gate is not None:
hidden_states = hidden_states * nn.silu(gate)
variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True)
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
def silu(x):
return x * mx.sigmoid(x)
def ssd(x, A, B, C, chunk_size):
batch, seqlen, nheads, dim = x.shape
B = mx.expand_dims(B, axis=2)
C = mx.expand_dims(C, axis=2)
state = mx.zeros((batch, nheads, dim, B.shape[-1]))
outputs = []
for i in range(0, seqlen, chunk_size):
chunk = slice(i, min(i + chunk_size, seqlen))
dA = mx.exp(mx.expand_dims(A[chunk], axis=0))
x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim]
x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size]
B_chunk = B[:, chunk] # [batch, chunk_size, state_size]
dBx = mx.matmul(x_chunk, B_chunk) # [batch, nheads, dim, state_size]
state = state * mx.expand_dims(dA, axis=-1) + dBx
C_chunk = C[:, chunk] # [batch, chunk_size, state_size]
y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size]
y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim]
outputs.append(y)
return mx.concatenate(outputs, axis=1), state
class DepthWiseConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = padding
self.groups = groups if groups is not None else in_channels
assert in_channels == out_channels, "In and out channels must be same for depthwise convolution"
assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution"
self.weight = mx.random.normal((in_channels, 1, kernel_size))
self.bias = mx.zeros((out_channels,)) if bias else None
def __call__(self, x: mx.array, cache=None) -> mx.array:
B, L, C = x.shape
K = self.kernel_size
assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}"
if cache is not None:
# Access conv_state directly from cache[0]
if cache[0] is None:
cache[0] = mx.zeros((B, K-1, C))
x = mx.concatenate([cache[0], x], axis=1)
outputs = []
for c in range(C):
x_c = x[:, :, c]
x_c = mx.expand_dims(x_c, axis=1)
w_c = self.weight[c]
if w_c.ndim == 2:
w_c = mx.expand_dims(w_c, axis=0)
elif w_c.ndim == 1:
w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0)
y_c = mx.conv_general(
x_c,
w_c,
stride=1,
padding=0
)
if self.bias is not None:
y_c = y_c + self.bias[c]
y_c = mx.squeeze(y_c, axis=1)
outputs.append(y_c)
y = mx.stack(outputs, axis=-1)
# Update cache directly using cache[0]
if cache is not None:
cache[0] = x[:, -K+1:, :] if x.shape[1] >= K else x
return y
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads
self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias)
conv_dim = args.intermediate_size + 2 * args.state_size
self.conv1d = DepthWiseConv1d(
in_channels=conv_dim,
out_channels=conv_dim,
kernel_size=args.conv_kernel,
groups=conv_dim,
bias=args.use_conv_bias,
padding=args.conv_kernel - 1
)
self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range
self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range
self.D = mx.random.normal((args.num_heads,)) * args.initializer_range
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
if args.rescale_prenorm_residual:
layer_scale = math.sqrt(1.0 / args.num_hidden_layers)
self.out_proj.weight = self.out_proj.weight * layer_scale
def __call__(self, u: mx.array, cache=None):
batch_size, seq_len, dimension = u.shape
# Initialize the negative A matrix
A = -mx.exp(self.A_log)
# Process sequence in chunks if needed
outputs = []
current_cache = cache
for i in range(seq_len):
# Extract current token
current_input = u[:, i:i+1, :]
# Initialize cache states if needed
if current_cache[0] is None: # conv state
conv_dim = self.args.intermediate_size + 2 * self.args.state_size
current_cache[0] = mx.zeros((batch_size, self.args.conv_kernel - 1, conv_dim))
if current_cache[1] is None: # ssm state
current_cache[1] = mx.zeros((
batch_size,
self.args.num_heads,
self.args.head_dim,
self.args.state_size
))
# Project input
zxbcdt = self.in_proj(current_input)
n_heads = self.args.num_heads
z = zxbcdt[:, :, :self.args.intermediate_size]
xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size]
dt = zxbcdt[:, :, -(n_heads):]
# Process time steps
dt = mx.reshape(dt, (batch_size, n_heads))
dt = mx.clip(nn.softplus(dt + self.dt_bias), self.args.time_step_min, self.args.time_step_max)
dt = mx.maximum(dt, self.args.time_step_floor)
# Apply convolution
xBC = self.conv1d(xBC, cache=current_cache)
xBC = silu(xBC)
# Split states
x = xBC[:, :, :self.args.intermediate_size]
B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size]
C = xBC[:, :, -self.args.state_size:]
# Reshape for SSM
x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim))
x = mx.squeeze(x, axis=1)
B = mx.reshape(B, (batch_size, 1, self.args.state_size))
B = mx.broadcast_to(B, (batch_size, n_heads, self.args.state_size))
B = mx.expand_dims(B, axis=2)
C = mx.reshape(C, (batch_size, 1, self.args.state_size))
C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size))
C = mx.expand_dims(C, axis=3)
# SSM updates
dA = mx.exp(dt * mx.expand_dims(A, 0))
dA = mx.expand_dims(mx.expand_dims(dA, -1), -1)
# Update state
x = mx.expand_dims(x, axis=3)
dBx = mx.matmul(x, B)
current_cache[1] = current_cache[1] * dA + dBx
# Compute output
y = mx.matmul(current_cache[1], C)
y = mx.squeeze(y, axis=-1)
y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1)
y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim))
y = self.norm(y + z)
outputs.append(self.out_proj(y))
# Concatenate all outputs
return mx.concatenate(outputs, axis=1)
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.residual_in_fp32 = args.residual_in_fp32
self.mixer = Mamba2Block(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
if self.residual_in_fp32:
x = x.astype(mx.float32)
normed = self.norm(x)
output = self.mixer(normed, cache)
return output + x
class Mamba2(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(self, x: mx.array, cache):
x = self.embeddings(x)
if cache is None:
cache = [None] * len(self.layers)
hidden = x
for layer, c in zip(self.layers, cache):
hidden = layer(hidden, c)
return self.norm_f(hidden)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba2(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None):
hidden = self.backbone(inputs, cache)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(hidden)
else:
logits = self.lm_head(hidden)
return logits
def make_cache(self):
return [MambaCache() for _ in range(len(self.layers))]
@property
def layers(self):
return self.backbone.layers

View File

@@ -5,7 +5,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs
from .cache import Mamba2Cache from .cache import MambaCache
@dataclass @dataclass
class ModelArgs(BaseModelArgs): class ModelArgs(BaseModelArgs):
@@ -62,7 +62,6 @@ def silu(x):
return x * mx.sigmoid(x) return x * mx.sigmoid(x)
def ssd(x, A, B, C, chunk_size): def ssd(x, A, B, C, chunk_size):
# Replace einsum operations with explicit reshape and matrix multiply
batch, seqlen, nheads, dim = x.shape batch, seqlen, nheads, dim = x.shape
B = mx.expand_dims(B, axis=2) B = mx.expand_dims(B, axis=2)
C = mx.expand_dims(C, axis=2) C = mx.expand_dims(C, axis=2)
@@ -74,7 +73,6 @@ def ssd(x, A, B, C, chunk_size):
chunk = slice(i, min(i + chunk_size, seqlen)) chunk = slice(i, min(i + chunk_size, seqlen))
dA = mx.exp(mx.expand_dims(A[chunk], axis=0)) dA = mx.exp(mx.expand_dims(A[chunk], axis=0))
# Replace einsum with explicit operations
x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim] x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim]
x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size] x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size]
B_chunk = B[:, chunk] # [batch, chunk_size, state_size] B_chunk = B[:, chunk] # [batch, chunk_size, state_size]
@@ -82,7 +80,6 @@ def ssd(x, A, B, C, chunk_size):
state = state * mx.expand_dims(dA, axis=-1) + dBx state = state * mx.expand_dims(dA, axis=-1) + dBx
# Replace einsum with explicit operations
C_chunk = C[:, chunk] # [batch, chunk_size, state_size] C_chunk = C[:, chunk] # [batch, chunk_size, state_size]
y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size] y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size]
y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim] y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim]
@@ -112,13 +109,13 @@ class DepthWiseConv1d(nn.Module):
assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}" assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}"
if cache is not None and cache.conv_states[0] is not None: if cache is not None:
# Convert None to proper array if needed # Access conv_state directly from cache[0]
if isinstance(cache.conv_states[0], type(None)): if cache[0] is None:
cache.conv_states[0] = mx.zeros((B, K-1, C)) cache[0] = mx.zeros((B, K-1, C))
x = mx.concatenate([cache.conv_states[0], x], axis=1)
x = mx.concatenate([cache[0], x], axis=1)
# Process each channel independently
outputs = [] outputs = []
for c in range(C): for c in range(C):
x_c = x[:, :, c] x_c = x[:, :, c]
@@ -130,24 +127,23 @@ class DepthWiseConv1d(nn.Module):
elif w_c.ndim == 1: elif w_c.ndim == 1:
w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0) w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0)
# Apply convolution
y_c = mx.conv_general( y_c = mx.conv_general(
x_c, x_c,
w_c, w_c,
stride=1, stride=1,
padding=0 padding=0
) )
if self.bias is not None: if self.bias is not None:
y_c = y_c + self.bias[c] y_c = y_c + self.bias[c]
outputs.append(mx.squeeze(y_c, axis=1)) y_c = mx.squeeze(y_c, axis=1)
outputs.append(y_c)
y = mx.stack(outputs, axis=-1) y = mx.stack(outputs, axis=-1)
# Update cache # Update cache directly using cache[0]
if cache is not None: if cache is not None:
cache.conv_states[0] = x[:, -K+1:, :] if x.shape[1] >= K else x cache[0] = x[:, -K+1:, :] if x.shape[1] >= K else x
return y return y
@@ -182,98 +178,80 @@ class Mamba2Block(nn.Module):
self.out_proj.weight = self.out_proj.weight * layer_scale self.out_proj.weight = self.out_proj.weight * layer_scale
def __call__(self, u: mx.array, cache=None): def __call__(self, u: mx.array, cache=None):
batch_size = u.shape[0] batch_size, seq_len, dimension = u.shape
seq_len = u.shape[1] assert seq_len == 1, "Input should be a single token"
outputs = []
# Initialize states if needed # Initialize cache states directly using indices
if cache.conv_states[0] is None: if cache[0] is None: # conv state
conv_dim = self.args.intermediate_size + 2 * self.args.state_size conv_dim = self.args.intermediate_size + 2 * self.args.state_size
cache.conv_states[0] = mx.zeros(( cache[0] = mx.zeros((batch_size, self.args.conv_kernel - 1, conv_dim))
batch_size,
self.args.conv_kernel - 1,
conv_dim
))
if cache.ssm_states[0] is None: if cache[1] is None: # ssm state
cache.ssm_states[0] = mx.zeros(( cache[1] = mx.zeros((
batch_size, batch_size,
self.args.num_heads, self.args.num_heads,
self.args.head_dim, self.args.head_dim,
self.args.state_size self.args.state_size
)) ))
for pos in range(seq_len): zxbcdt = self.in_proj(u)
u_t = u[:, pos:pos+1, :]
zxbcdt = self.in_proj(u_t)
n_heads = self.args.num_heads n_heads = self.args.num_heads
z = zxbcdt[:, :, :self.args.intermediate_size]
xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size]
dt = zxbcdt[:, :, -(n_heads):]
z = zxbcdt[:, :, :self.args.intermediate_size] dt = mx.reshape(dt, (batch_size, n_heads))
xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size] dt = mx.clip(nn.softplus(dt + self.dt_bias), self.args.time_step_min, self.args.time_step_max)
dt = zxbcdt[:, :, -(n_heads):] dt = mx.maximum(dt, self.args.time_step_floor)
dt = mx.reshape(dt, (batch_size, n_heads)) xBC = self.conv1d(xBC, cache=cache)
dt = mx.clip( xBC = silu(xBC)
nn.softplus(dt + self.dt_bias),
self.args.time_step_min,
self.args.time_step_max
)
dt = mx.maximum(dt, self.args.time_step_floor)
xBC = self.conv1d(xBC, cache=cache) x = xBC[:, :, :self.args.intermediate_size]
xBC = silu(xBC) B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size]
C = xBC[:, :, -self.args.state_size:]
x = xBC[:, :, :self.args.intermediate_size] x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim))
B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size] x = mx.squeeze(x, axis=1)
C = xBC[:, :, -self.args.state_size:] B = mx.reshape(B, (batch_size, 1, self.args.state_size))
B = mx.broadcast_to(B, (batch_size, n_heads, self.args.state_size))
B = mx.expand_dims(B, axis=2)
C = mx.reshape(C, (batch_size, 1, self.args.state_size))
C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size))
C = mx.expand_dims(C, axis=3)
x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim)) A = -mx.exp(self.A_log)
x = mx.squeeze(x, axis=1) dA = mx.exp(dt * mx.expand_dims(A, 0))
dA = mx.expand_dims(mx.expand_dims(dA, -1), -1)
B = mx.reshape(B, (batch_size, 1, self.args.state_size)) x = mx.expand_dims(x, axis=3)
B = mx.broadcast_to(B, (batch_size, n_heads, self.args.state_size)) dBx = mx.matmul(x, B)
B = mx.expand_dims(B, axis=2) # Update ssm state directly using cache[1]
cache[1] = cache[1] * dA + dBx
C = mx.reshape(C, (batch_size, 1, self.args.state_size)) y = mx.matmul(cache[1], C)
C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size)) y = mx.squeeze(y, axis=-1)
C = mx.expand_dims(C, axis=3) y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1)
y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim))
y = self.norm(y + z)
A = -mx.exp(self.A_log) return self.out_proj(y)
dA = mx.exp(dt * mx.expand_dims(A, 0))
dA = mx.expand_dims(mx.expand_dims(dA, -1), -1)
x = mx.expand_dims(x, axis=3)
dBx = mx.matmul(x, B)
cache.ssm_states[0] = cache.ssm_states[0] * dA + dBx
y = mx.matmul(cache.ssm_states[0], C)
y = mx.squeeze(y, axis=-1)
y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1)
y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim))
y = self.norm(y + z)
y = self.out_proj(y)
outputs.append(y)
return mx.concatenate(outputs, axis=1)
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.residual_in_fp32 = args.residual_in_fp32 self.residual_in_fp32 = args.residual_in_fp32
self.mixer = Mamba2Block(args) self.mixer = Mamba2Block(args)
self.norm = nn.RMSNorm(args.hidden_size) self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache): def __call__(self, x: mx.array, cache):
if self.residual_in_fp32: if self.residual_in_fp32:
x = x.astype(mx.float32) x = x.astype(mx.float32)
return self.mixer(self.norm(x), cache) + x normed = self.norm(x)
output = self.mixer(normed, cache)
return output + x
class Mamba2(nn.Module): class Mamba2(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
@@ -287,9 +265,11 @@ class Mamba2(nn.Module):
x = self.embeddings(x) x = self.embeddings(x)
if cache is None: if cache is None:
cache = [None] * len(self.layers) cache = [None] * len(self.layers)
hidden = x
for layer, c in zip(self.layers, cache): for layer, c in zip(self.layers, cache):
x = layer(x, c) hidden = layer(hidden, c)
return self.norm_f(x) return self.norm_f(hidden)
class Model(nn.Module): class Model(nn.Module):
@@ -297,40 +277,23 @@ class Model(nn.Module):
super().__init__() super().__init__()
self.args = args self.args = args
self.model_type = args.model_type self.model_type = args.model_type
self.backbone = Mamba2(args) self.backbone = Mamba2(args)
if not args.tie_word_embeddings: if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None): def __call__(self, inputs: mx.array, cache=None):
B, T = inputs.shape hidden = self.backbone(inputs, cache)
x = self.backbone(inputs, cache)
if self.args.tie_word_embeddings: if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(x) logits = self.backbone.embeddings.as_linear(hidden)
else: else:
logits = self.lm_head(x) logits = self.lm_head(hidden)
return logits return logits
def make_cache(self, batch_size=1): def make_cache(self):
return [Mamba2Cache(batch_size, self.args.conv_kernel) for _ in range(len(self.layers))] return [MambaCache() for _ in range(len(self.layers))]
def sanitize(self, weights):
sanitized = {}
for k, v in weights.items():
if "conv1d.weight" in k:
# Ensure weights are in correct shape (channels, 1, kernel_size)
if v.ndim == 2:
v = mx.expand_dims(v, axis=1)
elif v.ndim == 1:
v = mx.expand_dims(mx.expand_dims(v, axis=0), axis=0)
sanitized[k] = v
else:
sanitized[k] = v
return sanitized
@property @property
def layers(self): def layers(self):

View File

@@ -148,202 +148,131 @@ class DepthWiseConv1d(nn.Module):
return y return y
# class Mamba2Block(nn.Module):
# def __init__(self, args: ModelArgs):
# super().__init__()
# self.args = args
# d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads
# self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias)
# conv_dim = args.intermediate_size + 2 * args.state_size
# self.conv1d = DepthWiseConv1d(
# in_channels=conv_dim,
# out_channels=conv_dim,
# kernel_size=args.conv_kernel,
# groups=conv_dim,
# bias=args.use_conv_bias,
# padding=args.conv_kernel - 1
# )
# self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range
# self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range
# self.D = mx.random.normal((args.num_heads,)) * args.initializer_range
# self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
# self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
# if args.rescale_prenorm_residual:
# layer_scale = math.sqrt(1.0 / args.num_hidden_layers)
# self.out_proj.weight = self.out_proj.weight * layer_scale
# def __call__(self, u: mx.array, cache=None):
# batch_size, seq_len, dimension = u.shape
# assert seq_len == 1, "Input should be a single token"
# # Initialize cache states directly using indices
# if cache[0] is None: # conv state
# conv_dim = self.args.intermediate_size + 2 * self.args.state_size
# cache[0] = mx.zeros((batch_size, self.args.conv_kernel - 1, conv_dim))
# if cache[1] is None: # ssm state
# cache[1] = mx.zeros((
# batch_size,
# self.args.num_heads,
# self.args.head_dim,
# self.args.state_size
# ))
# zxbcdt = self.in_proj(u)
# n_heads = self.args.num_heads
# z = zxbcdt[:, :, :self.args.intermediate_size]
# xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size]
# dt = zxbcdt[:, :, -(n_heads):]
# dt = mx.reshape(dt, (batch_size, n_heads))
# dt = mx.clip(nn.softplus(dt + self.dt_bias), self.args.time_step_min, self.args.time_step_max)
# dt = mx.maximum(dt, self.args.time_step_floor)
# xBC = self.conv1d(xBC, cache=cache)
# xBC = silu(xBC)
# x = xBC[:, :, :self.args.intermediate_size]
# B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size]
# C = xBC[:, :, -self.args.state_size:]
# x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim))
# x = mx.squeeze(x, axis=1)
# B = mx.reshape(B, (batch_size, 1, self.args.state_size))
# B = mx.broadcast_to(B, (batch_size, n_heads, self.args.state_size))
# B = mx.expand_dims(B, axis=2)
# C = mx.reshape(C, (batch_size, 1, self.args.state_size))
# C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size))
# C = mx.expand_dims(C, axis=3)
# A = -mx.exp(self.A_log)
# dA = mx.exp(dt * mx.expand_dims(A, 0))
# dA = mx.expand_dims(mx.expand_dims(dA, -1), -1)
# x = mx.expand_dims(x, axis=3)
# dBx = mx.matmul(x, B)
# # Update ssm state directly using cache[1]
# cache[1] = cache[1] * dA + dBx
# y = mx.matmul(cache[1], C)
# y = mx.squeeze(y, axis=-1)
# y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1)
# y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim))
# y = self.norm(y + z)
# return self.out_proj(y)
class Mamba2Block(nn.Module): class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.args = args self.args = args
d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads # Calculate dimensions
self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias) self.d_model = args.hidden_size
self.d_state = args.state_size
self.d_conv = args.conv_kernel
self.expand = args.expand
self.d_inner = int(self.expand * self.d_model)
self.n_heads = args.num_heads
self.d_head = self.d_inner // self.n_heads
conv_dim = args.intermediate_size + 2 * args.state_size # Input projection
d_in_proj = self.d_inner * 2 + self.d_state * 2 + self.n_heads
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=args.use_bias)
# Convolution
conv_dim = self.d_inner + 2 * self.d_state
self.conv1d = DepthWiseConv1d( self.conv1d = DepthWiseConv1d(
in_channels=conv_dim, in_channels=conv_dim,
out_channels=conv_dim, out_channels=conv_dim,
kernel_size=args.conv_kernel, kernel_size=self.d_conv,
groups=conv_dim,
bias=args.use_conv_bias, bias=args.use_conv_bias,
padding=args.conv_kernel - 1 groups=conv_dim
) )
self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range # SSM parameters
self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range self.dt_bias = mx.random.normal((self.n_heads,)) * args.initializer_range
self.D = mx.random.normal((args.num_heads,)) * args.initializer_range self.A_log = mx.random.normal((self.n_heads,)) * args.initializer_range
self.D = mx.random.normal((self.n_heads,)) * args.initializer_range
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon) # Output projection
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) self.norm = MambaRMSNormGated(self.d_inner, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=args.use_bias)
if args.rescale_prenorm_residual: if args.rescale_prenorm_residual:
layer_scale = math.sqrt(1.0 / args.num_hidden_layers) layer_scale = math.sqrt(1.0 / args.num_hidden_layers)
self.out_proj.weight = self.out_proj.weight * layer_scale self.out_proj.weight = self.out_proj.weight * layer_scale
def __call__(self, u: mx.array, cache=None): def __call__(self, u: mx.array, cache=None):
batch_size, seq_len, dimension = u.shape batch_size, seq_len, _ = u.shape
# Process sequence in chunks if needed # Project input
proj = self.in_proj(u) # [batch, seq_len, d_in_proj]
# Calculate split indices and slice tensors
z = proj[..., :self.d_inner]
x_conv = proj[..., self.d_inner:self.d_inner + (self.d_inner + 2 * self.d_state)]
dt = proj[..., -self.n_heads:]
# Process time steps
dt = nn.softplus(dt + self.dt_bias)
dt = mx.clip(dt, self.args.time_step_min, self.args.time_step_max)
dt = mx.maximum(dt, self.args.time_step_floor)
# Convolution and activation
x_conv = self.conv1d(x_conv, cache=[cache[0] if cache else None])
x_conv = silu(x_conv)
# Split conv output
x = x_conv[..., :self.d_inner]
B = x_conv[..., self.d_inner:self.d_inner + self.d_state]
C = x_conv[..., -self.d_state:]
# Reshape x for SSM
x = mx.reshape(x, (batch_size, seq_len, self.n_heads, self.d_head))
# Process B and C without reshaping heads
B = mx.expand_dims(B, axis=2) # [batch, seq_len, 1, d_state]
B = mx.broadcast_to(B, (batch_size, seq_len, self.n_heads, self.d_state))
C = mx.expand_dims(C, axis=2) # [batch, seq_len, 1, d_state]
C = mx.broadcast_to(C, (batch_size, seq_len, self.n_heads, self.d_state))
# Initialize or get previous state
if cache and cache[1] is not None:
prev_state = cache[1]
else:
prev_state = mx.zeros((batch_size, self.n_heads, self.d_head, self.d_state))
# Compute dA
dA = -mx.exp(self.A_log) # [n_heads]
dt = mx.reshape(dt, (batch_size, seq_len, self.n_heads)) # Ensure correct shape
dA = mx.exp(mx.expand_dims(dt * mx.expand_dims(dA, 0), -1)) # [batch, seq_len, n_heads, 1]
dA = mx.expand_dims(dA, -1) # [batch, seq_len, n_heads, 1, 1]
# Process sequence
next_state = prev_state
outputs = [] outputs = []
current_cache = cache
for i in range(seq_len): for t in range(seq_len):
# Extract current token # Get current step tensors
current_input = u[:, i:i+1, :] xt = x[:, t] # [batch, n_heads, d_head]
Bt = B[:, t] # [batch, n_heads, d_state]
# Initialize cache states if needed Ct = C[:, t] # [batch, n_heads, d_state]
if current_cache[0] is None: # conv state dAt = dA[:, t] # [batch, n_heads, 1, 1]
conv_dim = self.args.intermediate_size + 2 * self.args.state_size
current_cache[0] = mx.zeros((batch_size, self.args.conv_kernel - 1, conv_dim))
if current_cache[1] is None: # ssm state
current_cache[1] = mx.zeros((
batch_size,
self.args.num_heads,
self.args.head_dim,
self.args.state_size
))
# Project input
zxbcdt = self.in_proj(current_input)
n_heads = self.args.num_heads
z = zxbcdt[:, :, :self.args.intermediate_size]
xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size]
dt = zxbcdt[:, :, -(n_heads):]
# Process time steps
dt = mx.reshape(dt, (batch_size, n_heads))
dt = mx.clip(nn.softplus(dt + self.dt_bias), self.args.time_step_min, self.args.time_step_max)
dt = mx.maximum(dt, self.args.time_step_floor)
# Apply convolution
xBC = self.conv1d(xBC, cache=current_cache)
xBC = silu(xBC)
# Split states
x = xBC[:, :, :self.args.intermediate_size]
B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size]
C = xBC[:, :, -self.args.state_size:]
# Reshape for SSM
x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim))
x = mx.squeeze(x, axis=1)
B = mx.reshape(B, (batch_size, 1, self.args.state_size))
B = mx.broadcast_to(B, (batch_size, n_heads, self.args.state_size))
B = mx.expand_dims(B, axis=2)
C = mx.reshape(C, (batch_size, 1, self.args.state_size))
C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size))
C = mx.expand_dims(C, axis=3)
# SSM updates
A = -mx.exp(self.A_log)
dA = mx.exp(dt * mx.expand_dims(A, 0))
dA = mx.expand_dims(mx.expand_dims(dA, -1), -1)
# Update state # Update state
x = mx.expand_dims(x, axis=3) next_state = (
dBx = mx.matmul(x, B) next_state * dAt + # Broadcasting: [batch, n_heads, d_head, d_state] * [batch, n_heads, 1, 1]
current_cache[1] = current_cache[1] * dA + dBx mx.matmul(
mx.expand_dims(xt, -1), # [batch, n_heads, d_head, 1]
mx.expand_dims(Bt, -2) # [batch, n_heads, 1, d_state]
)
)
# Compute output # Compute output
y = mx.matmul(current_cache[1], C) yt = mx.matmul(
y = mx.squeeze(y, axis=-1) next_state, # [batch, n_heads, d_head, d_state]
y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1) mx.expand_dims(Ct, -1) # [batch, n_heads, d_state, 1]
y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim)) )
y = self.norm(y + z) yt = mx.squeeze(yt, -1) # [batch, n_heads, d_head]
yt = yt + xt * mx.expand_dims(self.D, -1)
outputs.append(self.out_proj(y)) # Reshape and normalize
yt = mx.reshape(yt, (batch_size, 1, self.d_inner))
yt = self.norm(yt, z[:, t:t+1])
outputs.append(self.out_proj(yt))
# Update cache
if cache is not None:
cache[1] = next_state
# Concatenate all outputs
return mx.concatenate(outputs, axis=1) return mx.concatenate(outputs, axis=1)

View File

@@ -1,316 +0,0 @@
import math
from dataclasses import dataclass, field
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .cache import Mamba2Cache
@dataclass
class ModelArgs(BaseModelArgs):
num_heads: int
head_dim: int
vocab_size: int
hidden_size: int
state_size: int
num_hidden_layers: int
layer_norm_epsilon: float
expand: int
conv_kernel: int
n_groups: int
use_bias: bool
use_conv_bias: bool
initializer_range: float
residual_in_fp32: bool
time_step_min: float
time_step_max: float
time_step_floor: float
rescale_prenorm_residual: bool
rms_norm: bool
chunk_size: int
tie_word_embeddings: bool
use_cache: bool = True
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
time_step_rank: Union[int, str] = "auto"
model_type: str = "mamba2"
def __post_init__(self):
if not hasattr(self, "intermediate_size"):
self.intermediate_size = int(self.expand * self.hidden_size)
if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_heads
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
def silu(x):
return x * mx.sigmoid(x)
class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = mx.ones((hidden_size,))
self.variance_epsilon = eps
def __call__(self, hidden_states, gate=None):
# Fuse operations where possible
if gate is not None:
hidden_states = hidden_states * nn.silu(gate)
# Compute variance in fp32 for better numerical stability
hidden_states_fp32 = hidden_states.astype(mx.float32)
variance = mx.mean(hidden_states_fp32 * hidden_states_fp32, axis=-1, keepdims=True)
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
def ssd_optimized(x, A, B, C, chunk_size):
batch, seqlen, nheads, dim = x.shape
B = mx.expand_dims(B, axis=2)
C = mx.expand_dims(C, axis=2)
output = mx.zeros((batch, seqlen, nheads, dim))
state = mx.zeros((batch, nheads, dim, B.shape[-1]))
for i in range(0, seqlen, chunk_size):
chunk = slice(i, min(i + chunk_size, seqlen))
chunk_size_actual = min(chunk_size, seqlen - i)
dA = mx.exp(mx.expand_dims(A[chunk], axis=0))
x_chunk = mx.transpose(x[:, chunk], [0, 2, 3, 1])
dBx = mx.matmul(x_chunk, B[:, chunk])
state = state * mx.expand_dims(dA, axis=-1) + dBx
y = mx.matmul(state, mx.transpose(C[:, chunk], [0, 2, 1]))
output[:, i:i+chunk_size_actual] = mx.transpose(y, [0, 3, 1, 2])
return output, state
def update_conv_cache(x: mx.array, cache, kernel_size: int) -> Tuple[mx.array, mx.array]:
"""Update convolution cache for sequential processing."""
B, L, C = x.shape
if cache is None:
# Initialize cache with zeros
cache = mx.zeros((B, kernel_size - 1, C))
# Concatenate cache with current input
x_with_cache = mx.concatenate([cache, x], axis=1)
# Update cache with the last (kernel_size - 1) elements
new_cache = x_with_cache[:, -kernel_size+1:] if x_with_cache.shape[1] >= kernel_size else x_with_cache
return x_with_cache, new_cache
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.intermediate_size = int(args.expand * args.hidden_size)
self.state_size = args.state_size
self.num_heads = args.num_heads
self.head_dim = args.head_dim
self.ssm_state_size = args.state_size
self.n_groups = args.n_groups
self.conv_kernel = args.conv_kernel
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
projection_size = self.intermediate_size + self.conv_dim + self.num_heads
self.in_proj = nn.Linear(args.hidden_size, projection_size, bias=args.use_bias)
# Using built-in Conv1d instead of custom DepthwiseConv1d
self.conv1d = nn.Conv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
kernel_size=args.conv_kernel,
groups=self.conv_dim, # For depthwise convolution
padding=0, # We'll handle padding manually with the cache
bias=args.use_conv_bias
)
self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range
self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range
self.D = mx.random.normal((args.num_heads,)) * args.initializer_range
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
if args.rescale_prenorm_residual:
layer_scale = math.sqrt(1.0 / args.num_hidden_layers)
self.out_proj.weight = self.out_proj.weight * layer_scale
def __call__(self, u: mx.array, cache):
batch_size, seq_len, _ = u.shape
projected = self.in_proj(u)
d_conv = self.conv_dim
z = projected[..., :self.intermediate_size]
xBC = projected[..., self.intermediate_size:self.intermediate_size + d_conv]
dt = projected[..., -self.num_heads:]
dt = mx.clip(
nn.softplus(dt + self.dt_bias),
self.args.time_step_min,
self.args.time_step_max
)
dt = mx.maximum(dt, self.args.time_step_floor)
# Handle convolution with separate cache update
if cache is not None:
# Update cache and get padded input
xBC_padded, new_cache = update_conv_cache(xBC, cache.conv_states, self.conv_kernel)
cache.conv_states = new_cache
# Prepare input for conv1d: [B, L, C] -> [B, C, L]
xBC_conv = mx.transpose(xBC_padded, [0, 2, 1])
# Apply convolution
xBC = self.conv1d(xBC_conv)
# Transform back: [B, C, L] -> [B, L, C]
xBC = mx.transpose(xBC, [0, 2, 1])
# Take only the relevant part corresponding to input length
xBC = xBC[:, :seq_len]
else:
# For training, use regular convolution with padding
xBC = mx.transpose(xBC, [0, 2, 1])
xBC = self.conv1d(xBC)
xBC = mx.transpose(xBC, [0, 2, 1])
xBC = silu(xBC)
x = xBC[..., :self.intermediate_size]
BC = xBC[..., self.intermediate_size:]
B = BC[..., :self.state_size]
C = BC[..., self.state_size:]
x = mx.reshape(x, (-1, seq_len, self.num_heads, self.intermediate_size // self.num_heads))
A = -mx.exp(self.A_log)
D_expanded = mx.expand_dims(self.D, -1)
if cache is not None and cache.ssm_state is None:
cache.ssm_state = mx.zeros((
batch_size,
self.num_heads,
self.intermediate_size // self.num_heads,
self.state_size
))
if cache is not None:
output = mx.zeros((batch_size, seq_len, self.args.hidden_size))
for pos in range(seq_len):
x_t = x[:, pos:pos+1]
dA = mx.exp(dt[:, pos:pos+1] * mx.expand_dims(A, 0))
dA = mx.expand_dims(mx.expand_dims(dA, -1), -1)
x_expanded = mx.expand_dims(x_t, axis=3)
dBx = mx.matmul(x_expanded, mx.expand_dims(B[:, pos:pos+1], axis=2))
cache.ssm_state = cache.ssm_state * dA + dBx
y = mx.matmul(cache.ssm_state, mx.expand_dims(C[:, pos:pos+1], axis=3))
y = mx.squeeze(y, axis=-1)
y = y + x_t * D_expanded
y = mx.reshape(y, (batch_size, 1, -1))
y = self.norm(y + z[:, pos:pos+1])
y = self.out_proj(y)
if self.args.residual_in_fp32:
y = y.astype(mx.float32)
output = output.at[:, pos:pos+1].set(y)
else:
y, ssm_state = ssd_optimized(
x * mx.expand_dims(dt, -1),
-mx.exp(self.A_log) * dt,
B, C,
self.args.chunk_size
)
y = mx.reshape(
y + x * mx.expand_dims(self.D, -1),
(batch_size, seq_len, -1)
)
y = self.norm(y + z)
output = self.out_proj(y)
if self.args.residual_in_fp32:
output = output.astype(mx.float32)
return output
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.mixer = Mamba2Block(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
return self.mixer(self.norm(x), cache) + x
class Mamba2(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(self, x: mx.array, cache):
x = self.embeddings(x)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
x = layer(x, c)
return self.norm_f(x)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba2(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None):
B, T = inputs.shape
x = self.backbone(inputs, cache)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(x)
else:
logits = self.lm_head(x)
return logits
def make_cache(self, batch_size=1):
return [Mamba2Cache() for _ in range(len(self.layers))]
def sanitize(self, weights):
for k, v in weights.items():
if "conv1d.weight" in k and v.ndim == 3:
weights[k] = v.moveaxis(2, 1)
return weights
@property
def layers(self):
return self.backbone.layers