mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
updates
This commit is contained in:
@@ -1,437 +1,490 @@
|
||||
"""
|
||||
mamba2-minimal
|
||||
==============
|
||||
# coding=utf-8
|
||||
# Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch MAMBA2 model."""
|
||||
|
||||
A minimal, single-file implementation of the Mamba-2 model in PyTorch.
|
||||
|
||||
> **Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality**
|
||||
> Authors: Tri Dao, Albert Gu
|
||||
> Paper: https://arxiv.org/abs/2405.21060
|
||||
"""
|
||||
|
||||
import json
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, NamedTuple, TypeAlias, cast
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from torch import LongTensor, Tensor, nn
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
Device: TypeAlias = str | torch.device | None
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Mamba2Config:
|
||||
d_model: int # model dimension (D)
|
||||
n_layer: int = 24 # number of Mamba-2 layers in the language model
|
||||
d_state: int = 128 # state dimension (N)
|
||||
d_conv: int = 4 # convolution kernel size
|
||||
expand: int = 2 # expansion factor (E)
|
||||
headdim: int = 64 # head dimension (P)
|
||||
chunk_size: int = 64 # matrix partition size (Q)
|
||||
vocab_size: int = 50277
|
||||
pad_vocab_size_multiple: int = 16
|
||||
def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int):
|
||||
"""
|
||||
Padding x tensor with `pad_size` on the seq_len dim (dim=1)
|
||||
|
||||
def __post_init__(self):
|
||||
self.d_inner = self.expand * self.d_model
|
||||
assert self.d_inner % self.headdim == 0
|
||||
self.nheads = self.d_inner // self.headdim
|
||||
if self.vocab_size % self.pad_vocab_size_multiple != 0:
|
||||
self.vocab_size += (
|
||||
self.pad_vocab_size_multiple
|
||||
- self.vocab_size % self.pad_vocab_size_multiple
|
||||
)
|
||||
Assumes that we only have tensors of either size 4 or 3
|
||||
"""
|
||||
pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0)
|
||||
|
||||
return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0)
|
||||
|
||||
|
||||
class InferenceCache(NamedTuple):
|
||||
conv_state: Tensor # (batch, d_inner + 2 * d_state, d_conv)
|
||||
ssm_state: Tensor # (batch, nheads, headdim, d_state)
|
||||
def reshape_into_chunks(input_tensor, pad_size, chunk_size):
|
||||
"""
|
||||
Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and
|
||||
simultaneously splitting it into chunk sequences.
|
||||
|
||||
@staticmethod
|
||||
def alloc(batch_size: int, args: Mamba2Config, device: Device = None):
|
||||
return InferenceCache(
|
||||
torch.zeros(
|
||||
batch_size, args.d_inner + 2 * args.d_state, args.d_conv, device=device
|
||||
),
|
||||
torch.zeros(
|
||||
batch_size, args.nheads, args.headdim, args.d_state, device=device
|
||||
),
|
||||
Assumes that we only have tensors of either size 4 or 3
|
||||
"""
|
||||
# [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...]
|
||||
input_tensor = pad_tensor_by_size(input_tensor, pad_size)
|
||||
|
||||
if len(input_tensor.shape) == 3:
|
||||
# [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads]
|
||||
return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2])
|
||||
else:
|
||||
# [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size]
|
||||
return input_tensor.reshape(
|
||||
input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3]
|
||||
)
|
||||
|
||||
|
||||
class Mamba2LMHeadModel(nn.Module):
|
||||
def __init__(self, args: Mamba2Config, device: Device = None):
|
||||
def segment_sum(input_tensor):
|
||||
"""
|
||||
More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.
|
||||
"""
|
||||
chunk_size = input_tensor.size(-1)
|
||||
# 1. expand input tensor to have an additional dimension and repeat along that dimension
|
||||
# [..., chunk_size] -> [..., chunk_size, chunk_size]
|
||||
input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size)
|
||||
# 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag
|
||||
mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1)
|
||||
input_tensor = input_tensor.masked_fill(~mask, 0)
|
||||
# 3. compute actual cumsum
|
||||
tensor_segsum = torch.cumsum(input_tensor, dim=-2)
|
||||
|
||||
# 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time)
|
||||
mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0)
|
||||
tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf)
|
||||
return tensor_segsum
|
||||
|
||||
|
||||
class Mamba2Cache:
|
||||
"""
|
||||
Arguments:
|
||||
config: ModelArgs
|
||||
batch_size: int
|
||||
dtype: torch.dtype
|
||||
device: torch.device
|
||||
|
||||
Attributes:
|
||||
seqlen_offset: int
|
||||
dtype: torch.dtype
|
||||
conv_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, conv_kernel_size]
|
||||
ssm_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, ssm_state_size]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, config: ModelArgs, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None
|
||||
):
|
||||
self.seqlen_offset = 0
|
||||
self.dtype = dtype
|
||||
self.conv_kernel_size = config.conv_kernel
|
||||
self.intermediate_size = int(config.expand * config.hidden_size)
|
||||
|
||||
self.conv_states = {
|
||||
i: torch.zeros(
|
||||
batch_size,
|
||||
self.intermediate_size + 2 * config.n_groups * config.state_size,
|
||||
self.conv_kernel_size,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
}
|
||||
self.ssm_states = {
|
||||
i: torch.zeros(
|
||||
batch_size, config.num_heads, config.head_dim, config.state_size, device=device, dtype=dtype
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
}
|
||||
|
||||
def update_conv_state(
|
||||
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
|
||||
) -> torch.Tensor:
|
||||
conv_state = self.conv_states[layer_idx]
|
||||
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
|
||||
|
||||
conv_state = conv_state.roll(shifts=-1, dims=-1)
|
||||
conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
|
||||
self.conv_states[layer_idx].zero_()
|
||||
self.conv_states[layer_idx] += conv_state
|
||||
return self.conv_states[layer_idx]
|
||||
|
||||
def reset(self):
|
||||
self.conv_states.zero_()
|
||||
self.ssm_states.zero_()
|
||||
|
||||
|
||||
class MambaRMSNormGated(torch.nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.device = device
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
self.backbone = nn.ModuleDict(
|
||||
dict(
|
||||
embedding=nn.Embedding(args.vocab_size, args.d_model, device=device),
|
||||
layers=nn.ModuleList(
|
||||
[
|
||||
nn.ModuleDict(
|
||||
dict(
|
||||
mixer=Mamba2(args, device=device),
|
||||
norm=RMSNorm(args.d_model, device=device),
|
||||
)
|
||||
)
|
||||
for _ in range(args.n_layer)
|
||||
]
|
||||
),
|
||||
norm_f=RMSNorm(args.d_model, device=device),
|
||||
def forward(self, hidden_states, gate=None):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states
|
||||
|
||||
if gate is not None:
|
||||
hidden_states = hidden_states * nn.functional.silu(gate)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
|
||||
return self.weight * hidden_states
|
||||
|
||||
|
||||
class Mamba2Mixer(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.ssm_state_size = config.state_size
|
||||
self.conv_kernel_size = config.conv_kernel
|
||||
self.intermediate_size = int(config.expand * self.hidden_size)
|
||||
self.time_step_rank = int(config.time_step_rank)
|
||||
self.use_conv_bias = config.use_conv_bias
|
||||
self.act = nn.silu
|
||||
|
||||
self.layer_norm_epsilon = config.layer_norm_epsilon
|
||||
self.rms_norm = config.rms_norm
|
||||
|
||||
self.n_groups = config.n_groups
|
||||
self.head_dim = config.head_dim
|
||||
self.chunk_size = config.chunk_size
|
||||
|
||||
self.time_step_limit = config.time_step_limit
|
||||
self.time_step_min = config.time_step_min
|
||||
self.time_step_max = config.time_step_max
|
||||
|
||||
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
|
||||
self.conv1d = nn.Conv1d(
|
||||
in_channels=self.conv_dim,
|
||||
out_channels=self.conv_dim,
|
||||
bias=config.use_conv_bias,
|
||||
kernel_size=config.conv_kernel,
|
||||
groups=self.conv_dim,
|
||||
padding=config.conv_kernel - 1,
|
||||
)
|
||||
|
||||
# projection of the input hidden states
|
||||
projection_size = self.intermediate_size + self.conv_dim + self.num_heads
|
||||
self.in_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
projection_size,
|
||||
bias=config.use_bias,
|
||||
)
|
||||
|
||||
self.dt_bias = torch.ones(self.num_heads)
|
||||
A = torch.arange(1, self.num_heads + 1)
|
||||
self.A_log = torch.log(A)
|
||||
self.D = torch.ones(self.num_heads)
|
||||
|
||||
self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon)
|
||||
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
|
||||
|
||||
def forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, cache_position:Optional[torch.LongTensor]=None):
|
||||
batch_size, seq_len, _ = input_states.shape
|
||||
dtype = input_states.dtype
|
||||
|
||||
# Gated MLP's linear projection
|
||||
projected_states = self.in_proj(input_states.squeeze(1))
|
||||
d_mlp = (
|
||||
projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size- self.num_heads) // 2
|
||||
_, _, gate, hidden_states, dt = projected_states.split(
|
||||
[d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
|
||||
)
|
||||
|
||||
# Convolution sequence transformation
|
||||
if cache_params is not None:
|
||||
ssm_state = cache_params.ssm_states[self.layer_idx].clone()
|
||||
ssm_state = ssm_state.to(hidden_states.device)
|
||||
if cache_params.seqlen_offset > 0:
|
||||
conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size]
|
||||
conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
|
||||
# handle batched generation - states are copied through
|
||||
conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states
|
||||
cache_params.conv_states[self.layer_idx].copy_(conv_state)
|
||||
hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
|
||||
if self.use_conv_bias:
|
||||
hidden_states += self.conv1d.bias
|
||||
hidden_states = self.act(hidden_states)[:, None, ...] # [batch, 1, intermediate_size] : decoding
|
||||
else:
|
||||
hidden_states = hidden_states.transpose(1,2)
|
||||
conv_state = nn.functional.pad(
|
||||
hidden_states,
|
||||
(self.conv_kernel_size - hidden_states.shape[-1], 0)
|
||||
)
|
||||
cache_params.conv_states[self.layer_idx].copy_(conv_state)
|
||||
hidden_states = self.act(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len]
|
||||
else:
|
||||
ssm_state = torch.zeros(
|
||||
(batch_size, self.num_heads, self.head_dim, self.ssm_state_size),
|
||||
device=hidden_states.device
|
||||
)
|
||||
)
|
||||
self.lm_head = nn.Linear(
|
||||
args.d_model, args.vocab_size, bias=False, device=device
|
||||
)
|
||||
self.lm_head.weight = self.backbone.embedding.weight
|
||||
hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2))
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(huggingface_model_id: str, device: Device = None):
|
||||
from transformers.utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
from transformers.utils.hub import cached_file
|
||||
hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1)
|
||||
A = -torch.exp(self.A_log.float()) # [num_heads]
|
||||
|
||||
config_path = cached_file(huggingface_model_id, CONFIG_NAME)
|
||||
assert config_path, "Failed to get huggingface config file"
|
||||
state_dict_path = cached_file(huggingface_model_id, WEIGHTS_NAME)
|
||||
assert state_dict_path, "Failed to get huggingface state dict file"
|
||||
if cache_params is not None and cache_params.seqlen_offset > 0:
|
||||
# Note: there is no need to pad parameter matrices here, as there is just one new token
|
||||
# for batched generation
|
||||
dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...]
|
||||
dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim)
|
||||
# [num_heads] -> [num_heads, head_dim]
|
||||
dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)
|
||||
|
||||
config = json.load(open(config_path))
|
||||
args = Mamba2Config(
|
||||
d_model=config["d_model"],
|
||||
n_layer=config["n_layer"],
|
||||
vocab_size=config["vocab_size"],
|
||||
pad_vocab_size_multiple=config["pad_vocab_size_multiple"],
|
||||
)
|
||||
dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype))
|
||||
dt = torch.clamp(dt, self.time_step_min) #, self.time_step_max)
|
||||
A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
|
||||
# [bsz, num_heads, head_dim, state_size]
|
||||
dA = torch.exp(dt[..., None] * A)
|
||||
|
||||
map_location = "cpu" if device is None else device
|
||||
state_dict = torch.load(
|
||||
state_dict_path, weights_only=True, map_location=map_location, mmap=True
|
||||
)
|
||||
model = Mamba2LMHeadModel(args, device=device)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
return model
|
||||
# Discretize B
|
||||
# [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
|
||||
# -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size]
|
||||
B = B.reshape(batch_size, self.n_groups, -1)[..., None, :]
|
||||
B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous()
|
||||
B = B.reshape(batch_size, -1, B.shape[-1])
|
||||
# [bsz, num_heads, head_dim, state_size]
|
||||
dB = dt[..., None] * B[..., None, :]
|
||||
|
||||
# Discretize x into dB
|
||||
# [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
|
||||
hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
|
||||
dBx = dB * hidden_states[..., None]
|
||||
|
||||
# State calculation
|
||||
cache_params.ssm_states[self.layer_idx].copy_(
|
||||
cache_params.ssm_states[self.layer_idx] * dA + dBx
|
||||
)
|
||||
|
||||
# Subsequent output
|
||||
# [bsz, n_groups * state_size] -> [bsz, num_heads, state_size]
|
||||
C = C.reshape(batch_size, self.n_groups, -1)[..., None, :]
|
||||
C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous()
|
||||
C = C.reshape(batch_size, -1, C.shape[-1])
|
||||
# [bsz, num_heads, head_dim]
|
||||
|
||||
ssm_states = cache_params.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n]
|
||||
# Reshape ssm_states to merge the first two dimensions
|
||||
ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n]
|
||||
C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1]
|
||||
y = torch.bmm(ssm_states_reshaped, C_reshaped)
|
||||
y = y.view(batch_size, self.num_heads, self.head_dim)
|
||||
|
||||
# D skip connection
|
||||
# [num_heads] -> [num_heads, head_dim]
|
||||
D = self.D[..., None].expand(self.D.shape[0], self.head_dim)
|
||||
y = (y + hidden_states * D).to(y.dtype)
|
||||
|
||||
# [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size]
|
||||
y = y.reshape(batch_size, -1)[:, None, ...]
|
||||
else:
|
||||
# begin ssd naive implementation without einsums
|
||||
dt = nn.functional.softplus(dt + self.dt_bias)
|
||||
dt = torch.clamp(dt, self.time_step_min)
|
||||
hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
|
||||
B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
|
||||
C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
|
||||
B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
|
||||
C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
|
||||
pad_size = self.chunk_size - (seq_len % self.chunk_size)
|
||||
|
||||
D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
|
||||
|
||||
# Discretize x and A
|
||||
hidden_states = hidden_states * dt[..., None]
|
||||
A = A.to(hidden_states.dtype) * dt
|
||||
|
||||
# Rearrange into blocks/chunks
|
||||
hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]
|
||||
|
||||
|
||||
# [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size]
|
||||
A = A.permute(0, 3, 1, 2)
|
||||
A_cumsum = torch.cumsum(A, dim=-1)
|
||||
|
||||
# 1. Compute the output for each intra-chunk (diagonal blocks)
|
||||
# This is the analog of a causal mask
|
||||
L = torch.exp(segment_sum(A))
|
||||
|
||||
# First, contraction of C and B to get G (attention-weights like)
|
||||
G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n)
|
||||
G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)
|
||||
|
||||
|
||||
# Step 2: Compute M, equivalent to applying attention mask to weights
|
||||
M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
|
||||
M = M_intermediate.sum(dim=-1)
|
||||
|
||||
# Step 3: Compute Y_diag (apply to values)
|
||||
Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3)
|
||||
|
||||
# (right term of low-rank factorization of off-diagonal blocks; B terms)
|
||||
|
||||
decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
|
||||
B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None]
|
||||
# permute back B * decay states
|
||||
states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3)
|
||||
if cache_params is not None and cache_params.seqlen_offset > 0:
|
||||
previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...]
|
||||
else:
|
||||
previous_states = torch.zeros_like(states[:, :1])
|
||||
states = torch.cat([previous_states, states], dim=1)
|
||||
decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
|
||||
|
||||
states_permuted = states.permute(0, 2, 1, 3, 4)
|
||||
result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2)
|
||||
new_states = result.permute(0, 2, 1, 3, 4)
|
||||
states, ssm_state = new_states[:, :-1], new_states[:, -1]
|
||||
|
||||
# Compute state -> output conversion per chunk
|
||||
# (left term of low-rank factorization of off-diagonal blocks; C terms)
|
||||
state_decay_out = torch.exp(A_cumsum)
|
||||
# compute Yoff
|
||||
C_times_states = (C[..., None, :] * states[:, :, None, ...])
|
||||
state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
|
||||
Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
|
||||
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
|
||||
|
||||
y = Y_diag + Y_off
|
||||
# [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
|
||||
y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
|
||||
|
||||
y = y + D_residual
|
||||
# Cutting off padded chunks
|
||||
if pad_size > 0:
|
||||
y = y[:, :seq_len, :, :]
|
||||
y = y.reshape(batch_size, seq_len, -1)
|
||||
if ssm_state is not None and cache_params is not None:
|
||||
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
|
||||
|
||||
scan_output = self.norm(y, gate)
|
||||
|
||||
# end ssd naive
|
||||
|
||||
# 4. Final linear projection
|
||||
contextualized_states = self.out_proj(scan_output) # [batch, seq_len, hidden_size]
|
||||
return contextualized_states
|
||||
|
||||
|
||||
class Mamba2RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
Mamba2RMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states
|
||||
|
||||
|
||||
class Mamba2Block(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.norm = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mixer = Mamba2Mixer(config)
|
||||
|
||||
def forward(
|
||||
self, input_ids: LongTensor, h: list[InferenceCache] | list[None] | None = None
|
||||
) -> tuple[LongTensor, list[InferenceCache]]:
|
||||
"""
|
||||
Arguments
|
||||
input_ids: (batch, seqlen) tokens from `EleutherAI/gpt-neox-20b` tokenizer
|
||||
h: hidden states for inference step. If present the constant-time
|
||||
(wrt sequence length) inference path will be taken, input_ids
|
||||
should have shape (batch, 1) containing the next batch of prompt
|
||||
token.
|
||||
|
||||
Return (logits, h)
|
||||
logits: (batch, seqlen, vocab_size)
|
||||
h: updated inference cache after processing `input_ids`
|
||||
"""
|
||||
seqlen = input_ids.shape[1]
|
||||
|
||||
if h is None:
|
||||
h = [None for _ in range(self.args.n_layer)]
|
||||
|
||||
x = self.backbone.embedding(input_ids)
|
||||
for i, layer in enumerate(self.backbone.layers):
|
||||
y, h[i] = layer.mixer(layer.norm(x), h[i])
|
||||
x = y + x
|
||||
|
||||
x = self.backbone.norm_f(x)
|
||||
logits = self.lm_head(x)
|
||||
return logits[:, :seqlen], cast(list[InferenceCache], h)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_ids: LongTensor,
|
||||
max_new_length: int = 20,
|
||||
temperature: float = 1.0,
|
||||
top_k: int = 50,
|
||||
top_p: float = 1.0,
|
||||
eos_token_id: int = 0,
|
||||
) -> Iterable[tuple[int, list[InferenceCache]]]:
|
||||
prefix, tokens = input_ids[:-1], input_ids[-1:].unsqueeze(0)
|
||||
hidden_states,
|
||||
cache_params: Optional[Mamba2Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
x = self.mixer(
|
||||
self.norm(hidden_states), cache_params=cache_params, cache_position=cache_position
|
||||
)
|
||||
return x + hidden_states
|
||||
|
||||
# Process prompt
|
||||
# The input sequence to forward (non-inference path) must have length multiple that of chunk_size.
|
||||
# We split out excess tokens so that n_chunked tokens can be processed by one forward call and
|
||||
# process the rest in multiple inference steps.
|
||||
n_chunked = (prefix.shape[0] // self.args.chunk_size) * self.args.chunk_size
|
||||
if n_chunked > 0:
|
||||
_, h = self(prefix[:n_chunked].unsqueeze(0), None)
|
||||
|
||||
class Mamba2Model(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = nn.ModuleList([Mamba2Block(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.norm_f = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
cache_params: Optional[Mamba2Cache] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
inputs_embeds = self.embeddings(input_ids)
|
||||
|
||||
if use_cache:
|
||||
if cache_params is None:
|
||||
cache_params = Mamba2Cache(
|
||||
self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
|
||||
)
|
||||
cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device)
|
||||
else:
|
||||
h = [
|
||||
InferenceCache.alloc(1, self.args, device=self.device)
|
||||
for _ in range(self.args.n_layer)
|
||||
]
|
||||
for i in range(n_chunked, prefix.shape[0]):
|
||||
_, h = self(prefix[i : i + 1].unsqueeze(0), h)
|
||||
cache_params = None
|
||||
|
||||
# Generate
|
||||
for _ in range(max_new_length):
|
||||
with torch.no_grad():
|
||||
out, h = self(tokens, h)
|
||||
logits = out[0, -1]
|
||||
if temperature != 1.0:
|
||||
logits = logits / temperature
|
||||
if top_k > 0:
|
||||
indices_to_remove = logits < torch.topk(logits, k=top_k)[0][-1]
|
||||
logits[indices_to_remove] = -torch.inf
|
||||
if top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
sorted_indices_to_remove = cum_probs > 0.5
|
||||
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
|
||||
sorted_indices_to_remove[0] = False
|
||||
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
||||
logits[indices_to_remove] = -torch.inf
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
if next_token.item() == eos_token_id:
|
||||
return
|
||||
tokens = next_token.unsqueeze(0)
|
||||
yield cast(int, next_token.item()), h
|
||||
hidden_states = inputs_embeds
|
||||
for mixer_block in self.layers:
|
||||
hidden_states = mixer_block(
|
||||
hidden_states,
|
||||
cache_params=cache_params,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
cache_params.seqlen_offset += inputs_embeds.shape[1]
|
||||
|
||||
return self.norm_f(hidden_states), cache_params if use_cache else None
|
||||
|
||||
|
||||
class Mamba2(nn.Module):
|
||||
def __init__(self, args: Mamba2Config, device: Device = None):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.device = device
|
||||
|
||||
# Order: (z, x, B, C, dt)
|
||||
d_in_proj = 2 * args.d_inner + 2 * args.d_state + args.nheads
|
||||
self.in_proj = nn.Linear(args.d_model, d_in_proj, bias=False, device=device)
|
||||
class Mamba2ForCausalLM(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.backbone = Mamba2Model(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
conv_dim = args.d_inner + 2 * args.d_state
|
||||
self.conv1d = nn.Conv1d(
|
||||
in_channels=conv_dim,
|
||||
out_channels=conv_dim,
|
||||
kernel_size=args.d_conv,
|
||||
groups=conv_dim,
|
||||
padding=args.d_conv - 1,
|
||||
device=device,
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
cache_params: Optional[Mamba2Cache] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
):
|
||||
mamba2_outputs = self.backbone(
|
||||
input_ids,
|
||||
cache_params=cache_params,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
hidden_states = mamba2_outputs[0]
|
||||
|
||||
self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device))
|
||||
self.A_log = nn.Parameter(torch.empty(args.nheads, device=device))
|
||||
self.D = nn.Parameter(torch.empty(args.nheads, device=device))
|
||||
self.norm = RMSNorm(args.d_inner, device=device)
|
||||
self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device)
|
||||
|
||||
def forward(self, u: Tensor, h: InferenceCache | None = None):
|
||||
"""
|
||||
Arguments
|
||||
u: (batch, seqlen, d_model) input. seqlen should be a multiple of chunk_size.
|
||||
h: hidden states for inference step. Initialized to 0s if not present.
|
||||
|
||||
Return (y, h)
|
||||
y: (batch, seqlen, d_model) output
|
||||
h: updated inference cache after processing `u`
|
||||
"""
|
||||
if h:
|
||||
return self.step(u, h)
|
||||
|
||||
A = -torch.exp(self.A_log) # (nheads,)
|
||||
zxbcdt = self.in_proj(u) # (batch, seqlen, d_in_proj)
|
||||
z, xBC, dt = torch.split(
|
||||
zxbcdt,
|
||||
[
|
||||
self.args.d_inner,
|
||||
self.args.d_inner + 2 * self.args.d_state,
|
||||
self.args.nheads,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
dt = F.softplus(dt + self.dt_bias) # (batch, seqlen, nheads)
|
||||
|
||||
# Pad or truncate xBC seqlen to d_conv
|
||||
conv_state = F.pad(
|
||||
rearrange(xBC, "b l d -> b d l"), (self.args.d_conv - u.shape[1], 0)
|
||||
)
|
||||
|
||||
xBC = silu(
|
||||
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, : u.shape[1], :]
|
||||
) # (batch, seqlen, d_inner + 2 * d_state))
|
||||
x, B, C = torch.split(
|
||||
xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1
|
||||
)
|
||||
x = rearrange(x, "b l (h p) -> b l h p", p=self.args.headdim)
|
||||
y, ssm_state = ssd(
|
||||
x * dt.unsqueeze(-1),
|
||||
A * dt,
|
||||
rearrange(B, "b l n -> b l 1 n"),
|
||||
rearrange(C, "b l n -> b l 1 n"),
|
||||
self.args.chunk_size,
|
||||
device=self.device,
|
||||
)
|
||||
y = y + x * self.D.unsqueeze(-1)
|
||||
y = rearrange(y, "b l h p -> b l (h p)")
|
||||
y = self.norm(y, z)
|
||||
y = self.out_proj(y)
|
||||
|
||||
h = InferenceCache(conv_state, ssm_state)
|
||||
return y, h
|
||||
|
||||
def step(self, u: Tensor, h: InferenceCache) -> tuple[Tensor, InferenceCache]:
|
||||
"""Take a single inference step for the current input and hidden state
|
||||
|
||||
Unlike attention-based models, RNN-based models (eg Mamba) does not need
|
||||
to look back at all the past tokens to generate a new token. Instead a
|
||||
hidden state (initialized to 0s initially) is updated for each input and
|
||||
passed to the next inference step. This means that the total inference
|
||||
time is linear with respect to the sequence length instead of quadratic
|
||||
in attention's case.
|
||||
|
||||
Arguments
|
||||
u: (batch, 1, d_model)
|
||||
h: initial/running hidden state
|
||||
|
||||
Return (y, h)
|
||||
y: (batch, 1, d_model)
|
||||
h: updated hidden state
|
||||
"""
|
||||
assert u.shape[1] == 1, "Only one token can be decoded per inference step"
|
||||
|
||||
zxbcdt = self.in_proj(u.squeeze(1)) # (batch, d_in_proj)
|
||||
z, xBC, dt = torch.split(
|
||||
zxbcdt,
|
||||
[
|
||||
self.args.d_inner,
|
||||
self.args.d_inner + 2 * self.args.d_state,
|
||||
self.args.nheads,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# Advance convolution input
|
||||
h.conv_state.copy_(torch.roll(h.conv_state, shifts=-1, dims=-1))
|
||||
h.conv_state[:, :, -1] = xBC
|
||||
# Convolution step
|
||||
xBC = torch.sum(
|
||||
h.conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
|
||||
)
|
||||
xBC += self.conv1d.bias
|
||||
xBC = silu(xBC)
|
||||
|
||||
x, B, C = torch.split(
|
||||
xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1
|
||||
)
|
||||
A = -torch.exp(self.A_log) # (nheads,)
|
||||
|
||||
# SSM step
|
||||
dt = F.softplus(dt + self.dt_bias) # (batch, nheads)
|
||||
dA = torch.exp(dt * A) # (batch, nheads)
|
||||
x = rearrange(x, "b (h p) -> b h p", p=self.args.headdim)
|
||||
dBx = torch.einsum("bh, bn, bhp -> bhpn", dt, B, x)
|
||||
h.ssm_state.copy_(h.ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
|
||||
y = torch.einsum("bhpn, bn -> bhp", h.ssm_state, C)
|
||||
y = y + rearrange(self.D, "h -> h 1") * x
|
||||
y = rearrange(y, "b h p -> b (h p)")
|
||||
y = self.norm(y, z)
|
||||
y = self.out_proj(y)
|
||||
|
||||
return y.unsqueeze(1), h
|
||||
|
||||
|
||||
def segsum(x: Tensor, device: Device = None) -> Tensor:
|
||||
"""Stable segment sum calculation.
|
||||
|
||||
`exp(segsum(A))` produces a 1-semiseparable matrix, which is equivalent to a scalar SSM.
|
||||
|
||||
Source: https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L23-L32
|
||||
"""
|
||||
T = x.size(-1)
|
||||
x = repeat(x, "... d -> ... d e", e=T)
|
||||
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=-1)
|
||||
x = x.masked_fill(~mask, 0)
|
||||
x_segsum = torch.cumsum(x, dim=-2)
|
||||
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=0)
|
||||
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
|
||||
return x_segsum
|
||||
|
||||
|
||||
def ssd(x, A, B, C, chunk_size, initial_states=None, device: Device = None):
|
||||
"""Structed State Space Duality (SSD) - the core of Mamba-2
|
||||
|
||||
This is almost the exact same minimal SSD code from the blog post.
|
||||
|
||||
Arguments
|
||||
x: (batch, seqlen, n_heads, d_head)
|
||||
A: (batch, seqlen, n_heads)
|
||||
B: (batch, seqlen, n_heads, d_state)
|
||||
C: (batch, seqlen, n_heads, d_state)
|
||||
|
||||
Return
|
||||
y: (batch, seqlen, n_heads, d_head)
|
||||
|
||||
Source
|
||||
1. https://tridao.me/blog/2024/mamba2-part3-algorithm/
|
||||
2. https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L34-L78
|
||||
"""
|
||||
assert x.shape[1] % chunk_size == 0
|
||||
|
||||
# Rearrange into chunks
|
||||
# Step 1, 2 and 4 of SSD can be computed in parallel for each chunk across devices (sequence parallel)
|
||||
# This is not implemented and left as an exercise for the reader 😜
|
||||
x, A, B, C = [
|
||||
rearrange(m, "b (c l) ... -> b c l ...", l=chunk_size) for m in (x, A, B, C)
|
||||
]
|
||||
|
||||
A = rearrange(A, "b c l h -> b h c l")
|
||||
A_cumsum = torch.cumsum(A, dim=-1)
|
||||
|
||||
# 1. Compute the output for each intra-chunk (diagonal blocks)
|
||||
L = torch.exp(segsum(A, device=device))
|
||||
Y_diag = torch.einsum("bclhn, bcshn, bhcls, bcshp -> bclhp", C, B, L, x)
|
||||
|
||||
# 2. Compute the state for each intra-chunk
|
||||
# (right term of low-rank factorization of off-diagonal blocks; B terms)
|
||||
decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
|
||||
states = torch.einsum("bclhn, bhcl, bclhp -> bchpn", B, decay_states, x)
|
||||
|
||||
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
|
||||
# (middle term of factorization of off-diag blocks; A terms)
|
||||
if initial_states is None:
|
||||
initial_states = torch.zeros_like(states[:, :1])
|
||||
states = torch.cat([initial_states, states], dim=1)
|
||||
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)), device=device))
|
||||
new_states = torch.einsum("bhzc, bchpn -> bzhpn", decay_chunk, states)
|
||||
states, final_state = new_states[:, :-1], new_states[:, -1]
|
||||
|
||||
# 4. Compute state -> output conversion per chunk
|
||||
# (left term of low-rank factorization of off-diagonal blocks; C terms)
|
||||
state_decay_out = torch.exp(A_cumsum)
|
||||
Y_off = torch.einsum("bclhn, bchpn, bhcl -> bclhp", C, states, state_decay_out)
|
||||
|
||||
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
|
||||
Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
|
||||
|
||||
return Y, final_state
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, d: int, eps: float = 1e-5, device: Device = None):
|
||||
"""Gated Root Mean Square Layer Normalization
|
||||
|
||||
Paper: https://arxiv.org/abs/1910.07467
|
||||
"""
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(d, device=device))
|
||||
|
||||
def forward(self, x, z=None):
|
||||
if z is not None:
|
||||
x = x * silu(z)
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
|
||||
|
||||
|
||||
def silu(x):
|
||||
"""Applies the Sigmoid Linear Unit (SiLU), element-wise.
|
||||
|
||||
Define this manually since torch's version doesn't seem to work on MPS.
|
||||
"""
|
||||
return x * F.sigmoid(x)
|
||||
logits = self.lm_head(hidden_states)
|
||||
return logits, mamba2_outputs.cache_params, mamba2_outputs.hidden_states
|
||||
Reference in New Issue
Block a user