# coding=utf-8 # Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch MAMBA2 model.""" import math from dataclasses import dataclass from typing import Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss logger = logging.get_logger(__name__) def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): """ Padding x tensor with `pad_size` on the seq_len dim (dim=1) Assumes that we only have tensors of either size 4 or 3 """ pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) def reshape_into_chunks(input_tensor, pad_size, chunk_size): """ Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and simultaneously splitting it into chunk sequences. Assumes that we only have tensors of either size 4 or 3 """ # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] input_tensor = pad_tensor_by_size(input_tensor, pad_size) if len(input_tensor.shape) == 3: # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) else: # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size] return input_tensor.reshape( input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3] ) def segment_sum(input_tensor): """ More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions. """ chunk_size = input_tensor.size(-1) # 1. expand input tensor to have an additional dimension and repeat along that dimension # [..., chunk_size] -> [..., chunk_size, chunk_size] input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size) # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1) input_tensor = input_tensor.masked_fill(~mask, 0) # 3. compute actual cumsum tensor_segsum = torch.cumsum(input_tensor, dim=-2) # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time) mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0) tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf) return tensor_segsum class Mamba2Cache: """ Arguments: config: ModelArgs batch_size: int dtype: torch.dtype device: torch.device Attributes: seqlen_offset: int dtype: torch.dtype conv_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, conv_kernel] ssm_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, ssm_state_size] """ def __init__( self, config: ModelArgs, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None ): self.seqlen_offset = 0 self.dtype = dtype self.conv_kernel = config.conv_kernel self.intermediate_size = int(config.expand * config.hidden_size) self.conv_states = { i: torch.zeros( batch_size, self.intermediate_size + 2 * config.n_groups * config.state_size, self.conv_kernel, device=device, dtype=dtype, ) for i in range(config.num_hidden_layers) } self.ssm_states = { i: torch.zeros( batch_size, config.num_heads, config.head_dim, config.state_size, device=device, dtype=dtype ) for i in range(config.num_hidden_layers) } def update_conv_state( self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor ) -> torch.Tensor: conv_state = self.conv_states[layer_idx] cache_position = cache_position.clamp(0, self.conv_kernel - 1) conv_state = conv_state.roll(shifts=-1, dims=-1) conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device) self.conv_states[layer_idx].zero_() self.conv_states[layer_idx] += conv_state return self.conv_states[layer_idx] def reset(self): self.conv_states.zero_() self.ssm_states.zero_() class MambaRMSNormGated(torch.nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states, gate=None): input_dtype = hidden_states.dtype hidden_states = hidden_states if gate is not None: hidden_states = hidden_states * nn.functional.silu(gate) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states class Mamba2Mixer(nn.Module): def __init__(self, config: ModelArgs): super().__init__() self.num_heads = config.num_heads self.hidden_size = config.hidden_size self.state_size = config.state_size self.conv_kernel = config.conv_kernel self.intermediate_size = int(config.expand * self.hidden_size) self.time_step_rank = int(config.time_step_rank) self.use_conv_bias = config.use_conv_bias self.layer_norm_epsilon = config.layer_norm_epsilon self.n_groups = config.n_groups self.head_dim = config.head_dim self.chunk_size = config.chunk_size self.time_step_limit = config.time_step_limit self.time_step_min = config.time_step_min self.time_step_max = config.time_step_max self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size self.conv1d = nn.Conv1d( in_channels=self.conv_dim, out_channels=self.conv_dim, bias=config.use_conv_bias, kernel_size=config.conv_kernel, groups=self.conv_dim, padding=config.conv_kernel - 1, ) # projection of the input hidden states projection_size = self.intermediate_size + self.conv_dim + self.num_heads self.in_proj = nn.Linear( self.hidden_size, projection_size, bias=config.use_bias, ) self.dt_bias = torch.ones(self.num_heads) A = torch.arange(1, self.num_heads + 1) self.A_log = torch.log(A) self.D = torch.ones(self.num_heads) self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon) self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) def forward(self, input_states, cache: Optional[Mamba2Cache]=None): batch_size, seq_len, _ = input_states.shape # Gated MLP's linear projection projected_states = self.in_proj(input_states.squeeze(1)) d_mlp = ( projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.state_size- self.num_heads) // 2 _, _, gate, hidden_states, dt = projected_states.split( [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) # Convolution sequence transformation ssm_state = cache.ssm_states[self.layer_idx].clone() ssm_state = ssm_state.to(hidden_states.device) if cache.seqlen_offset > 0: conv_state = cache.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel] conv_state = torch.roll(conv_state, shifts=-1, dims=-1) # handle batched generation - states are copied through conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states cache.conv_states[self.layer_idx].copy_(conv_state) hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) if self.use_conv_bias: hidden_states += self.conv1d.bias hidden_states = nn.silu(hidden_states)[:, None, ...] # [batch, 1, intermediate_size] : decoding else: hidden_states = hidden_states.transpose(1,2) conv_state = nn.functional.pad( hidden_states, (self.conv_kernel - hidden_states.shape[-1], 0) ) cache.conv_states[self.layer_idx].copy_(conv_state) hidden_states = nn.silu(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len] hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.state_size, self.n_groups * self.state_size], dim=-1) A = -torch.exp(self.A_log.float()) # [num_heads] if cache is not None and cache.seqlen_offset > 0: # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...] dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) # [num_heads] -> [num_heads, head_dim] dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype)) dt = torch.clamp(dt, self.time_step_min) #, self.time_step_max) A = A[..., None, None].expand(self.num_heads, self.head_dim, self.state_size).to(dtype=torch.float32) # [bsz, num_heads, head_dim, state_size] dA = torch.exp(dt[..., None] * A) # Discretize B # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous() B = B.reshape(batch_size, -1, B.shape[-1]) # [bsz, num_heads, head_dim, state_size] dB = dt[..., None] * B[..., None, :] # Discretize x into dB # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) dBx = dB * hidden_states[..., None] # State calculation cache.ssm_states[self.layer_idx].copy_( cache.ssm_states[self.layer_idx] * dA + dBx ) # Subsequent output # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous() C = C.reshape(batch_size, -1, C.shape[-1]) # [bsz, num_heads, head_dim] ssm_states = cache.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.state_size) # Shape: [b*h, d, n] C_reshaped = C.view(batch_size * self.num_heads, self.state_size, 1) # Shape: [b*h, n, 1] y = torch.bmm(ssm_states_reshaped, C_reshaped) y = y.view(batch_size, self.num_heads, self.head_dim) # D skip connection # [num_heads] -> [num_heads, head_dim] D = self.D[..., None].expand(self.D.shape[0], self.head_dim) y = (y + hidden_states * D).to(y.dtype) # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size] y = y.reshape(batch_size, -1)[:, None, ...] else: # begin ssd naive implementation without einsums dt = nn.functional.softplus(dt + self.dt_bias) dt = torch.clamp(dt, self.time_step_min) hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() B = B.reshape(batch_size, seq_len, -1, self.state_size).float() C = C.reshape(batch_size, seq_len, -1, self.state_size).float() B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) pad_size = self.chunk_size - (seq_len % self.chunk_size) D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) # Discretize x and A hidden_states = hidden_states * dt[..., None] A = A.to(hidden_states.dtype) * dt # Rearrange into blocks/chunks hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] A = A.permute(0, 3, 1, 2) A_cumsum = torch.cumsum(A, dim=-1) # 1. Compute the output for each intra-chunk (diagonal blocks) # This is the analog of a causal mask L = torch.exp(segment_sum(A)) # First, contraction of C and B to get G (attention-weights like) G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n) G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) # Step 2: Compute M, equivalent to applying attention mask to weights M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] M = M_intermediate.sum(dim=-1) # Step 3: Compute Y_diag (apply to values) Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3) # (right term of low-rank factorization of off-diagonal blocks; B terms) decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None] # permute back B * decay states states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3) if cache is not None and cache.seqlen_offset > 0: previous_states = cache.ssm_states[self.layer_idx][:, None, ...] else: previous_states = torch.zeros_like(states[:, :1]) states = torch.cat([previous_states, states], dim=1) decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) states_permuted = states.permute(0, 2, 1, 3, 4) result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2) new_states = result.permute(0, 2, 1, 3, 4) states, ssm_state = new_states[:, :-1], new_states[:, -1] # Compute state -> output conversion per chunk # (left term of low-rank factorization of off-diagonal blocks; C terms) state_decay_out = torch.exp(A_cumsum) # compute Yoff C_times_states = (C[..., None, :] * states[:, :, None, ...]) state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) y = Y_diag + Y_off # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) y = y + D_residual # Cutting off padded chunks if pad_size > 0: y = y[:, :seq_len, :, :] y = y.reshape(batch_size, seq_len, -1) if ssm_state is not None and cache is not None: cache.ssm_states[self.layer_idx].copy_(ssm_state) scan_output = self.norm(y, gate) # end ssd naive # 4. Final linear projection contextualized_states = self.out_proj(scan_output) # [batch, seq_len, hidden_size] return contextualized_states class Mamba2Block(nn.Module): def __init__(self, config): super().__init__() self.config = config self.norm = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.mixer = Mamba2Mixer(config) def forward( self, hidden_states, cache: Optional[Mamba2Cache] = None, cache_position: Optional[torch.LongTensor] = None, ): x = self.mixer( self.norm(hidden_states), cache=cache, cache_position=cache_position ) return x + hidden_states class Mamba2Model(nn.Module): def __init__(self, config): super().__init__(config) self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([Mamba2Block(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]) self.gradient_checkpointing = False self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) def forward( self, input_ids: Optional[torch.LongTensor] = None, cache: Optional[Mamba2Cache] = None, cache_position: Optional[torch.LongTensor] = None, ): inputs_embeds = self.embeddings(input_ids) hidden_states = inputs_embeds for mixer_block in self.layers: hidden_states = mixer_block( hidden_states, cache=cache, cache_position=cache_position, ) cache.seqlen_offset += inputs_embeds.shape[1] return self.norm_f(hidden_states), cache class Mamba2ForCausalLM(nn.Module): def __init__(self, config): super().__init__(config) self.backbone = Mamba2Model(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def forward( self, input_ids: Optional[torch.LongTensor] = None, cache: Optional[Mamba2Cache] = None, cache_position: Optional[torch.Tensor] = None, ): out, cache = self.backbone( input_ids, cache=cache, cache_position=cache_position, ) logits = self.lm_head(out) return logits, cache