diff --git a/llms/mlx_lm/models/mamba2-prch-minimal.py b/llms/mlx_lm/models/mamba2-prch-minimal.py deleted file mode 100644 index f988a825..00000000 --- a/llms/mlx_lm/models/mamba2-prch-minimal.py +++ /dev/null @@ -1,437 +0,0 @@ -""" -mamba2-minimal -============== - -A minimal, single-file implementation of the Mamba-2 model in PyTorch. - -> **Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality** -> Authors: Tri Dao, Albert Gu -> Paper: https://arxiv.org/abs/2405.21060 -""" - -import json -from dataclasses import dataclass -from typing import Iterable, NamedTuple, TypeAlias, cast - -import torch -import torch.nn.functional as F -from einops import rearrange, repeat -from torch import LongTensor, Tensor, nn - -Device: TypeAlias = str | torch.device | None - - -@dataclass -class Mamba2Config: - d_model: int # model dimension (D) - n_layer: int = 24 # number of Mamba-2 layers in the language model - d_state: int = 128 # state dimension (N) - d_conv: int = 4 # convolution kernel size - expand: int = 2 # expansion factor (E) - headdim: int = 64 # head dimension (P) - chunk_size: int = 64 # matrix partition size (Q) - vocab_size: int = 50277 - pad_vocab_size_multiple: int = 16 - - def __post_init__(self): - self.d_inner = self.expand * self.d_model - assert self.d_inner % self.headdim == 0 - self.nheads = self.d_inner // self.headdim - if self.vocab_size % self.pad_vocab_size_multiple != 0: - self.vocab_size += ( - self.pad_vocab_size_multiple - - self.vocab_size % self.pad_vocab_size_multiple - ) - - -class InferenceCache(NamedTuple): - conv_state: Tensor # (batch, d_inner + 2 * d_state, d_conv) - ssm_state: Tensor # (batch, nheads, headdim, d_state) - - @staticmethod - def alloc(batch_size: int, args: Mamba2Config, device: Device = None): - return InferenceCache( - torch.zeros( - batch_size, args.d_inner + 2 * args.d_state, args.d_conv, device=device - ), - torch.zeros( - batch_size, args.nheads, args.headdim, args.d_state, device=device - ), - ) - - -class Mamba2LMHeadModel(nn.Module): - def __init__(self, args: Mamba2Config, device: Device = None): - super().__init__() - self.args = args - self.device = device - - self.backbone = nn.ModuleDict( - dict( - embedding=nn.Embedding(args.vocab_size, args.d_model, device=device), - layers=nn.ModuleList( - [ - nn.ModuleDict( - dict( - mixer=Mamba2(args, device=device), - norm=RMSNorm(args.d_model, device=device), - ) - ) - for _ in range(args.n_layer) - ] - ), - norm_f=RMSNorm(args.d_model, device=device), - ) - ) - self.lm_head = nn.Linear( - args.d_model, args.vocab_size, bias=False, device=device - ) - self.lm_head.weight = self.backbone.embedding.weight - - @staticmethod - def from_pretrained(huggingface_model_id: str, device: Device = None): - from transformers.utils import CONFIG_NAME, WEIGHTS_NAME - from transformers.utils.hub import cached_file - - config_path = cached_file(huggingface_model_id, CONFIG_NAME) - assert config_path, "Failed to get huggingface config file" - state_dict_path = cached_file(huggingface_model_id, WEIGHTS_NAME) - assert state_dict_path, "Failed to get huggingface state dict file" - - config = json.load(open(config_path)) - args = Mamba2Config( - d_model=config["d_model"], - n_layer=config["n_layer"], - vocab_size=config["vocab_size"], - pad_vocab_size_multiple=config["pad_vocab_size_multiple"], - ) - - map_location = "cpu" if device is None else device - state_dict = torch.load( - state_dict_path, weights_only=True, map_location=map_location, mmap=True - ) - model = Mamba2LMHeadModel(args, device=device) - model.load_state_dict(state_dict) - model.eval() - return model - - def forward( - self, input_ids: LongTensor, h: list[InferenceCache] | list[None] | None = None - ) -> tuple[LongTensor, list[InferenceCache]]: - """ - Arguments - input_ids: (batch, seqlen) tokens from `EleutherAI/gpt-neox-20b` tokenizer - h: hidden states for inference step. If present the constant-time - (wrt sequence length) inference path will be taken, input_ids - should have shape (batch, 1) containing the next batch of prompt - token. - - Return (logits, h) - logits: (batch, seqlen, vocab_size) - h: updated inference cache after processing `input_ids` - """ - seqlen = input_ids.shape[1] - - if h is None: - h = [None for _ in range(self.args.n_layer)] - - x = self.backbone.embedding(input_ids) - for i, layer in enumerate(self.backbone.layers): - y, h[i] = layer.mixer(layer.norm(x), h[i]) - x = y + x - - x = self.backbone.norm_f(x) - logits = self.lm_head(x) - return logits[:, :seqlen], cast(list[InferenceCache], h) - - def generate( - self, - input_ids: LongTensor, - max_new_length: int = 20, - temperature: float = 1.0, - top_k: int = 50, - top_p: float = 1.0, - eos_token_id: int = 0, - ) -> Iterable[tuple[int, list[InferenceCache]]]: - prefix, tokens = input_ids[:-1], input_ids[-1:].unsqueeze(0) - - # Process prompt - # The input sequence to forward (non-inference path) must have length multiple that of chunk_size. - # We split out excess tokens so that n_chunked tokens can be processed by one forward call and - # process the rest in multiple inference steps. - n_chunked = (prefix.shape[0] // self.args.chunk_size) * self.args.chunk_size - if n_chunked > 0: - _, h = self(prefix[:n_chunked].unsqueeze(0), None) - else: - h = [ - InferenceCache.alloc(1, self.args, device=self.device) - for _ in range(self.args.n_layer) - ] - for i in range(n_chunked, prefix.shape[0]): - _, h = self(prefix[i : i + 1].unsqueeze(0), h) - - # Generate - for _ in range(max_new_length): - with torch.no_grad(): - out, h = self(tokens, h) - logits = out[0, -1] - if temperature != 1.0: - logits = logits / temperature - if top_k > 0: - indices_to_remove = logits < torch.topk(logits, k=top_k)[0][-1] - logits[indices_to_remove] = -torch.inf - if top_p < 1.0: - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) - sorted_indices_to_remove = cum_probs > 0.5 - sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() - sorted_indices_to_remove[0] = False - indices_to_remove = sorted_indices[sorted_indices_to_remove] - logits[indices_to_remove] = -torch.inf - probs = F.softmax(logits, dim=-1) - next_token = torch.multinomial(probs, num_samples=1) - if next_token.item() == eos_token_id: - return - tokens = next_token.unsqueeze(0) - yield cast(int, next_token.item()), h - - -class Mamba2(nn.Module): - def __init__(self, args: Mamba2Config, device: Device = None): - super().__init__() - self.args = args - self.device = device - - # Order: (z, x, B, C, dt) - d_in_proj = 2 * args.d_inner + 2 * args.d_state + args.nheads - self.in_proj = nn.Linear(args.d_model, d_in_proj, bias=False, device=device) - - conv_dim = args.d_inner + 2 * args.d_state - self.conv1d = nn.Conv1d( - in_channels=conv_dim, - out_channels=conv_dim, - kernel_size=args.d_conv, - groups=conv_dim, - padding=args.d_conv - 1, - device=device, - ) - - self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device)) - self.A_log = nn.Parameter(torch.empty(args.nheads, device=device)) - self.D = nn.Parameter(torch.empty(args.nheads, device=device)) - self.norm = RMSNorm(args.d_inner, device=device) - self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device) - - def forward(self, u: Tensor, h: InferenceCache | None = None): - """ - Arguments - u: (batch, seqlen, d_model) input. seqlen should be a multiple of chunk_size. - h: hidden states for inference step. Initialized to 0s if not present. - - Return (y, h) - y: (batch, seqlen, d_model) output - h: updated inference cache after processing `u` - """ - if h: - return self.step(u, h) - - A = -torch.exp(self.A_log) # (nheads,) - zxbcdt = self.in_proj(u) # (batch, seqlen, d_in_proj) - z, xBC, dt = torch.split( - zxbcdt, - [ - self.args.d_inner, - self.args.d_inner + 2 * self.args.d_state, - self.args.nheads, - ], - dim=-1, - ) - dt = F.softplus(dt + self.dt_bias) # (batch, seqlen, nheads) - - # Pad or truncate xBC seqlen to d_conv - conv_state = F.pad( - rearrange(xBC, "b l d -> b d l"), (self.args.d_conv - u.shape[1], 0) - ) - - xBC = silu( - self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, : u.shape[1], :] - ) # (batch, seqlen, d_inner + 2 * d_state)) - x, B, C = torch.split( - xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1 - ) - x = rearrange(x, "b l (h p) -> b l h p", p=self.args.headdim) - y, ssm_state = ssd( - x * dt.unsqueeze(-1), - A * dt, - rearrange(B, "b l n -> b l 1 n"), - rearrange(C, "b l n -> b l 1 n"), - self.args.chunk_size, - device=self.device, - ) - y = y + x * self.D.unsqueeze(-1) - y = rearrange(y, "b l h p -> b l (h p)") - y = self.norm(y, z) - y = self.out_proj(y) - - h = InferenceCache(conv_state, ssm_state) - return y, h - - def step(self, u: Tensor, h: InferenceCache) -> tuple[Tensor, InferenceCache]: - """Take a single inference step for the current input and hidden state - - Unlike attention-based models, RNN-based models (eg Mamba) does not need - to look back at all the past tokens to generate a new token. Instead a - hidden state (initialized to 0s initially) is updated for each input and - passed to the next inference step. This means that the total inference - time is linear with respect to the sequence length instead of quadratic - in attention's case. - - Arguments - u: (batch, 1, d_model) - h: initial/running hidden state - - Return (y, h) - y: (batch, 1, d_model) - h: updated hidden state - """ - assert u.shape[1] == 1, "Only one token can be decoded per inference step" - - zxbcdt = self.in_proj(u.squeeze(1)) # (batch, d_in_proj) - z, xBC, dt = torch.split( - zxbcdt, - [ - self.args.d_inner, - self.args.d_inner + 2 * self.args.d_state, - self.args.nheads, - ], - dim=-1, - ) - - # Advance convolution input - h.conv_state.copy_(torch.roll(h.conv_state, shifts=-1, dims=-1)) - h.conv_state[:, :, -1] = xBC - # Convolution step - xBC = torch.sum( - h.conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1 - ) - xBC += self.conv1d.bias - xBC = silu(xBC) - - x, B, C = torch.split( - xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1 - ) - A = -torch.exp(self.A_log) # (nheads,) - - # SSM step - dt = F.softplus(dt + self.dt_bias) # (batch, nheads) - dA = torch.exp(dt * A) # (batch, nheads) - x = rearrange(x, "b (h p) -> b h p", p=self.args.headdim) - dBx = torch.einsum("bh, bn, bhp -> bhpn", dt, B, x) - h.ssm_state.copy_(h.ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx) - y = torch.einsum("bhpn, bn -> bhp", h.ssm_state, C) - y = y + rearrange(self.D, "h -> h 1") * x - y = rearrange(y, "b h p -> b (h p)") - y = self.norm(y, z) - y = self.out_proj(y) - - return y.unsqueeze(1), h - - -def segsum(x: Tensor, device: Device = None) -> Tensor: - """Stable segment sum calculation. - - `exp(segsum(A))` produces a 1-semiseparable matrix, which is equivalent to a scalar SSM. - - Source: https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L23-L32 - """ - T = x.size(-1) - x = repeat(x, "... d -> ... d e", e=T) - mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=-1) - x = x.masked_fill(~mask, 0) - x_segsum = torch.cumsum(x, dim=-2) - mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=0) - x_segsum = x_segsum.masked_fill(~mask, -torch.inf) - return x_segsum - - -def ssd(x, A, B, C, chunk_size, initial_states=None, device: Device = None): - """Structed State Space Duality (SSD) - the core of Mamba-2 - - This is almost the exact same minimal SSD code from the blog post. - - Arguments - x: (batch, seqlen, n_heads, d_head) - A: (batch, seqlen, n_heads) - B: (batch, seqlen, n_heads, d_state) - C: (batch, seqlen, n_heads, d_state) - - Return - y: (batch, seqlen, n_heads, d_head) - - Source - 1. https://tridao.me/blog/2024/mamba2-part3-algorithm/ - 2. https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L34-L78 - """ - assert x.shape[1] % chunk_size == 0 - - # Rearrange into chunks - # Step 1, 2 and 4 of SSD can be computed in parallel for each chunk across devices (sequence parallel) - # This is not implemented and left as an exercise for the reader 😜 - x, A, B, C = [ - rearrange(m, "b (c l) ... -> b c l ...", l=chunk_size) for m in (x, A, B, C) - ] - - A = rearrange(A, "b c l h -> b h c l") - A_cumsum = torch.cumsum(A, dim=-1) - - # 1. Compute the output for each intra-chunk (diagonal blocks) - L = torch.exp(segsum(A, device=device)) - Y_diag = torch.einsum("bclhn, bcshn, bhcls, bcshp -> bclhp", C, B, L, x) - - # 2. Compute the state for each intra-chunk - # (right term of low-rank factorization of off-diagonal blocks; B terms) - decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum) - states = torch.einsum("bclhn, bhcl, bclhp -> bchpn", B, decay_states, x) - - # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries - # (middle term of factorization of off-diag blocks; A terms) - if initial_states is None: - initial_states = torch.zeros_like(states[:, :1]) - states = torch.cat([initial_states, states], dim=1) - decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)), device=device)) - new_states = torch.einsum("bhzc, bchpn -> bzhpn", decay_chunk, states) - states, final_state = new_states[:, :-1], new_states[:, -1] - - # 4. Compute state -> output conversion per chunk - # (left term of low-rank factorization of off-diagonal blocks; C terms) - state_decay_out = torch.exp(A_cumsum) - Y_off = torch.einsum("bclhn, bchpn, bhcl -> bclhp", C, states, state_decay_out) - - # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) - Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p") - - return Y, final_state - - -class RMSNorm(nn.Module): - def __init__(self, d: int, eps: float = 1e-5, device: Device = None): - """Gated Root Mean Square Layer Normalization - - Paper: https://arxiv.org/abs/1910.07467 - """ - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(d, device=device)) - - def forward(self, x, z=None): - if z is not None: - x = x * silu(z) - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight - - -def silu(x): - """Applies the Sigmoid Linear Unit (SiLU), element-wise. - - Define this manually since torch's version doesn't seem to work on MPS. - """ - return x * F.sigmoid(x) \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba2-prch.py b/llms/mlx_lm/models/mamba2-prch.py deleted file mode 100644 index 69390ea9..00000000 --- a/llms/mlx_lm/models/mamba2-prch.py +++ /dev/null @@ -1,1081 +0,0 @@ -# coding=utf-8 -# Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch MAMBA2 model.""" - -import math -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import CrossEntropyLoss - -from ...activations import ACT2FN -from ...modeling_utils import PreTrainedModel -from ...utils import ( - ModelOutput, - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, -) -from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available -from .configuration_mamba2 import Mamba2Config - - -logger = logging.get_logger(__name__) - - -if is_mamba_2_ssm_available(): - from mamba_ssm.ops.triton.selective_state_update import selective_state_update - from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined -else: - selective_state_update = None - -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -else: - causal_conv1d_update, causal_conv1d_fn = None, None - -is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) - -_CHECKPOINT_FOR_DOC = "mistralai/mamba-codestral-7B-v0.1" -_CONFIG_FOR_DOC = "Mamba2Config" - - -# Helper methods for segment sum computation - - -def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): - """ - Padding x tensor with `pad_size` on the seq_len dim (dim=1) - - Assumes that we only have tensors of either size 4 or 3 - """ - pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) - - return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) - - -def reshape_into_chunks(input_tensor, pad_size, chunk_size): - """ - Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and - simultaneously splitting it into chunk sequences. - - Assumes that we only have tensors of either size 4 or 3 - """ - # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] - input_tensor = pad_tensor_by_size(input_tensor, pad_size) - - if len(input_tensor.shape) == 3: - # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] - return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) - else: - # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size] - return input_tensor.reshape( - input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3] - ) - - -def segment_sum(input_tensor): - """ - More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions. - """ - chunk_size = input_tensor.size(-1) - # 1. expand input tensor to have an additional dimension and repeat along that dimension - # [..., chunk_size] -> [..., chunk_size, chunk_size] - input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size) - # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag - mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1) - input_tensor = input_tensor.masked_fill(~mask, 0) - # 3. compute actual cumsum - tensor_segsum = torch.cumsum(input_tensor, dim=-2) - - # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time) - mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0) - tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf) - return tensor_segsum - - -class Mamba2Cache: - """ - Arguments: - config: Mamba2Config - batch_size: int - dtype: torch.dtype - device: torch.device - - Attributes: - seqlen_offset: int - dtype: torch.dtype - conv_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, conv_kernel_size] - ssm_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, ssm_state_size] - """ - - def __init__( - self, config: Mamba2Config, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None - ): - self.seqlen_offset = 0 - self.dtype = dtype - self.conv_kernel_size = config.conv_kernel - self.intermediate_size = int(config.expand * config.hidden_size) - - self.conv_states = { - i: torch.zeros( - batch_size, - self.intermediate_size + 2 * config.n_groups * config.state_size, - self.conv_kernel_size, - device=device, - dtype=dtype, - ) - for i in range(config.num_hidden_layers) - } - self.ssm_states = { - i: torch.zeros( - batch_size, config.num_heads, config.head_dim, config.state_size, device=device, dtype=dtype - ) - for i in range(config.num_hidden_layers) - } - self.activation = config.hidden_act - self.act = ACT2FN[config.hidden_act] - - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor - ) -> torch.Tensor: - conv_state = self.conv_states[layer_idx] - cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) - - conv_state = conv_state.roll(shifts=-1, dims=-1) - conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device) - self.conv_states[layer_idx].zero_() - self.conv_states[layer_idx] += conv_state - return self.conv_states[layer_idx] - - def reset(self): - self.conv_states.zero_() - self.ssm_states.zero_() - - -class MambaRMSNormGated(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-6): - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states, gate=None): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - - if gate is not None: - hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32)) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - - return self.weight * hidden_states.to(input_dtype) - - -class Mamba2Mixer(nn.Module): - """ - Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. - A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) - ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, - and is why Mamba is called **selective** state spaces) - """ - - def __init__(self, config: Mamba2Config, layer_idx: int): - super().__init__() - self.num_heads = config.num_heads - self.hidden_size = config.hidden_size - self.ssm_state_size = config.state_size - self.conv_kernel_size = config.conv_kernel - self.intermediate_size = int(config.expand * self.hidden_size) - self.time_step_rank = int(config.time_step_rank) - self.layer_idx = layer_idx - self.use_conv_bias = config.use_conv_bias - self.activation = config.hidden_act - self.act = ACT2FN[config.hidden_act] - - self.layer_norm_epsilon = config.layer_norm_epsilon - self.rms_norm = config.rms_norm - - self.n_groups = config.n_groups - self.head_dim = config.head_dim - self.chunk_size = config.chunk_size - - self.time_step_limit = config.time_step_limit - self.time_step_min = config.time_step_min - self.time_step_max = config.time_step_max - - self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size - self.conv1d = nn.Conv1d( - in_channels=self.conv_dim, - out_channels=self.conv_dim, - bias=config.use_conv_bias, - kernel_size=config.conv_kernel, - groups=self.conv_dim, - padding=config.conv_kernel - 1, - ) - - # projection of the input hidden states - projection_size = self.intermediate_size + self.conv_dim + self.num_heads - self.in_proj = nn.Linear( - self.hidden_size, - projection_size, - bias=config.use_bias, - ) - # selective projection used to make dt, B and C input dependant - - # time step projection (discretization) - # instantiate once and copy inv_dt in init_weights of PretrainedModel - self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) - - # S4D real initialization. These are not discretized! - # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded - A = torch.arange(1, self.num_heads + 1) - self.A_log = nn.Parameter(torch.log(A)) - self.A_log._no_weight_decay = True - self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon) - self.D = nn.Parameter(torch.ones(self.num_heads)) - self.D._no_weight_decay = True - - self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) - self.use_bias = config.use_bias - - if not is_fast_path_available: - logger.warning_once( - "The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" - " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and" - " https://github.com/Dao-AILab/causal-conv1d" - ) - - def cuda_kernels_forward( - self, - hidden_states: torch.Tensor, - cache_params: Optional[Mamba2Cache] = None, - cache_position: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - ): - # set up dimensions for reshapes later - - batch_size, seq_len, _ = hidden_states.shape - groups_time_state_size = self.n_groups * self.ssm_state_size - d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads - - # getting projected states from cache if it exists - if cache_params is not None and cache_params.seqlen_offset > 0: - in_projected_states = self.in_proj(hidden_states.squeeze(1)) # (B 2D) - d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2 - split_projection_dim = [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads] - _, _, gate, hidden_states_B_C, dt = torch.split(in_projected_states, split_projection_dim, dim=-1) - - hidden_states_B_C = causal_conv1d_update( - hidden_states_B_C, - cache_params.conv_states[self.layer_idx], - self.conv1d.weight.squeeze(1), - self.conv1d.bias, - self.activation, - ) - - hidden_states, B, C = torch.split( - hidden_states_B_C, - [self.intermediate_size, groups_time_state_size, groups_time_state_size], - dim=-1, - ) - A = -torch.exp(self.A_log.float()) # (nheads,) - - A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) - dt = dt[:, :, None].expand(-1, -1, self.head_dim) - dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) - D = self.D[:, None, ...].expand(-1, self.head_dim) - B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups) - C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) - hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) - hidden_states = selective_state_update( - cache_params.ssm_states[self.layer_idx], - hidden_states_reshaped, - dt, - A, - B, - C, - D, - z=None, - dt_bias=dt_bias, - dt_softplus=True, - ) - hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim) - hidden_states = self.norm(hidden_states, gate) - out = self.out_proj(hidden_states)[:, None, ...] - # if no cache is found, calling the kernel - else: - if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: - # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 - dtype = hidden_states.dtype - hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) - # 1. Gated MLP's linear projection - projected_states = self.in_proj(hidden_states) - A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) - dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} - - if self.training and cache_params is None: - out, ssm_state = mamba_split_conv1d_scan_combined( - projected_states, - self.conv1d.weight.squeeze(1), - self.conv1d.bias, - self.dt_bias, - A, - D=self.D, - chunk_size=self.chunk_size, - seq_idx=None, # was seq_idx - activation=self.activation, - rmsnorm_weight=self.norm.weight, - rmsnorm_eps=self.norm.variance_epsilon, - outproj_weight=self.out_proj.weight, - outproj_bias=self.out_proj.bias, - headdim=self.head_dim, - ngroups=self.n_groups, - norm_before_gate=False, - return_final_states=True, - **dt_limit_kwargs, - ) - - else: - gate, hidden_states_B_C, time_step = torch.split( - projected_states, - [self.intermediate_size, self.conv_dim, self.num_heads], - dim=-1, - ) - - time_step = nn.functional.softplus(time_step + self.dt_bias) - # 1D Convolution - if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: - hidden_states_B_C = self.act( - self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len] - ) # (B, L, self.d_inner + 2 * ngroups * d_state) - else: - hidden_states_B_C = causal_conv1d_fn( - x=hidden_states_B_C.transpose(1, 2), - weight=self.conv1d.weight.squeeze(1), - bias=self.conv1d.bias, - activation=self.activation, - ).transpose(1, 2)[:, :seq_len] - hidden_states, B, C = torch.split( - hidden_states_B_C, - [self.intermediate_size, groups_time_state_size, groups_time_state_size], - dim=-1, - ) - if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: - # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 - dtype = hidden_states.dtype - hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) - scan_output, ssm_state = mamba_chunk_scan_combined( - hidden_states.view(batch_size, seq_len, -1, self.head_dim), - time_step, - A, - B.view(batch_size, seq_len, self.n_groups, -1), - C.view(batch_size, seq_len, self.n_groups, -1), - chunk_size=self.chunk_size, - D=self.D, - z=None, - seq_idx=None, - return_final_states=True, - **dt_limit_kwargs, - ) - if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) - scan_output = scan_output.view(batch_size, seq_len, -1) - # Multiply "gate" branch and apply extra normalization layer - scan_output = self.norm(scan_output, gate) - out = self.out_proj(scan_output) - return out - - # fmt: off - def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None): - batch_size, seq_len, _ = input_states.shape - dtype = input_states.dtype - # Gated MLP's linear projection - projected_states = self.in_proj(input_states.squeeze(1)) - d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size- self.num_heads) // 2 - _, _, gate, hidden_states, dt = projected_states.split( - [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 - ) - - # Convolution sequence transformation - if cache_params is not None: - ssm_state = cache_params.ssm_states[self.layer_idx].clone() - ssm_state = ssm_state.to(hidden_states.device) - if cache_params.seqlen_offset > 0: - conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size] - conv_state = torch.roll(conv_state, shifts=-1, dims=-1) - # handle batched generation - states are copied through - conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states - cache_params.conv_states[self.layer_idx].copy_(conv_state) - hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1) - if self.use_conv_bias: - hidden_states += self.conv1d.bias - hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding - else: - hidden_states = hidden_states.transpose(1,2) - conv_state = nn.functional.pad( - hidden_states, - (self.conv_kernel_size - hidden_states.shape[-1], 0) - ) - cache_params.conv_states[self.layer_idx].copy_(conv_state) - hidden_states = self.act(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len] - if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: - dtype = hidden_states.dtype - # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 - hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) - else: - ssm_state = torch.zeros( - (batch_size, self.num_heads, self.head_dim, self.ssm_state_size), - device=hidden_states.device, dtype=dtype - ) - hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2)) - hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1) - A = -torch.exp(self.A_log.float()) # [num_heads] - if cache_params is not None and cache_params.seqlen_offset > 0: - # Note: there is no need to pad parameter matrices here, as there is just one new token - # for batched generation - dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...] - dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) - # [num_heads] -> [num_heads, head_dim] - dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) - - dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype)) - dt = torch.clamp(dt, self.time_step_min) #, self.time_step_max) - A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) - # [bsz, num_heads, head_dim, state_size] - dA = torch.exp(dt[..., None] * A) - - # Discretize B - # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> - # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] - B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] - B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous() - B = B.reshape(batch_size, -1, B.shape[-1]) - # [bsz, num_heads, head_dim, state_size] - dB = dt[..., None] * B[..., None, :] - - # Discretize x into dB - # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] - hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) - dBx = dB * hidden_states[..., None] - - # State calculation - cache_params.ssm_states[self.layer_idx].copy_( - cache_params.ssm_states[self.layer_idx] * dA + dBx - ) - - # Subsequent output - # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] - C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] - C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous() - C = C.reshape(batch_size, -1, C.shape[-1]) - # [bsz, num_heads, head_dim] - - ssm_states = cache_params.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n] - # Reshape ssm_states to merge the first two dimensions - ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] - C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] - y = torch.bmm(ssm_states_reshaped, C_reshaped) - y = y.view(batch_size, self.num_heads, self.head_dim) - - # D skip connection - # [num_heads] -> [num_heads, head_dim] - D = self.D[..., None].expand(self.D.shape[0], self.head_dim) - y = (y + hidden_states * D).to(y.dtype) - - # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size] - y = y.reshape(batch_size, -1)[:, None, ...] - else: - # begin ssd naive implementation without einsums - dt = nn.functional.softplus(dt + self.dt_bias) - dt = torch.clamp(dt, self.time_step_min) - hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() - B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() - C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() - B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) - C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) - pad_size = self.chunk_size - (seq_len % self.chunk_size) - - D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) - - # Discretize x and A - hidden_states = hidden_states * dt[..., None] - A = A.to(hidden_states.dtype) * dt - - # Rearrange into blocks/chunks - hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] - - - # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] - A = A.permute(0, 3, 1, 2) - A_cumsum = torch.cumsum(A, dim=-1) - - # 1. Compute the output for each intra-chunk (diagonal blocks) - # This is the analog of a causal mask - L = torch.exp(segment_sum(A)) - - # First, contraction of C and B to get G (attention-weights like) - G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n) - G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) - - - # Step 2: Compute M, equivalent to applying attention mask to weights - M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] - M = M_intermediate.sum(dim=-1) - - # Step 3: Compute Y_diag (apply to values) - Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3) - - # (right term of low-rank factorization of off-diagonal blocks; B terms) - - decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) - B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None] - # permute back B * decay states - states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3) - if cache_params is not None and cache_params.seqlen_offset > 0: - previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...] - else: - previous_states = torch.zeros_like(states[:, :1]) - states = torch.cat([previous_states, states], dim=1) - decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) - - states_permuted = states.permute(0, 2, 1, 3, 4) - result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2) - new_states = result.permute(0, 2, 1, 3, 4) - states, ssm_state = new_states[:, :-1], new_states[:, -1] - - # Compute state -> output conversion per chunk - # (left term of low-rank factorization of off-diagonal blocks; C terms) - state_decay_out = torch.exp(A_cumsum) - # compute Yoff - C_times_states = (C[..., None, :] * states[:, :, None, ...]) - state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) - Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) - # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) - - y = Y_diag + Y_off - # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] - y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) - - y = y + D_residual - # Cutting off padded chunks - if pad_size > 0: - y = y[:, :seq_len, :, :] - y = y.reshape(batch_size, seq_len, -1) - if ssm_state is not None and cache_params is not None: - cache_params.ssm_states[self.layer_idx].copy_(ssm_state) - - scan_output = self.norm(y, gate) - - # end ssd naive - - # 4. Final linear projection - contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size] - return contextualized_states - # fmt: on - - def forward( - self, - hidden_states, - cache_params: Optional[Mamba2Cache] = None, - cache_position: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - ): - if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: - return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) - dtype = hidden_states.dtype - if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: - # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 - hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) - - return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) - - -class Mamba2RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Mamba2RMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -class Mamba2Block(nn.Module): - def __init__(self, config, layer_idx): - super().__init__() - self.config = config - self.layer_idx = layer_idx - self.residual_in_fp32 = config.residual_in_fp32 - self.norm = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) - self.mixer = Mamba2Mixer(config, layer_idx=layer_idx) - - def forward( - self, - hidden_states, - cache_params: Optional[Mamba2Cache] = None, - cache_position: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - ): - residual = hidden_states - hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) - if self.residual_in_fp32: - residual = residual.to(torch.float32) - - hidden_states = self.mixer( - hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask - ) - hidden_states = residual + hidden_states - return hidden_states - - -class Mamba2PreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = Mamba2Config - base_model_prefix = "backbone" - _no_split_modules = ["Mamba2Block"] - supports_gradient_checkpointing = True - _is_stateful = True - - def _init_weights(self, module): - """Initialize the weights.""" - if isinstance(module, Mamba2Mixer): - module.A_log._no_weight_decay = True - module.D._no_weight_decay = True - - dt = torch.exp( - torch.rand(self.config.num_heads) - * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) - + math.log(self.config.time_step_min) - ).clamp(min=self.config.time_step_floor) - - # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - with torch.no_grad(): - module.dt_bias.copy_(inv_dt) - module.dt_bias._no_reinit = True - - if isinstance(module, nn.Linear): - if module.bias is not None: - if not getattr(module.bias, "_no_reinit", False): - nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - nn.init.normal_(module.weight, std=self.config.initializer_range) - - if self.config.rescale_prenorm_residual: - # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: - # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale - # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. - # > -- GPT-2 :: https://openai.com/blog/better-language-models/ - # - # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py - for name, p in module.named_parameters(): - if name in ["out_proj.weight"]: - # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block - # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) - # We need to reinit p since this code could be called multiple times - # Having just p *= scale would repeatedly scale it down - nn.init.kaiming_uniform_(p, a=math.sqrt(5)) - with torch.no_grad(): - p /= math.sqrt(self.config.num_hidden_layers) - - -@dataclass -# Copied from transformers.models.mamba.modeling_mamba.MambaOutput with MAMBA->MAMBA2,Mamba->Mamba2 -class Mamba2Output(ModelOutput): - """ - Class for the MAMBA2 model outputs. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - cache_params (`Mamba2Cache`): - The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to - avoid providing the old `input_ids`. - - Includes both the State space model state matrices after the selective scan, and the Convolutional states - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - """ - - last_hidden_state: Optional[torch.FloatTensor] = None - cache_params: Optional[Mamba2Cache] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - - -@dataclass -# Copied from transformers.models.mamba.modeling_mamba.MambaCausalLMOutput with Mamba->Mamba2 -class Mamba2CausalLMOutput(ModelOutput): - """ - Base class for causal language model (or autoregressive) outputs. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - cache_params (`Mamba2Cache`): - The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to - avoid providing the old `input_ids`. - - Includes both the State space model state matrices after the selective scan, and the Convolutional states - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - """ - - loss: Optional[torch.FloatTensor] = None - logits: Optional[torch.FloatTensor] = None - cache_params: Optional[Mamba2Cache] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - - -MAMBA2_START_DOCSTRING = r""" - - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`Mamba2Config`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -MAMBA2_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - Indices of input sequence tokens in the vocabulary. - - If `cache_params.seqlen_offset>0`, only `input_ids` that do not have their past calculated should be passed as - `input_ids`. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - cache_params (`Mamba2Cache`, *optional*): - If passed along, the model uses the previous state in all the blocks (which will give the output for the - `input_ids` provided as if the model add `state_input_ids + input_ids` as context). - use_cache (`bool`, *optional*): - If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare MAMBA2 Model transformer outputting raw hidden-states without any specific head on top.", - MAMBA2_START_DOCSTRING, -) -class Mamba2Model(Mamba2PreTrainedModel): - def __init__(self, config): - super().__init__(config) - - self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) - self.layers = nn.ModuleList([Mamba2Block(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]) - - self.gradient_checkpointing = False - self.norm_f = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) - # Initialize weights and apply final processing - self._register_load_state_dict_pre_hook(self.load_hook) - self.post_init() - - def load_hook(self, state_dict, prefix, *args): - for k in state_dict: - if "embedding." in k: - state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k) - break - - def get_input_embeddings(self): - return self.embeddings - - def set_input_embeddings(self, new_embeddings): - self.embeddings = new_embeddings - - @add_start_docstrings_to_model_forward(MAMBA2_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=Mamba2Output, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.LongTensor] = None, - cache_params: Optional[Mamba2Cache] = None, - use_cache: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - **kwargs, - ) -> Union[Tuple, Mamba2Output]: - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" - ) - - if inputs_embeds is None: - inputs_embeds = self.embeddings(input_ids) - - if self.gradient_checkpointing and self.training and use_cache: - use_cache = False - - if use_cache: - if cache_params is None: - cache_params = Mamba2Cache( - self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype - ) - cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device) - elif cache_position is None: - # cases when we do manual forward instead of using `model.generate` which will initiate - # `cache_position` and makes sure it is not None, throw error here instead of doing some - # hack to conjecture the current cache position - raise ValueError( - "You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, " - "you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will " - "be initialized for you automatically" - ) - else: - cache_params = None - - hidden_states = inputs_embeds - all_hidden_states = () if output_hidden_states else None - for mixer_block in self.layers: - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask - ) - else: - hidden_states = mixer_block( - hidden_states, - cache_params=cache_params, - cache_position=cache_position, - attention_mask=attention_mask, - ) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if use_cache: - cache_params.seqlen_offset += inputs_embeds.shape[1] - - hidden_states = self.norm_f(hidden_states) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None) - - return Mamba2Output( - last_hidden_state=hidden_states, - cache_params=cache_params if use_cache else None, - hidden_states=all_hidden_states, - ) - - -@add_start_docstrings( - """ - The MAMBA2 Model transformer with a language modeling head on top (linear layer with weights not tied to the input - embeddings). - """, - MAMBA2_START_DOCSTRING, -) -class Mamba2ForCausalLM(Mamba2PreTrainedModel): - _tied_weights_keys = [] - - def __init__(self, config): - super().__init__(config) - self.backbone = Mamba2Model(config) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - # Initialize weights and apply final processing - self.post_init() - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def get_input_embeddings(self): - return self.backbone.get_input_embeddings() - - def set_input_embeddings(self, new_embeddings): - return self.backbone.set_input_embeddings(new_embeddings) - - def prepare_inputs_for_generation( - self, - input_ids, - inputs_embeds=None, - use_cache=None, - cache_params: Optional[Mamba2Cache] = None, - cache_position: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - **kwargs, - ): - if inputs_embeds is not None: - past_len = inputs_embeds.shape[1] + input_ids.shape[1] - else: - past_len = input_ids.shape[1] - if use_cache: - # `cache_position` should have been initialized in `generate` - if cache_position is None: - raise ValueError( - "`cache_position` should not be None as it should have been initialized in " - "`model.generate`, you are responsible for passing in a valid `cache_position` if " - "you are calling `prepare_inputs_for_generation` directly with `use_cache=True`" - ) - # how do we detect that we are in decoding without cache? - if cache_position[0] > 0: - input_ids = input_ids[:, -1][..., None] - attention_mask = attention_mask[:, -1][..., None] - else: - # we initialize the `cache_position` to full size of `conv_states` at prefill stage - # considering padding will be applied when input length is shorter, and truncation - # will be applied when it is longer, so it will be equivalent to always have it match - # the length of `cache_params.conv_states`, which is `config.conv_kernel` - cache_position = torch.arange(0, past_len, device=input_ids.device) - # if the cache is not used, we also do have to extend the attention mask here - # TODO there is likely a cleverer way to do this - extended_mask = torch.ones( - attention_mask.size(0), past_len - attention_mask.shape[1], device=attention_mask.device - ) - attention_mask = torch.cat([attention_mask, extended_mask], dim=1) - cache_params = None - - if attention_mask.shape[1] < past_len: - # we have to update manually the attention mask if - # we are in decoding without cache - # and we don't have position_ids here - # TODO but we should be able to use cache_position though at a later time - extended_mask = torch.ones( - attention_mask.size(0), past_len - attention_mask.shape[1], device=attention_mask.device - ) - attention_mask = torch.cat([attention_mask, extended_mask], dim=1) - if inputs_embeds is not None and cache_params is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "attention_mask": attention_mask, - "cache_params": cache_params, - "use_cache": use_cache, - "cache_position": cache_position, - } - ) - return model_inputs - - @add_start_docstrings_to_model_forward(MAMBA2_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=Mamba2CausalLMOutput, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - cache_params: Optional[Mamba2Cache] = None, - labels: Optional[torch.LongTensor] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - use_cache: Optional[bool] = None, - cache_position: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - **kwargs, # for now we need this for generation - ) -> Union[Tuple, Mamba2CausalLMOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` - are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - mamba2_outputs = self.backbone( - input_ids, - cache_params=cache_params, - inputs_embeds=inputs_embeds, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - use_cache=use_cache, - cache_position=cache_position, - attention_mask=attention_mask, - ) - hidden_states = mamba2_outputs[0] - - logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float() - - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(logits.device) - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - if not return_dict: - output = (logits,) + mamba2_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return Mamba2CausalLMOutput( - loss=loss, - logits=logits, - cache_params=mamba2_outputs.cache_params, - hidden_states=mamba2_outputs.hidden_states, - ) diff --git a/llms/mlx_lm/models/mamba2-works-hella-slow.py b/llms/mlx_lm/models/mamba2-works-hella-slow.py deleted file mode 100644 index 2960d3d0..00000000 --- a/llms/mlx_lm/models/mamba2-works-hella-slow.py +++ /dev/null @@ -1,300 +0,0 @@ -import math -from dataclasses import dataclass, field -from typing import Tuple, Union -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs -from .cache import MambaCache - -@dataclass -class ModelArgs(BaseModelArgs): - num_heads: int - head_dim: int - vocab_size: int - hidden_size: int - state_size: int - num_hidden_layers: int - layer_norm_epsilon: float - expand: int - conv_kernel: int - n_groups: int - use_bias: bool - use_conv_bias: bool - initializer_range: float - residual_in_fp32: bool - time_step_min: float - time_step_max: float - time_step_floor: float - rescale_prenorm_residual: bool - rms_norm: bool - chunk_size: int - tie_word_embeddings: bool - use_cache: bool = True - time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf"))) - time_step_rank: Union[int, str] = "auto" - model_type: str = "mamba2" - - def __post_init__(self): - if not hasattr(self, "intermediate_size"): - self.intermediate_size = int(self.expand * self.hidden_size) - if not hasattr(self, "head_dim"): - self.head_dim = self.hidden_size // self.num_heads - if self.time_step_rank == "auto": - self.time_step_rank = math.ceil(self.hidden_size / 16) - - -class MambaRMSNormGated(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - super().__init__() - self.weight = mx.ones((hidden_size,)) - self.variance_epsilon = eps - - def __call__(self, hidden_states, gate=None): - if gate is not None: - hidden_states = hidden_states * nn.silu(gate) - variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True) - hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states - - -def silu(x): - return x * mx.sigmoid(x) - -def ssd(x, A, B, C, chunk_size): - batch, seqlen, nheads, dim = x.shape - B = mx.expand_dims(B, axis=2) - C = mx.expand_dims(C, axis=2) - - state = mx.zeros((batch, nheads, dim, B.shape[-1])) - outputs = [] - - for i in range(0, seqlen, chunk_size): - chunk = slice(i, min(i + chunk_size, seqlen)) - dA = mx.exp(mx.expand_dims(A[chunk], axis=0)) - - x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim] - x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size] - B_chunk = B[:, chunk] # [batch, chunk_size, state_size] - dBx = mx.matmul(x_chunk, B_chunk) # [batch, nheads, dim, state_size] - - state = state * mx.expand_dims(dA, axis=-1) + dBx - - C_chunk = C[:, chunk] # [batch, chunk_size, state_size] - y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size] - y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim] - outputs.append(y) - - return mx.concatenate(outputs, axis=1), state - - -class DepthWiseConv1d(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.padding = padding - self.groups = groups if groups is not None else in_channels - - assert in_channels == out_channels, "In and out channels must be same for depthwise convolution" - assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution" - - self.weight = mx.random.normal((in_channels, 1, kernel_size)) - self.bias = mx.zeros((out_channels,)) if bias else None - - def __call__(self, x: mx.array, cache=None) -> mx.array: - B, L, C = x.shape - K = self.kernel_size - - assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}" - - if cache is not None: - # Access conv_state directly from cache[0] - if cache[0] is None: - cache[0] = mx.zeros((B, K-1, C)) - - x = mx.concatenate([cache[0], x], axis=1) - - outputs = [] - for c in range(C): - x_c = x[:, :, c] - x_c = mx.expand_dims(x_c, axis=1) - - w_c = self.weight[c] - if w_c.ndim == 2: - w_c = mx.expand_dims(w_c, axis=0) - elif w_c.ndim == 1: - w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0) - - y_c = mx.conv_general( - x_c, - w_c, - stride=1, - padding=0 - ) - if self.bias is not None: - y_c = y_c + self.bias[c] - - y_c = mx.squeeze(y_c, axis=1) - outputs.append(y_c) - - y = mx.stack(outputs, axis=-1) - - # Update cache directly using cache[0] - if cache is not None: - cache[0] = x[:, -K+1:, :] if x.shape[1] >= K else x - - return y - - -class Mamba2Block(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - - d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads - self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias) - - conv_dim = args.intermediate_size + 2 * args.state_size - self.conv1d = DepthWiseConv1d( - in_channels=conv_dim, - out_channels=conv_dim, - kernel_size=args.conv_kernel, - groups=conv_dim, - bias=args.use_conv_bias, - padding=args.conv_kernel - 1 - ) - - self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range - self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range - self.D = mx.random.normal((args.num_heads,)) * args.initializer_range - - self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon) - self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) - - if args.rescale_prenorm_residual: - layer_scale = math.sqrt(1.0 / args.num_hidden_layers) - self.out_proj.weight = self.out_proj.weight * layer_scale - - def __call__(self, u: mx.array, cache=None): - batch_size, seq_len, dimension = u.shape - assert seq_len == 1, "Input should be a single token" - - # Initialize cache states directly using indices - if cache[0] is None: # conv state - conv_dim = self.args.intermediate_size + 2 * self.args.state_size - cache[0] = mx.zeros((batch_size, self.args.conv_kernel - 1, conv_dim)) - - if cache[1] is None: # ssm state - cache[1] = mx.zeros(( - batch_size, - self.args.num_heads, - self.args.head_dim, - self.args.state_size - )) - - zxbcdt = self.in_proj(u) - - n_heads = self.args.num_heads - z = zxbcdt[:, :, :self.args.intermediate_size] - xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size] - dt = zxbcdt[:, :, -(n_heads):] - - dt = mx.reshape(dt, (batch_size, n_heads)) - dt = mx.clip(nn.softplus(dt + self.dt_bias), self.args.time_step_min, self.args.time_step_max) - dt = mx.maximum(dt, self.args.time_step_floor) - - xBC = self.conv1d(xBC, cache=cache) - xBC = silu(xBC) - - x = xBC[:, :, :self.args.intermediate_size] - B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size] - C = xBC[:, :, -self.args.state_size:] - - x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim)) - x = mx.squeeze(x, axis=1) - B = mx.reshape(B, (batch_size, 1, self.args.state_size)) - B = mx.broadcast_to(B, (batch_size, n_heads, self.args.state_size)) - B = mx.expand_dims(B, axis=2) - C = mx.reshape(C, (batch_size, 1, self.args.state_size)) - C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size)) - C = mx.expand_dims(C, axis=3) - - A = -mx.exp(self.A_log) - dA = mx.exp(dt * mx.expand_dims(A, 0)) - dA = mx.expand_dims(mx.expand_dims(dA, -1), -1) - - x = mx.expand_dims(x, axis=3) - dBx = mx.matmul(x, B) - # Update ssm state directly using cache[1] - cache[1] = cache[1] * dA + dBx - - y = mx.matmul(cache[1], C) - y = mx.squeeze(y, axis=-1) - y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1) - y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim)) - y = self.norm(y + z) - - return self.out_proj(y) - - -class ResidualBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.residual_in_fp32 = args.residual_in_fp32 - self.mixer = Mamba2Block(args) - self.norm = nn.RMSNorm(args.hidden_size) - - def __call__(self, x: mx.array, cache): - if self.residual_in_fp32: - x = x.astype(mx.float32) - normed = self.norm(x) - output = self.mixer(normed, cache) - return output + x - -class Mamba2(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] - self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) - - def __call__(self, x: mx.array, cache): - x = self.embeddings(x) - if cache is None: - cache = [None] * len(self.layers) - - hidden = x - for layer, c in zip(self.layers, cache): - hidden = layer(hidden, c) - return self.norm_f(hidden) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.backbone = Mamba2(args) - - if not args.tie_word_embeddings: - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__(self, inputs: mx.array, cache=None): - hidden = self.backbone(inputs, cache) - - if self.args.tie_word_embeddings: - logits = self.backbone.embeddings.as_linear(hidden) - else: - logits = self.lm_head(hidden) - - return logits - - def make_cache(self): - return [MambaCache() for _ in range(len(self.layers))] - - @property - def layers(self): - return self.backbone.layers \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba23.py b/llms/mlx_lm/models/mamba23.py deleted file mode 100644 index efbe54a4..00000000 --- a/llms/mlx_lm/models/mamba23.py +++ /dev/null @@ -1,357 +0,0 @@ -import math -from dataclasses import dataclass, field -from typing import Optional, Tuple, Union - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs -from .cache import Mamba2Cache - -@dataclass -class ModelArgs(BaseModelArgs): - num_heads: int - head_dim: int - vocab_size: int - hidden_size: int - state_size: int - num_hidden_layers: int - layer_norm_epsilon: float - expand: int - conv_kernel: int - n_groups: int - use_bias: bool - use_conv_bias: bool - initializer_range: float - residual_in_fp32: bool - time_step_min: float - time_step_max: float - time_step_floor: float - rescale_prenorm_residual: bool - use_cache: bool - rms_norm: bool - chunk_size: int - tie_word_embeddings: bool - intermediate_size: int = None - time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf"))) - time_step_rank: Union[int, str] = "auto" - model_type: str = "mamba2" - - def __post_init__(self): - self.intermediate_size = int(self.expand * self.hidden_size) # E*D = ED - - if not hasattr(self, "head_dim"): - self.head_dim = self.hidden_size // self.num_heads - if self.time_step_rank == "auto": - self.time_step_rank = math.ceil(self.hidden_size / 16) - - -class MambaRMSNormGated(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - super().__init__() - self.weight = mx.ones(hidden_size) - self.variance_epsilon = eps - - def forward(self, hidden_states, gate=None): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(mx.float32) - - if gate is not None: - hidden_states = hidden_states * nn.functional.silu(gate.to(mx.float32)) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * math.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -class Mamba2Mixer(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - # Model dimensions - self.hidden_size = args.hidden_size - self.num_heads = args.num_heads - self.head_dim = args.head_dim - self.ssm_state_size = args.state_size - self.n_groups = args.n_groups - self.intermediate_size = int(args.expand * args.hidden_size) - - # Convolution parameters - self.conv_kernel = args.conv_kernel - self.use_conv_bias = args.use_conv_bias - - # Time step parameters - self.time_step_rank = int(args.time_step_rank) - self.time_step_min = args.time_step_min - self.time_step_max = args.time_step_max - - # Processing parameters - self.chunk_size = args.chunk_size - self.layer_norm_epsilon = args.layer_norm_epsilon - - # Calculate dimensions - self.conv_dim = (self.intermediate_size + - 2 * self.n_groups * self.ssm_state_size) - projection_size = (self.intermediate_size + - self.conv_dim + - self.num_heads) - - # Initialize layers - self.in_proj = nn.Linear( - self.hidden_size, - projection_size, - bias=args.use_bias - ) - - self.conv1d = nn.Conv1d( - in_channels=self.conv_dim, - out_channels=self.conv_dim, - kernel_size=self.conv_kernel, - groups=self.conv_dim, - padding=self.conv_kernel - 1, - bias=self.use_conv_bias - ) - - # Initialize parameters - self.dt_bias = mx.ones(self.num_heads) - A = mx.arange(1, self.num_heads + 1) - self.A_log = mx.log(A) - self.D = mx.ones(self.num_heads) - - # Output layers - self.norm = MambaRMSNormGated( - self.intermediate_size, - eps=self.layer_norm_epsilon - ) - self.out_proj = nn.Linear( - self.intermediate_size, - self.hidden_size, - bias=args.use_bias - ) - - def reshape_into_chunks(self, tensor, pad_size, chunk_size): - if pad_size > 0: - pad_shape = list(tensor.shape) - pad_shape[1] = pad_size - padding = mx.zeros(pad_shape, dtype=tensor.dtype) - tensor = mx.concatenate([tensor, padding], axis=1) - - chunk_shape = list(tensor.shape) - chunk_shape[1] = -1 - chunk_shape.insert(2, chunk_size) - return tensor.reshape(chunk_shape) - - def segment_sum(self, x): - return mx.cumsum(x, axis=-1) - - def process_single_token(self, hidden_states, B, C, dt, cache): - batch_size = hidden_states.shape[0] - - # Process convolution state - if cache is not None and cache.conv_states is not None: - conv_state = cache.conv_states - # Roll the conv state and update the last position - conv_state = mx.roll(conv_state, shift=-1, axis=-1) - # Create new conv state with updated last position - new_conv_state = mx.array(conv_state) - new_conv_state = new_conv_state.at[:, :, -1].add(hidden_states) - conv_state = new_conv_state - - # Compute convolution - conv_out = mx.sum(conv_state * self.conv1d.weight[:, 0, :], axis=-1) - if self.use_conv_bias: - conv_out = conv_out + self.conv1d.bias - - # Apply SiLU activation - conv_out = mx.sigmoid(conv_out) * conv_out - - else: - # Initialize new cache and process convolution - conv_state = mx.zeros((batch_size, self.conv_dim, self.conv_kernel - 1)) - - # Reshape hidden_states for conv1d - hidden_states_reshaped = hidden_states.reshape(batch_size, -1, 1) - conv_out = self.conv1d(hidden_states_reshaped) - conv_out = mx.squeeze(conv_out, axis=-1) # Remove the last dimension - conv_out = mx.sigmoid(conv_out) * conv_out - - # Process SSM - dt = mx.clip( - nn.softplus(dt + self.dt_bias), - self.time_step_min, - self.time_step_max - ) - - A = -mx.exp(self.A_log) - dA = mx.exp(dt[:, None] * A[None, :]) - - if cache is not None and cache.ssm_states is not None: - ssm_state = cache.ssm_states - else: - ssm_state = mx.zeros( - (batch_size, self.num_heads, self.head_dim, self.ssm_state_size) - ) - - # Compute SSM updates - dBx = mx.einsum('bh,bhs,bhd->bhds', dt, B, hidden_states) - next_state = ssm_state * dA[:, :, None, None] + dBx - y = mx.einsum('bhds,bhs->bhd', next_state, C) - - # Add skip connection - y = y + hidden_states * self.D[None, :, None] - - return y, conv_state, next_state - - def process_long_sequence(self, hidden_states, B, C, dt, ssm_state): - batch_size, seq_len = hidden_states.shape[:2] - pad_size = self.chunk_size - (seq_len % self.chunk_size) - - # Reshape into chunks - x_chunks = self.reshape_into_chunks(hidden_states, pad_size, self.chunk_size) - B_chunks = self.reshape_into_chunks(B, pad_size, self.chunk_size) - C_chunks = self.reshape_into_chunks(C, pad_size, self.chunk_size) - - # Process time steps - dt = nn.softplus(dt + self.dt_bias) - dt = mx.clip(dt, self.time_step_min) - - # Prepare matrices - A = -mx.exp(self.A_log) - A = A * dt[:, None] - - # Process chunks - A_chunks = self.reshape_into_chunks( - mx.broadcast_to(A, (batch_size, seq_len + pad_size, self.num_heads)), - pad_size, - self.chunk_size - ) - - # Compute cumulative sums - A_cumsum = mx.cumsum(A_chunks, axis=-1) - L = mx.exp(self.segment_sum(A_chunks)) - - # Process diagonal blocks - G = mx.einsum('...lhn,...shn->...lsh', C_chunks, B_chunks) - M = G * L[..., None, :] - Y_diag = mx.einsum('...lsh,...sh->...lh', M, x_chunks) - - # Process off-diagonal blocks - decay_states = mx.exp(A_cumsum[..., -1:] - A_cumsum) - B_decay = B_chunks * decay_states[..., None] - states = mx.einsum('...shn,...sh->...hn', B_decay, x_chunks) - - # Combine results - y = Y_diag + states - - # Remove padding if necessary - if pad_size > 0: - y = y[:, :seq_len] - - return y, ssm_state - - def __call__(self, x: mx.array, cache: Optional[Mamba2Cache] = None) -> mx.array: - batch_size, seq_len, _ = x.shape - - # Project input - projected_states = self.in_proj(x) - - # Calculate d_mlp based on projection size - d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * - self.n_groups * self.ssm_state_size - self.num_heads) // 2 - - # Split projections with corrected dimensions - splits = [ - d_mlp, # z0 - d_mlp, # x0 - self.intermediate_size, # gate - self.conv_dim, # hidden_states - self.num_heads # dt - ] - - z0, x0, x1, gate, hidden_states, dt = projected_states.split(splits, axis=-1) - - # Split hidden states into components - x_conv, BC = mx.split(hidden_states, [self.intermediate_size], axis=-1) - B, C = mx.split(BC, [self.n_groups * self.ssm_state_size], axis=-1) - - # Process based on sequence length - if seq_len > 1 and cache is None: - y, next_state = self.process_long_sequence( - x_conv, B, C, dt, - mx.zeros((batch_size, self.num_heads, self.head_dim, self.ssm_state_size)) - ) - else: - # Reshape for single token processing - x_conv = x_conv.reshape(batch_size, -1, self.head_dim) - B = B.reshape(batch_size, self.num_heads, -1) - C = C.reshape(batch_size, self.num_heads, -1) - y, conv_state, next_state = self.process_single_token(x_conv, B, C, dt, cache) - - if cache is not None: - cache.update(conv_state, next_state) - - # Apply normalization and final projection - y = self.norm(y) * gate - return self.out_proj(y) - - -class ResidualBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.mixer = Mamba2Mixer(args) - self.norm = nn.RMSNorm(args.hidden_size) - - def __call__(self, x: mx.array, cache: Optional[Mamba2Cache] = None) -> mx.array: - return self.mixer(self.norm(x), cache) + x - - -class Mamba2Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] - self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) - - def __call__(self, x: mx.array, cache=None) -> mx.array: - x = self.embeddings(x) - - if cache is None: - cache = [None] * len(self.layers) - - for layer, layer_cache in zip(self.layers, cache): - x = layer(x, layer_cache) - - return self.norm_f(x) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.backbone = Mamba2Model(args) - if not args.tie_word_embeddings: - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__(self, inputs: mx.array, cache=None) -> mx.array: - B, T = inputs.shape - - x = self.backbone(inputs, cache) - - if self.args.tie_word_embeddings: - logits = self.backbone.embeddings.as_linear(x) - else: - logits = self.lm_head(x) - - return logits - - def make_cache(self, batch_size=1): - return [Mamba2Cache() for _ in range(len(self.backbone.layers))] - - def sanitize(self, weights): - for k, v in weights.items(): - if "conv1d.weight" in k and v.ndim == 3: - weights[k] = v.moveaxis(2, 1) - return weights - - @property - def layers(self): - return self.backbone.layers diff --git a/llms/mlx_lm/models/mamba24.py b/llms/mlx_lm/models/mamba24.py deleted file mode 100644 index b1ada1df..00000000 --- a/llms/mlx_lm/models/mamba24.py +++ /dev/null @@ -1,430 +0,0 @@ -import math -from dataclasses import dataclass, field -from typing import Tuple, Union -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs -from .cache import Mamba2Cache - -@dataclass -class ModelArgs(BaseModelArgs): - num_heads: int - head_dim: int - vocab_size: int - hidden_size: int - state_size: int - num_hidden_layers: int - layer_norm_epsilon: float - expand: int - conv_kernel: int - n_groups: int - use_bias: bool - use_conv_bias: bool - initializer_range: float - residual_in_fp32: bool - time_step_min: float - time_step_max: float - time_step_floor: float - rescale_prenorm_residual: bool - rms_norm: bool - chunk_size: int - tie_word_embeddings: bool - use_cache: bool = True - time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf"))) - time_step_rank: Union[int, str] = "auto" - model_type: str = "mamba2" - - def __post_init__(self): - if not hasattr(self, "intermediate_size"): - self.intermediate_size = int(self.expand * self.hidden_size) - if not hasattr(self, "head_dim"): - self.head_dim = self.hidden_size // self.num_heads - if self.time_step_rank == "auto": - self.time_step_rank = math.ceil(self.hidden_size / 16) - - -class MambaRMSNormGated(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - super().__init__() - self.weight = mx.ones((hidden_size,)) - self.variance_epsilon = eps - - def __call__(self, hidden_states, gate=None): - if gate is not None: - hidden_states = hidden_states * nn.silu(gate) - variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True) - hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states - - -def pad_tensor_by_size(input_tensor: mx.array, pad_size: int): - """ - Padding x tensor with `pad_size` on the seq_len dim (dim=1) - - Assumes that we only have tensors of either size 4 or 3 - """ - pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) - - return mx.pad(input_tensor, pad_shape, mode="constant", value=0) - - -def reshape_into_chunks(input_tensor, pad_size, chunk_size): - """ - Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and - simultaneously splitting it into chunk sequences. - - Assumes that we only have tensors of either size 4 or 3 - """ - # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] - input_tensor = pad_tensor_by_size(input_tensor, pad_size) - - if len(input_tensor.shape) == 3: - # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] - return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) - else: - # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size] - return input_tensor.reshape( - input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3] - ) - - -def segment_sum(input_tensor): - """ - More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions. - """ - chunk_size = input_tensor.size(-1) - # 1. expand input tensor to have an additional dimension and repeat along that dimension - # [..., chunk_size] -> [..., chunk_size, chunk_size] - input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size) - # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag - mask = mx.tril(mx.ones(chunk_size, chunk_size, device=input_tensor.device), diagonal=-1) - input_tensor = input_tensor.masked_fill(~mask, 0) - # 3. compute actual cumsum - tensor_segsum = mx.cumsum(input_tensor, dim=-2) - - # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time) - mask = mx.tril(mx.ones(chunk_size, chunk_size, device=input_tensor.device), diagonal=0) - tensor_segsum = tensor_segsum.masked_fill(~mask, -mx.inf) - return tensor_segsum - - -class Mamba2Block(nn.Module): - def __init__(self, args: ModelArgs, layer_idx: int): - super().__init__() - self.layer_idx = layer_idx - self.args = args - - self.hidden_size = args.hidden_size - self.num_heads = args.num_heads - self.head_dim = args.head_dim - self.state_size = args.state_size - self.n_groups = args.n_groups - self.conv_kernel = args.conv_kernel - self.intermediate_size = int(args.expand * args.hidden_size) - self.time_step_rank = int(args.time_step_rank) - self.time_step_min = args.time_step_min - self.time_step_max = args.time_step_max - self.chunk_size = args.chunk_size - - - # Convolution dimension includes both intermediate sizes - self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size - self.conv1d = nn.Conv1d( - in_channels=self.conv_dim, - out_channels=self.conv_dim, - bias=args.use_conv_bias, - kernel_size=args.conv_kernel, - groups=self.conv_dim, - padding=args.conv_kernel - 1 - ) - - # Compute input projection dimension - projection_size = self.intermediate_size + self.conv_dim + self.num_heads - self.in_proj = nn.Linear(args.hidden_size, projection_size, bias=args.use_bias) - - self.dt_bias = mx.ones(self.num_heads) - A = mx.arange(1, self.num_heads + 1) - self.A_log = mx.log(A) - self.D = mx.ones(self.num_heads) - - self.norm = MambaRMSNormGated(self.intermediate_size, eps=args.layer_norm_epsilon) - self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) - - def __call__(self, input_states: mx.array, cache): - batch_size, seq_len, _ = input_states.shape - - # Gated MLP's linear projection - projected_states = self.in_proj(input_states) # [1, 1, projection_size] - - d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - - 2 * self.n_groups * self.state_size - self.num_heads) // 2 - - # Split projected states - *_, gate, hidden_states, dt = projected_states.split( - [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], - axis=-1 - ) - # hidden_states shape: [1, 1, conv_dim] - - # Get SSM state from cache - ssm_state = cache.ssm_states[self.layer_idx] - - if cache.seqlen_offset > 0: - # Handle cached generation case - conv_state = cache.conv_states[self.layer_idx] # [batch, conv_dim, conv_kernel] - conv_state = mx.roll(conv_state, shifts=-1, axis=-1) - - # Handle batched generation - states are copied through - # Properly reshape hidden_states for the conv_state update - conv_state = conv_state.at[:, :, -1].set(hidden_states[:, 0, :]) - cache.conv_states[self.layer_idx] = conv_state - - # Compute convolution output - hidden_states = mx.sum(conv_state * self.conv1d.weight[:, 0, :], axis=-1) - if self.args.use_conv_bias: - hidden_states += self.conv1d.bias - hidden_states = nn.silu(hidden_states)[:, None, ...] # [batch, 1, conv_dim] : decoding - - else: - # Handle normal forward pass - # Properly transpose while preserving the sequence dimension - hidden_states = hidden_states.transpose(0, 2, 1) # [1, conv_dim, 1] - - # Pad the convolution state - padding_size = self.conv_kernel - 1 - conv_state = mx.pad( - hidden_states, - ((0, 0), (0, 0), (padding_size, 0)) - ) - - # Store in cache - cache.conv_states[self.layer_idx] = conv_state - - # Apply convolution with proper padding - hidden_states = self.conv1d(hidden_states) # [1, conv_dim, 1] - hidden_states = hidden_states.transpose(0, 2, 1) # [1, 1, conv_dim] - hidden_states = nn.silu(hidden_states) - - # Split hidden states for SSM computation - hidden_states, B, C = mx.split( - hidden_states, - [self.intermediate_size, self.n_groups * self.state_size, self.n_groups * self.state_size], - axis=-1 - ) - - # Compute A matrix - A = -mx.exp(self.A_log) - - if cache is not None and cache.seqlen_offset > 0: - # Note: there is no need to pad parameter matrices here, as there is just one new token - # for batched generation - dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...] - dt = dt.transpose(0, 2, 1).expand(batch_size, dt.shape[-1], self.head_dim) - # [num_heads] -> [num_heads, head_dim] - dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) - - dt = nn.softplus(dt + dt_bias) - dt = mx.clamp(dt, self.time_step_min) #, self.time_step_max) - A = A[..., None, None].expand(self.num_heads, self.head_dim, self.state_size) - # [bsz, num_heads, head_dim, state_size] - dA = mx.exp(dt[..., None] * A) - - # Discretize B - # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> - # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] - B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] - B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous() - B = B.reshape(batch_size, -1, B.shape[-1]) - # [bsz, num_heads, head_dim, state_size] - dB = dt[..., None] * B[..., None, :] - - # Discretize x into dB - # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] - hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) - dBx = dB * hidden_states[..., None] - - # State calculation - cache.ssm_states[self.layer_idx].copy_( - cache.ssm_states[self.layer_idx] * dA + dBx - ) - # Subsequent output - # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] - C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] - C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous() - C = C.reshape(batch_size, -1, C.shape[-1]) - # [bsz, num_heads, head_dim] - - ssm_states = cache.ssm_states[self.layer_idx] # Shape: [b, h, d, n] - # Reshape ssm_states to merge the first two dimensions - ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.state_size) # Shape: [b*h, d, n] - C_reshaped = C.view(batch_size * self.num_heads, self.state_size, 1) # Shape: [b*h, n, 1] - y = ssm_states_reshaped @ C_reshaped - y = y.view(batch_size, self.num_heads, self.head_dim) - - # D skip connection - # [num_heads] -> [num_heads, head_dim] - D = self.D[..., None].expand(self.D.shape[0], self.head_dim) - y = (y + hidden_states * D) - - # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size] - y = y.reshape(batch_size, -1)[:, None, ...] - else: - # begin ssd naive implementation without einsums - dt = nn.functional.softplus(dt + self.dt_bias) - dt = mx.clamp(dt, self.time_step_min) - hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim) - B = B.reshape(batch_size, seq_len, -1, self.state_size) - C = C.reshape(batch_size, seq_len, -1, self.state_size) - B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) - C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) - pad_size = self.chunk_size - (seq_len % self.chunk_size) - - D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) - - # Discretize x and A - hidden_states = hidden_states * dt[..., None] - A = A * dt - - # Rearrange into blocks/chunks - hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] - - - # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] - A = A.permute(0, 3, 1, 2) - A_cumsum = mx.cumsum(A, dim=-1) - - # 1. Compute the output for each intra-chunk (diagonal blocks) - # This is the analog of a causal mask - L = mx.exp(segment_sum(A)) - - # First, contraction of C and B to get G (attention-weights like) - G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n) - G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) - - - # Step 2: Compute M, equivalent to applying attention mask to weights - M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] - M = M_intermediate.sum(dim=-1) - - # Step 3: Compute Y_diag (apply to values) - Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3) - - # (right term of low-rank factorization of off-diagonal blocks; B terms) - - decay_states = mx.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) - B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None] - # permute back B * decay states - states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3) - if cache is not None and cache.seqlen_offset > 0: - previous_states = cache.ssm_states[self.layer_idx][:, None, ...] - else: - previous_states = mx.zeros_like(states[:, :1]) - states = mx.concat([previous_states, states], dim=1) - decay_chunk = mx.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) - - states_permuted = states.permute(0, 2, 1, 3, 4) - result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2) - new_states = result.permute(0, 2, 1, 3, 4) - states, ssm_state = new_states[:, :-1], new_states[:, -1] - - # Compute state -> output conversion per chunk - # (left term of low-rank factorization of off-diagonal blocks; C terms) - state_decay_out = mx.exp(A_cumsum) - # compute Yoff - C_times_states = (C[..., None, :] * states[:, :, None, ...]) - state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) - Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) - # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) - - y = Y_diag + Y_off - # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] - y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) - - y = y + D_residual - # Cutting off padded chunks - if pad_size > 0: - y = y[:, :seq_len, :, :] - y = y.reshape(batch_size, seq_len, -1) - - if ssm_state is not None and cache is not None: - cache.ssm_states[self.layer_idx] = ssm_state - - scan_output = self.norm(y, gate) - # end ssd naive - - # 4. Final linear projection - return self.out_proj(scan_output) # [batch, seq_len, hidden_size] - - -class ResidualBlock(nn.Module): - def __init__(self, args: ModelArgs, layer_idx: int): - super().__init__() - self.residual_in_fp32 = args.residual_in_fp32 - self.mixer = Mamba2Block(args, layer_idx) - self.norm = nn.RMSNorm(args.hidden_size) - - def __call__(self, x: mx.array, cache): - return self.mixer(self.norm(x), cache) + x - - -class Mamba2(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ResidualBlock(args, idx) for idx in range(args.num_hidden_layers)] - self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) - - def __call__(self, x: mx.array, cache): - x = self.embeddings(x) - if cache is None: - cache = [None] * len(self.layers) - for layer, c in zip(self.layers, cache): - x = layer(x, c) - return self.norm_f(x) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - - self.backbone = Mamba2(args) - - if not args.tie_word_embeddings: - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__(self, inputs: mx.array, cache=None): - B, T = inputs.shape - - x = self.backbone(inputs, cache) - - if self.args.tie_word_embeddings: - logits = self.backbone.embeddings.as_linear(x) - else: - logits = self.lm_head(x) - - return logits - - def make_cache(self, batch_size=1): - return [Mamba2Cache( - batch_size, - self.args.intermediate_size, - self.args.conv_kernel, - self.args.head_dim, - self.args.num_heads, - self.args.n_groups, - self.args.state_size - ) for _ in range(len(self.layers))] - - def sanitize(self, weights): - for k, v in weights.items(): - if "conv1d.weight" in k and v.ndim == 3: - weights[k] = v.moveaxis(2, 1) - return weights - - @property - def layers(self): - return self.backbone.layers \ No newline at end of file diff --git a/llms/mlx_lm/models/s.py b/llms/mlx_lm/models/s.py deleted file mode 100644 index e305fae8..00000000 --- a/llms/mlx_lm/models/s.py +++ /dev/null @@ -1,343 +0,0 @@ -import math -from dataclasses import dataclass, field -from typing import Tuple, Union -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs -from .cache import Mamba2Cache - -@dataclass -class ModelArgs(BaseModelArgs): - num_heads: int - head_dim: int - vocab_size: int - hidden_size: int - state_size: int - num_hidden_layers: int - layer_norm_epsilon: float - expand: int - conv_kernel: int - n_groups: int - use_bias: bool - use_conv_bias: bool - initializer_range: float - residual_in_fp32: bool - time_step_min: float - time_step_max: float - time_step_floor: float - rescale_prenorm_residual: bool - rms_norm: bool - chunk_size: int - tie_word_embeddings: bool - use_cache: bool = True - time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf"))) - time_step_rank: Union[int, str] = "auto" - model_type: str = "mamba2" - - def __post_init__(self): - if not hasattr(self, "intermediate_size"): - self.intermediate_size = int(self.expand * self.hidden_size) - if not hasattr(self, "head_dim"): - self.head_dim = self.hidden_size // self.num_heads - if self.time_step_rank == "auto": - self.time_step_rank = math.ceil(self.hidden_size / 16) - - -class MambaRMSNormGated(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - super().__init__() - self.weight = mx.ones((hidden_size,)) - self.variance_epsilon = eps - - def __call__(self, hidden_states, gate=None): - if gate is not None: - hidden_states = hidden_states * nn.silu(gate) - variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True) - hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states - - -def silu(x): - return x * mx.sigmoid(x) - -def ssd(x, A, B, C, chunk_size): - # Replace einsum operations with explicit reshape and matrix multiply - batch, seqlen, nheads, dim = x.shape - B = mx.expand_dims(B, axis=2) - C = mx.expand_dims(C, axis=2) - - state = mx.zeros((batch, nheads, dim, B.shape[-1])) - outputs = [] - - for i in range(0, seqlen, chunk_size): - chunk = slice(i, min(i + chunk_size, seqlen)) - dA = mx.exp(mx.expand_dims(A[chunk], axis=0)) - - # Replace einsum with explicit operations - x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim] - x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size] - B_chunk = B[:, chunk] # [batch, chunk_size, state_size] - dBx = mx.matmul(x_chunk, B_chunk) # [batch, nheads, dim, state_size] - - state = state * mx.expand_dims(dA, axis=-1) + dBx - - # Replace einsum with explicit operations - C_chunk = C[:, chunk] # [batch, chunk_size, state_size] - y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size] - y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim] - outputs.append(y) - - return mx.concatenate(outputs, axis=1), state - - -class DepthWiseConv1d(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.padding = padding - self.groups = groups if groups is not None else in_channels - - assert in_channels == out_channels, "In and out channels must be same for depthwise convolution" - assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution" - - # Initialize weight with correct shape [C_out, 1, kernel_size] - self.weight = mx.random.normal((out_channels, 1, kernel_size)) - self.bias = mx.zeros((out_channels,)) if bias else None - - def __call__(self, x: mx.array, cache=None) -> mx.array: - B, L, C = x.shape - K = self.kernel_size - - assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}" - - # Handle caching for sequential processing - if cache is not None and cache.conv_states[0] is not None: - if isinstance(cache.conv_states[0], type(None)): - cache.conv_states[0] = mx.zeros((B, K-1, C)) - x = mx.concatenate([cache.conv_states[0], x], axis=1) - - # Process each channel independently - outputs = [] - for c in range(C): - # Extract and reshape the channel - x_c = x[:, :, c] # [B, L] - x_c = mx.expand_dims(x_c, axis=1) # [B, 1, L] - - # Get weight for this channel - already in correct shape [1, 1, K] - w_c = mx.expand_dims(self.weight[c], axis=0) # Ensure [1, 1, K] - - # Apply convolution - y_c = mx.conv_general( - x_c, - w_c, - stride=1, - padding=self.padding - ) - - if self.bias is not None: - y_c = y_c + self.bias[c] - - outputs.append(mx.squeeze(y_c, axis=1)) - - y = mx.stack(outputs, axis=-1) - - # Update cache - if cache is not None: - cache.conv_states[0] = x[:, -K+1:, :] if x.shape[1] >= K else x - - return y - - -class Mamba2Block(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - - self.chunk_size = args.chunk_size - - d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads - self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias) - - self.conv_dim = args.intermediate_size + 2 * args.state_size - self.conv1d = DepthWiseConv1d( - in_channels=self.conv_dim, - out_channels=self.conv_dim, - kernel_size=args.conv_kernel, - groups=self.conv_dim, - bias=args.use_conv_bias, - padding=args.conv_kernel - 1 - ) - - self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range - self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range - self.D = mx.random.normal((args.num_heads,)) * args.initializer_range - - self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon) - self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) - - if args.rescale_prenorm_residual: - layer_scale = math.sqrt(1.0 / args.num_hidden_layers) - self.out_proj.weight = self.out_proj.weight * layer_scale - - def __call__(self, u: mx.array, cache=None): - # Expect input shape: [batch_size, 1, hidden_size] - batch_size, seq_len, _ = u.shape - pad_size = self.chunk_size - (seq_len % self.chunk_size) - - # Initialize states if needed - if cache.conv_states[0] is None: - cache.conv_states[0] = mx.zeros(( - batch_size, - self.args.conv_kernel - 1, - self.conv_dim - )) - - if cache.ssm_states[0] is None: - cache.ssm_states[0] = mx.zeros(( - batch_size, - self.args.num_heads, - self.args.head_dim, - self.args.state_size - )) - - # Project input - zxbcdt = self.in_proj(u) - - # Split projections - z = zxbcdt[:, :, :self.args.intermediate_size] - xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size] - dt = zxbcdt[:, :, -(self.args.num_heads):] - - # Process delta time - dt = mx.reshape(dt, (batch_size, seq_len, self.args.num_heads)) - dt = mx.squeeze(dt, axis=0) # Remove sequence dimension for single token - dt = mx.clip( - nn.softplus(dt + self.dt_bias), - self.args.time_step_min, - self.args.time_step_max - ) - dt = mx.maximum(dt, self.args.time_step_floor) - - # Convolution step - xBC = self.conv1d(xBC, cache=cache) - xBC = silu(xBC) - - # Split conv output - x = xBC[:, :, :self.args.intermediate_size] - B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size] - C = xBC[:, :, -self.args.state_size:] - - # Reshape for SSM - x = mx.reshape(x, (batch_size, 1, self.args.num_heads, self.args.head_dim)) - x = mx.squeeze(x, axis=1) - - B = mx.reshape(B, (batch_size, 1, self.args.state_size)) - B = mx.broadcast_to(B, (batch_size, self.args.num_heads, self.args.state_size)) - B = mx.expand_dims(B, axis=2) - - C = mx.reshape(C, (batch_size, 1, self.args.state_size)) - C = mx.broadcast_to(C, (batch_size, self.args.num_heads, self.args.state_size)) - C = mx.expand_dims(C, axis=3) - - # SSM state update - A = -mx.exp(self.A_log) - dA = mx.exp(dt * mx.expand_dims(A, 0)) - dA = mx.expand_dims(mx.expand_dims(dA, -1), -1) - - x = mx.expand_dims(x, axis=3) - dBx = mx.matmul(x, B) - - cache.ssm_states[0] = cache.ssm_states[0] * dA + dBx - - # Output computation - y = mx.matmul(cache.ssm_states[0], C) - y = mx.squeeze(y, axis=-1) - - # y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1) - if pad_size > 0: - y = y[:, :seq_len, :, :] - - # Final reshape and projections - y = mx.reshape(y, (batch_size, 1, self.args.num_heads * self.args.head_dim)) - y = self.norm(y + z) - - return self.out_proj(y) - - -class ResidualBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.residual_in_fp32 = args.residual_in_fp32 - - self.mixer = Mamba2Block(args) - self.norm = nn.RMSNorm(args.hidden_size) - - def __call__(self, x: mx.array, cache): - if self.residual_in_fp32: - x = x.astype(mx.float32) - return self.mixer(self.norm(x), cache) + x - - -class Mamba2(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] - self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) - - def __call__(self, x: mx.array, cache): - x = self.embeddings(x) - if cache is None: - cache = [None] * len(self.layers) - for layer, c in zip(self.layers, cache): - x = layer(x, c) - return self.norm_f(x) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - - self.backbone = Mamba2(args) - - if not args.tie_word_embeddings: - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__(self, inputs: mx.array, cache=None): - B, T = inputs.shape - - x = self.backbone(inputs, cache) - - if self.args.tie_word_embeddings: - logits = self.backbone.embeddings.as_linear(x) - else: - logits = self.lm_head(x) - - return logits - - def make_cache(self, batch_size=1): - return [Mamba2Cache(batch_size, self.args.conv_kernel) for _ in range(len(self.layers))] - - def sanitize(self, weights): - sanitized = {} - for k, v in weights.items(): - if "conv1d.weight" in k: - # Ensure weights are in correct shape (channels, 1, kernel_size) - if v.ndim == 2: - v = mx.expand_dims(v, axis=1) - elif v.ndim == 1: - v = mx.expand_dims(mx.expand_dims(v, axis=0), axis=0) - sanitized[k] = v - else: - sanitized[k] = v - return sanitized - - @property - def layers(self): - return self.backbone.layers