From e22b2dbf27050f9c798c906786a1412aba9d7e29 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Thu, 21 Nov 2024 22:01:28 +0100 Subject: [PATCH] Fixed streaming generation and got rid of generating gibberish, but is still a litle slow: 0.222 tokens-per-sec --- llms/mlx_lm/models/cache.py | 26 +- llms/mlx_lm/models/mamba2-prch-minimal.py | 772 +++++++++--------- .../models/mamba2-working_but_giberish.py | 319 ++++++++ llms/mlx_lm/models/mamba2-works-hella-slow.py | 187 ++--- llms/mlx_lm/models/mamba2.py | 279 +++---- llms/mlx_lm/models/mamba22.py | 316 ------- 6 files changed, 884 insertions(+), 1015 deletions(-) create mode 100644 llms/mlx_lm/models/mamba2-working_but_giberish.py delete mode 100644 llms/mlx_lm/models/mamba22.py diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index 37b414da..32a07af7 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -419,34 +419,20 @@ class RotatingKVCache(_BaseCache): raise NotImplementedError("RotatingKVCache Quantization NYI") -class MambaCache: +class MambaCache(_BaseCache): def __init__(self): - # [conv_state, ssm_state] self.cache = [None, None] - self.offset = 0 # Sliding window caching - + def __setitem__(self, idx, value): self.cache[idx] = value - + def __getitem__(self, idx): return self.cache[idx] - + @property def state(self): return self.cache - + @state.setter def state(self, v): - self.cache = v - - @property - def conv_states(self): - return [self.cache[0]] - - @property - def ssm_states(self): - return [self.cache[1]] - - def reset(self): - self.cache = [None, None] - self.offset = 0 + self.cache = v \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba2-prch-minimal.py b/llms/mlx_lm/models/mamba2-prch-minimal.py index 52d27f00..f988a825 100644 --- a/llms/mlx_lm/models/mamba2-prch-minimal.py +++ b/llms/mlx_lm/models/mamba2-prch-minimal.py @@ -1,449 +1,437 @@ -# coding=utf-8 -# Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch MAMBA2 model.""" +""" +mamba2-minimal +============== -import math +A minimal, single-file implementation of the Mamba-2 model in PyTorch. + +> **Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality** +> Authors: Tri Dao, Albert Gu +> Paper: https://arxiv.org/abs/2405.21060 +""" + +import json from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Iterable, NamedTuple, TypeAlias, cast import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import CrossEntropyLoss +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import LongTensor, Tensor, nn -logger = logging.get_logger(__name__) +Device: TypeAlias = str | torch.device | None -def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): - """ - Padding x tensor with `pad_size` on the seq_len dim (dim=1) +@dataclass +class Mamba2Config: + d_model: int # model dimension (D) + n_layer: int = 24 # number of Mamba-2 layers in the language model + d_state: int = 128 # state dimension (N) + d_conv: int = 4 # convolution kernel size + expand: int = 2 # expansion factor (E) + headdim: int = 64 # head dimension (P) + chunk_size: int = 64 # matrix partition size (Q) + vocab_size: int = 50277 + pad_vocab_size_multiple: int = 16 - Assumes that we only have tensors of either size 4 or 3 - """ - pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) - - return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) + def __post_init__(self): + self.d_inner = self.expand * self.d_model + assert self.d_inner % self.headdim == 0 + self.nheads = self.d_inner // self.headdim + if self.vocab_size % self.pad_vocab_size_multiple != 0: + self.vocab_size += ( + self.pad_vocab_size_multiple + - self.vocab_size % self.pad_vocab_size_multiple + ) -def reshape_into_chunks(input_tensor, pad_size, chunk_size): - """ - Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and - simultaneously splitting it into chunk sequences. +class InferenceCache(NamedTuple): + conv_state: Tensor # (batch, d_inner + 2 * d_state, d_conv) + ssm_state: Tensor # (batch, nheads, headdim, d_state) - Assumes that we only have tensors of either size 4 or 3 - """ - # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] - input_tensor = pad_tensor_by_size(input_tensor, pad_size) - - if len(input_tensor.shape) == 3: - # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] - return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) - else: - # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size] - return input_tensor.reshape( - input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3] + @staticmethod + def alloc(batch_size: int, args: Mamba2Config, device: Device = None): + return InferenceCache( + torch.zeros( + batch_size, args.d_inner + 2 * args.d_state, args.d_conv, device=device + ), + torch.zeros( + batch_size, args.nheads, args.headdim, args.d_state, device=device + ), ) -def segment_sum(input_tensor): - """ - More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions. - """ - chunk_size = input_tensor.size(-1) - # 1. expand input tensor to have an additional dimension and repeat along that dimension - # [..., chunk_size] -> [..., chunk_size, chunk_size] - input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size) - # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag - mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1) - input_tensor = input_tensor.masked_fill(~mask, 0) - # 3. compute actual cumsum - tensor_segsum = torch.cumsum(input_tensor, dim=-2) - - # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time) - mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0) - tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf) - return tensor_segsum - - -class Mamba2Cache: - """ - Arguments: - config: ModelArgs - batch_size: int - dtype: torch.dtype - device: torch.device - - Attributes: - seqlen_offset: int - dtype: torch.dtype - conv_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, conv_kernel] - ssm_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, ssm_state_size] - """ - - def __init__( - self, config: ModelArgs, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None - ): - self.seqlen_offset = 0 - self.dtype = dtype - self.conv_kernel = config.conv_kernel - self.intermediate_size = int(config.expand * config.hidden_size) - - self.conv_states = { - i: torch.zeros( - batch_size, - self.intermediate_size + 2 * config.n_groups * config.state_size, - self.conv_kernel, - device=device, - dtype=dtype, - ) - for i in range(config.num_hidden_layers) - } - self.ssm_states = { - i: torch.zeros( - batch_size, config.num_heads, config.head_dim, config.state_size, device=device, dtype=dtype - ) - for i in range(config.num_hidden_layers) - } - - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor - ) -> torch.Tensor: - conv_state = self.conv_states[layer_idx] - cache_position = cache_position.clamp(0, self.conv_kernel - 1) - - conv_state = conv_state.roll(shifts=-1, dims=-1) - conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device) - self.conv_states[layer_idx].zero_() - self.conv_states[layer_idx] += conv_state - return self.conv_states[layer_idx] - - def reset(self): - self.conv_states.zero_() - self.ssm_states.zero_() - - -class MambaRMSNormGated(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-6): +class Mamba2LMHeadModel(nn.Module): + def __init__(self, args: Mamba2Config, device: Device = None): super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps + self.args = args + self.device = device - def forward(self, hidden_states, gate=None): - input_dtype = hidden_states.dtype - hidden_states = hidden_states + self.backbone = nn.ModuleDict( + dict( + embedding=nn.Embedding(args.vocab_size, args.d_model, device=device), + layers=nn.ModuleList( + [ + nn.ModuleDict( + dict( + mixer=Mamba2(args, device=device), + norm=RMSNorm(args.d_model, device=device), + ) + ) + for _ in range(args.n_layer) + ] + ), + norm_f=RMSNorm(args.d_model, device=device), + ) + ) + self.lm_head = nn.Linear( + args.d_model, args.vocab_size, bias=False, device=device + ) + self.lm_head.weight = self.backbone.embedding.weight - if gate is not None: - hidden_states = hidden_states * nn.functional.silu(gate) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + @staticmethod + def from_pretrained(huggingface_model_id: str, device: Device = None): + from transformers.utils import CONFIG_NAME, WEIGHTS_NAME + from transformers.utils.hub import cached_file - return self.weight * hidden_states + config_path = cached_file(huggingface_model_id, CONFIG_NAME) + assert config_path, "Failed to get huggingface config file" + state_dict_path = cached_file(huggingface_model_id, WEIGHTS_NAME) + assert state_dict_path, "Failed to get huggingface state dict file" + + config = json.load(open(config_path)) + args = Mamba2Config( + d_model=config["d_model"], + n_layer=config["n_layer"], + vocab_size=config["vocab_size"], + pad_vocab_size_multiple=config["pad_vocab_size_multiple"], + ) + + map_location = "cpu" if device is None else device + state_dict = torch.load( + state_dict_path, weights_only=True, map_location=map_location, mmap=True + ) + model = Mamba2LMHeadModel(args, device=device) + model.load_state_dict(state_dict) + model.eval() + return model + + def forward( + self, input_ids: LongTensor, h: list[InferenceCache] | list[None] | None = None + ) -> tuple[LongTensor, list[InferenceCache]]: + """ + Arguments + input_ids: (batch, seqlen) tokens from `EleutherAI/gpt-neox-20b` tokenizer + h: hidden states for inference step. If present the constant-time + (wrt sequence length) inference path will be taken, input_ids + should have shape (batch, 1) containing the next batch of prompt + token. + + Return (logits, h) + logits: (batch, seqlen, vocab_size) + h: updated inference cache after processing `input_ids` + """ + seqlen = input_ids.shape[1] + + if h is None: + h = [None for _ in range(self.args.n_layer)] + + x = self.backbone.embedding(input_ids) + for i, layer in enumerate(self.backbone.layers): + y, h[i] = layer.mixer(layer.norm(x), h[i]) + x = y + x + + x = self.backbone.norm_f(x) + logits = self.lm_head(x) + return logits[:, :seqlen], cast(list[InferenceCache], h) + + def generate( + self, + input_ids: LongTensor, + max_new_length: int = 20, + temperature: float = 1.0, + top_k: int = 50, + top_p: float = 1.0, + eos_token_id: int = 0, + ) -> Iterable[tuple[int, list[InferenceCache]]]: + prefix, tokens = input_ids[:-1], input_ids[-1:].unsqueeze(0) + + # Process prompt + # The input sequence to forward (non-inference path) must have length multiple that of chunk_size. + # We split out excess tokens so that n_chunked tokens can be processed by one forward call and + # process the rest in multiple inference steps. + n_chunked = (prefix.shape[0] // self.args.chunk_size) * self.args.chunk_size + if n_chunked > 0: + _, h = self(prefix[:n_chunked].unsqueeze(0), None) + else: + h = [ + InferenceCache.alloc(1, self.args, device=self.device) + for _ in range(self.args.n_layer) + ] + for i in range(n_chunked, prefix.shape[0]): + _, h = self(prefix[i : i + 1].unsqueeze(0), h) + + # Generate + for _ in range(max_new_length): + with torch.no_grad(): + out, h = self(tokens, h) + logits = out[0, -1] + if temperature != 1.0: + logits = logits / temperature + if top_k > 0: + indices_to_remove = logits < torch.topk(logits, k=top_k)[0][-1] + logits[indices_to_remove] = -torch.inf + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cum_probs > 0.5 + sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() + sorted_indices_to_remove[0] = False + indices_to_remove = sorted_indices[sorted_indices_to_remove] + logits[indices_to_remove] = -torch.inf + probs = F.softmax(logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + if next_token.item() == eos_token_id: + return + tokens = next_token.unsqueeze(0) + yield cast(int, next_token.item()), h -class Mamba2Mixer(nn.Module): - def __init__(self, config: ModelArgs): +class Mamba2(nn.Module): + def __init__(self, args: Mamba2Config, device: Device = None): super().__init__() - self.num_heads = config.num_heads - self.hidden_size = config.hidden_size - self.state_size = config.state_size - self.conv_kernel = config.conv_kernel - self.intermediate_size = int(config.expand * self.hidden_size) - self.time_step_rank = int(config.time_step_rank) - self.use_conv_bias = config.use_conv_bias + self.args = args + self.device = device - self.layer_norm_epsilon = config.layer_norm_epsilon + # Order: (z, x, B, C, dt) + d_in_proj = 2 * args.d_inner + 2 * args.d_state + args.nheads + self.in_proj = nn.Linear(args.d_model, d_in_proj, bias=False, device=device) - self.n_groups = config.n_groups - self.head_dim = config.head_dim - self.chunk_size = config.chunk_size - - self.time_step_limit = config.time_step_limit - self.time_step_min = config.time_step_min - self.time_step_max = config.time_step_max - - self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size + conv_dim = args.d_inner + 2 * args.d_state self.conv1d = nn.Conv1d( - in_channels=self.conv_dim, - out_channels=self.conv_dim, - bias=config.use_conv_bias, - kernel_size=config.conv_kernel, - groups=self.conv_dim, - padding=config.conv_kernel - 1, + in_channels=conv_dim, + out_channels=conv_dim, + kernel_size=args.d_conv, + groups=conv_dim, + padding=args.d_conv - 1, + device=device, ) - # projection of the input hidden states - projection_size = self.intermediate_size + self.conv_dim + self.num_heads - self.in_proj = nn.Linear( - self.hidden_size, - projection_size, - bias=config.use_bias, + self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device)) + self.A_log = nn.Parameter(torch.empty(args.nheads, device=device)) + self.D = nn.Parameter(torch.empty(args.nheads, device=device)) + self.norm = RMSNorm(args.d_inner, device=device) + self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device) + + def forward(self, u: Tensor, h: InferenceCache | None = None): + """ + Arguments + u: (batch, seqlen, d_model) input. seqlen should be a multiple of chunk_size. + h: hidden states for inference step. Initialized to 0s if not present. + + Return (y, h) + y: (batch, seqlen, d_model) output + h: updated inference cache after processing `u` + """ + if h: + return self.step(u, h) + + A = -torch.exp(self.A_log) # (nheads,) + zxbcdt = self.in_proj(u) # (batch, seqlen, d_in_proj) + z, xBC, dt = torch.split( + zxbcdt, + [ + self.args.d_inner, + self.args.d_inner + 2 * self.args.d_state, + self.args.nheads, + ], + dim=-1, + ) + dt = F.softplus(dt + self.dt_bias) # (batch, seqlen, nheads) + + # Pad or truncate xBC seqlen to d_conv + conv_state = F.pad( + rearrange(xBC, "b l d -> b d l"), (self.args.d_conv - u.shape[1], 0) ) - self.dt_bias = torch.ones(self.num_heads) - A = torch.arange(1, self.num_heads + 1) - self.A_log = torch.log(A) - self.D = torch.ones(self.num_heads) + xBC = silu( + self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, : u.shape[1], :] + ) # (batch, seqlen, d_inner + 2 * d_state)) + x, B, C = torch.split( + xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1 + ) + x = rearrange(x, "b l (h p) -> b l h p", p=self.args.headdim) + y, ssm_state = ssd( + x * dt.unsqueeze(-1), + A * dt, + rearrange(B, "b l n -> b l 1 n"), + rearrange(C, "b l n -> b l 1 n"), + self.args.chunk_size, + device=self.device, + ) + y = y + x * self.D.unsqueeze(-1) + y = rearrange(y, "b l h p -> b l (h p)") + y = self.norm(y, z) + y = self.out_proj(y) - self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon) - self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) + h = InferenceCache(conv_state, ssm_state) + return y, h - def forward(self, input_states, cache: Optional[Mamba2Cache]=None): - batch_size, seq_len, _ = input_states.shape - # Gated MLP's linear projection - projected_states = self.in_proj(input_states.squeeze(1)) - d_mlp = ( - projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.state_size- self.num_heads) // 2 - _, _, gate, hidden_states, dt = projected_states.split( - [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + def step(self, u: Tensor, h: InferenceCache) -> tuple[Tensor, InferenceCache]: + """Take a single inference step for the current input and hidden state + + Unlike attention-based models, RNN-based models (eg Mamba) does not need + to look back at all the past tokens to generate a new token. Instead a + hidden state (initialized to 0s initially) is updated for each input and + passed to the next inference step. This means that the total inference + time is linear with respect to the sequence length instead of quadratic + in attention's case. + + Arguments + u: (batch, 1, d_model) + h: initial/running hidden state + + Return (y, h) + y: (batch, 1, d_model) + h: updated hidden state + """ + assert u.shape[1] == 1, "Only one token can be decoded per inference step" + + zxbcdt = self.in_proj(u.squeeze(1)) # (batch, d_in_proj) + z, xBC, dt = torch.split( + zxbcdt, + [ + self.args.d_inner, + self.args.d_inner + 2 * self.args.d_state, + self.args.nheads, + ], + dim=-1, ) - # Convolution sequence transformation - ssm_state = cache.ssm_states[self.layer_idx].clone() - ssm_state = ssm_state.to(hidden_states.device) + # Advance convolution input + h.conv_state.copy_(torch.roll(h.conv_state, shifts=-1, dims=-1)) + h.conv_state[:, :, -1] = xBC + # Convolution step + xBC = torch.sum( + h.conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1 + ) + xBC += self.conv1d.bias + xBC = silu(xBC) - if cache.seqlen_offset > 0: - conv_state = cache.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel] - conv_state = torch.roll(conv_state, shifts=-1, dims=-1) + x, B, C = torch.split( + xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1 + ) + A = -torch.exp(self.A_log) # (nheads,) - # handle batched generation - states are copied through - conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states - cache.conv_states[self.layer_idx].copy_(conv_state) - hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + # SSM step + dt = F.softplus(dt + self.dt_bias) # (batch, nheads) + dA = torch.exp(dt * A) # (batch, nheads) + x = rearrange(x, "b (h p) -> b h p", p=self.args.headdim) + dBx = torch.einsum("bh, bn, bhp -> bhpn", dt, B, x) + h.ssm_state.copy_(h.ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx) + y = torch.einsum("bhpn, bn -> bhp", h.ssm_state, C) + y = y + rearrange(self.D, "h -> h 1") * x + y = rearrange(y, "b h p -> b (h p)") + y = self.norm(y, z) + y = self.out_proj(y) - if self.use_conv_bias: - hidden_states += self.conv1d.bias - hidden_states = nn.silu(hidden_states)[:, None, ...] # [batch, 1, intermediate_size] : decoding - else: - hidden_states = hidden_states.transpose(1,2) - conv_state = nn.functional.pad( - hidden_states, - (self.conv_kernel - hidden_states.shape[-1], 0) - ) - cache.conv_states[self.layer_idx].copy_(conv_state) - hidden_states = nn.silu(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len] - - hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.state_size, self.n_groups * self.state_size], dim=-1) - A = -torch.exp(self.A_log.float()) # [num_heads] - - if cache is not None and cache.seqlen_offset > 0: - # Note: there is no need to pad parameter matrices here, as there is just one new token - # for batched generation - dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...] - dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) - # [num_heads] -> [num_heads, head_dim] - dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) - - dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype)) - dt = torch.clamp(dt, self.time_step_min) #, self.time_step_max) - A = A[..., None, None].expand(self.num_heads, self.head_dim, self.state_size).to(dtype=torch.float32) - # [bsz, num_heads, head_dim, state_size] - dA = torch.exp(dt[..., None] * A) - - # Discretize B - # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> - # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] - B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] - B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous() - B = B.reshape(batch_size, -1, B.shape[-1]) - # [bsz, num_heads, head_dim, state_size] - dB = dt[..., None] * B[..., None, :] - - # Discretize x into dB - # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] - hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) - dBx = dB * hidden_states[..., None] - - # State calculation - cache.ssm_states[self.layer_idx].copy_( - cache.ssm_states[self.layer_idx] * dA + dBx - ) - - # Subsequent output - # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] - C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] - C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous() - C = C.reshape(batch_size, -1, C.shape[-1]) - # [bsz, num_heads, head_dim] - - ssm_states = cache.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n] - # Reshape ssm_states to merge the first two dimensions - ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.state_size) # Shape: [b*h, d, n] - C_reshaped = C.view(batch_size * self.num_heads, self.state_size, 1) # Shape: [b*h, n, 1] - y = torch.bmm(ssm_states_reshaped, C_reshaped) - y = y.view(batch_size, self.num_heads, self.head_dim) - - # D skip connection - # [num_heads] -> [num_heads, head_dim] - D = self.D[..., None].expand(self.D.shape[0], self.head_dim) - y = (y + hidden_states * D).to(y.dtype) - - # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size] - y = y.reshape(batch_size, -1)[:, None, ...] - else: - # begin ssd naive implementation without einsums - dt = nn.functional.softplus(dt + self.dt_bias) - dt = torch.clamp(dt, self.time_step_min) - hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() - B = B.reshape(batch_size, seq_len, -1, self.state_size).float() - C = C.reshape(batch_size, seq_len, -1, self.state_size).float() - B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) - C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) - pad_size = self.chunk_size - (seq_len % self.chunk_size) - - D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) - - # Discretize x and A - hidden_states = hidden_states * dt[..., None] - A = A.to(hidden_states.dtype) * dt - - # Rearrange into blocks/chunks - hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] + return y.unsqueeze(1), h - # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] - A = A.permute(0, 3, 1, 2) - A_cumsum = torch.cumsum(A, dim=-1) +def segsum(x: Tensor, device: Device = None) -> Tensor: + """Stable segment sum calculation. - # 1. Compute the output for each intra-chunk (diagonal blocks) - # This is the analog of a causal mask - L = torch.exp(segment_sum(A)) + `exp(segsum(A))` produces a 1-semiseparable matrix, which is equivalent to a scalar SSM. - # First, contraction of C and B to get G (attention-weights like) - G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n) - G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) + Source: https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L23-L32 + """ + T = x.size(-1) + x = repeat(x, "... d -> ... d e", e=T) + mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=-1) + x = x.masked_fill(~mask, 0) + x_segsum = torch.cumsum(x, dim=-2) + mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum - # Step 2: Compute M, equivalent to applying attention mask to weights - M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] - M = M_intermediate.sum(dim=-1) +def ssd(x, A, B, C, chunk_size, initial_states=None, device: Device = None): + """Structed State Space Duality (SSD) - the core of Mamba-2 - # Step 3: Compute Y_diag (apply to values) - Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3) + This is almost the exact same minimal SSD code from the blog post. - # (right term of low-rank factorization of off-diagonal blocks; B terms) + Arguments + x: (batch, seqlen, n_heads, d_head) + A: (batch, seqlen, n_heads) + B: (batch, seqlen, n_heads, d_state) + C: (batch, seqlen, n_heads, d_state) - decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) - B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None] - # permute back B * decay states - states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3) - if cache is not None and cache.seqlen_offset > 0: - previous_states = cache.ssm_states[self.layer_idx][:, None, ...] - else: - previous_states = torch.zeros_like(states[:, :1]) - states = torch.cat([previous_states, states], dim=1) - decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) + Return + y: (batch, seqlen, n_heads, d_head) - states_permuted = states.permute(0, 2, 1, 3, 4) - result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2) - new_states = result.permute(0, 2, 1, 3, 4) - states, ssm_state = new_states[:, :-1], new_states[:, -1] + Source + 1. https://tridao.me/blog/2024/mamba2-part3-algorithm/ + 2. https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L34-L78 + """ + assert x.shape[1] % chunk_size == 0 - # Compute state -> output conversion per chunk - # (left term of low-rank factorization of off-diagonal blocks; C terms) - state_decay_out = torch.exp(A_cumsum) - # compute Yoff - C_times_states = (C[..., None, :] * states[:, :, None, ...]) - state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) - Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) - # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + # Rearrange into chunks + # Step 1, 2 and 4 of SSD can be computed in parallel for each chunk across devices (sequence parallel) + # This is not implemented and left as an exercise for the reader 😜 + x, A, B, C = [ + rearrange(m, "b (c l) ... -> b c l ...", l=chunk_size) for m in (x, A, B, C) + ] - y = Y_diag + Y_off - # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] - y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) + A = rearrange(A, "b c l h -> b h c l") + A_cumsum = torch.cumsum(A, dim=-1) - y = y + D_residual - # Cutting off padded chunks - if pad_size > 0: - y = y[:, :seq_len, :, :] - y = y.reshape(batch_size, seq_len, -1) - if ssm_state is not None and cache is not None: - cache.ssm_states[self.layer_idx].copy_(ssm_state) + # 1. Compute the output for each intra-chunk (diagonal blocks) + L = torch.exp(segsum(A, device=device)) + Y_diag = torch.einsum("bclhn, bcshn, bhcls, bcshp -> bclhp", C, B, L, x) - scan_output = self.norm(y, gate) - # end ssd naive + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum) + states = torch.einsum("bclhn, bhcl, bclhp -> bchpn", B, decay_states, x) - # 4. Final linear projection - contextualized_states = self.out_proj(scan_output) # [batch, seq_len, hidden_size] - return contextualized_states + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if initial_states is None: + initial_states = torch.zeros_like(states[:, :1]) + states = torch.cat([initial_states, states], dim=1) + decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)), device=device)) + new_states = torch.einsum("bhzc, bchpn -> bzhpn", decay_chunk, states) + states, final_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + Y_off = torch.einsum("bclhn, bchpn, bhcl -> bclhp", C, states, state_decay_out) + + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p") + + return Y, final_state -class Mamba2Block(nn.Module): - def __init__(self, config): +class RMSNorm(nn.Module): + def __init__(self, d: int, eps: float = 1e-5, device: Device = None): + """Gated Root Mean Square Layer Normalization + + Paper: https://arxiv.org/abs/1910.07467 + """ super().__init__() - self.config = config - self.norm = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) - self.mixer = Mamba2Mixer(config) + self.eps = eps + self.weight = nn.Parameter(torch.ones(d, device=device)) - def forward( - self, - hidden_states, - cache: Optional[Mamba2Cache] = None, - cache_position: Optional[torch.LongTensor] = None, - ): - x = self.mixer( - self.norm(hidden_states), cache=cache, cache_position=cache_position - ) - return x + hidden_states + def forward(self, x, z=None): + if z is not None: + x = x * silu(z) + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight -class Mamba2Model(nn.Module): - def __init__(self, config): - super().__init__(config) - self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) - self.layers = nn.ModuleList([Mamba2Block(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]) +def silu(x): + """Applies the Sigmoid Linear Unit (SiLU), element-wise. - self.gradient_checkpointing = False - self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - cache: Optional[Mamba2Cache] = None, - cache_position: Optional[torch.LongTensor] = None, - ): - inputs_embeds = self.embeddings(input_ids) - hidden_states = inputs_embeds - - for mixer_block in self.layers: - hidden_states = mixer_block( - hidden_states, - cache=cache, - cache_position=cache_position, - ) - - cache.seqlen_offset += inputs_embeds.shape[1] - return self.norm_f(hidden_states), cache - - - -class Mamba2ForCausalLM(nn.Module): - def __init__(self, config): - super().__init__(config) - self.backbone = Mamba2Model(config) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - cache: Optional[Mamba2Cache] = None, - cache_position: Optional[torch.Tensor] = None, - ): - out, cache = self.backbone( - input_ids, - cache=cache, - cache_position=cache_position, - ) - logits = self.lm_head(out) - return logits, cache \ No newline at end of file + Define this manually since torch's version doesn't seem to work on MPS. + """ + return x * F.sigmoid(x) \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba2-working_but_giberish.py b/llms/mlx_lm/models/mamba2-working_but_giberish.py new file mode 100644 index 00000000..37666fbb --- /dev/null +++ b/llms/mlx_lm/models/mamba2-working_but_giberish.py @@ -0,0 +1,319 @@ +import math +from dataclasses import dataclass, field +from typing import Tuple, Union +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs +from .cache import MambaCache + +@dataclass +class ModelArgs(BaseModelArgs): + num_heads: int + head_dim: int + vocab_size: int + hidden_size: int + state_size: int + num_hidden_layers: int + layer_norm_epsilon: float + expand: int + conv_kernel: int + n_groups: int + use_bias: bool + use_conv_bias: bool + initializer_range: float + residual_in_fp32: bool + time_step_min: float + time_step_max: float + time_step_floor: float + rescale_prenorm_residual: bool + rms_norm: bool + chunk_size: int + tie_word_embeddings: bool + use_cache: bool = True + time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf"))) + time_step_rank: Union[int, str] = "auto" + model_type: str = "mamba2" + + def __post_init__(self): + if not hasattr(self, "intermediate_size"): + self.intermediate_size = int(self.expand * self.hidden_size) + if not hasattr(self, "head_dim"): + self.head_dim = self.hidden_size // self.num_heads + if self.time_step_rank == "auto": + self.time_step_rank = math.ceil(self.hidden_size / 16) + + +class MambaRMSNormGated(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = mx.ones((hidden_size,)) + self.variance_epsilon = eps + + def __call__(self, hidden_states, gate=None): + if gate is not None: + hidden_states = hidden_states * nn.silu(gate) + variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True) + hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states + + +def silu(x): + return x * mx.sigmoid(x) + +def ssd(x, A, B, C, chunk_size): + batch, seqlen, nheads, dim = x.shape + B = mx.expand_dims(B, axis=2) + C = mx.expand_dims(C, axis=2) + + state = mx.zeros((batch, nheads, dim, B.shape[-1])) + outputs = [] + + for i in range(0, seqlen, chunk_size): + chunk = slice(i, min(i + chunk_size, seqlen)) + dA = mx.exp(mx.expand_dims(A[chunk], axis=0)) + + x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim] + x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size] + B_chunk = B[:, chunk] # [batch, chunk_size, state_size] + dBx = mx.matmul(x_chunk, B_chunk) # [batch, nheads, dim, state_size] + + state = state * mx.expand_dims(dA, axis=-1) + dBx + + C_chunk = C[:, chunk] # [batch, chunk_size, state_size] + y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size] + y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim] + outputs.append(y) + + return mx.concatenate(outputs, axis=1), state + + +class DepthWiseConv1d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.padding = padding + self.groups = groups if groups is not None else in_channels + + assert in_channels == out_channels, "In and out channels must be same for depthwise convolution" + assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution" + + self.weight = mx.random.normal((in_channels, 1, kernel_size)) + self.bias = mx.zeros((out_channels,)) if bias else None + + def __call__(self, x: mx.array, cache=None) -> mx.array: + B, L, C = x.shape + K = self.kernel_size + + assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}" + + if cache is not None: + # Access conv_state directly from cache[0] + if cache[0] is None: + cache[0] = mx.zeros((B, K-1, C)) + + x = mx.concatenate([cache[0], x], axis=1) + + outputs = [] + for c in range(C): + x_c = x[:, :, c] + x_c = mx.expand_dims(x_c, axis=1) + + w_c = self.weight[c] + if w_c.ndim == 2: + w_c = mx.expand_dims(w_c, axis=0) + elif w_c.ndim == 1: + w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0) + + y_c = mx.conv_general( + x_c, + w_c, + stride=1, + padding=0 + ) + if self.bias is not None: + y_c = y_c + self.bias[c] + + y_c = mx.squeeze(y_c, axis=1) + outputs.append(y_c) + + y = mx.stack(outputs, axis=-1) + + # Update cache directly using cache[0] + if cache is not None: + cache[0] = x[:, -K+1:, :] if x.shape[1] >= K else x + + return y + + +class Mamba2Block(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + + d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads + self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias) + + conv_dim = args.intermediate_size + 2 * args.state_size + self.conv1d = DepthWiseConv1d( + in_channels=conv_dim, + out_channels=conv_dim, + kernel_size=args.conv_kernel, + groups=conv_dim, + bias=args.use_conv_bias, + padding=args.conv_kernel - 1 + ) + + self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range + self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range + self.D = mx.random.normal((args.num_heads,)) * args.initializer_range + + self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon) + self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) + + if args.rescale_prenorm_residual: + layer_scale = math.sqrt(1.0 / args.num_hidden_layers) + self.out_proj.weight = self.out_proj.weight * layer_scale + + def __call__(self, u: mx.array, cache=None): + batch_size, seq_len, dimension = u.shape + + # Initialize the negative A matrix + A = -mx.exp(self.A_log) + + # Process sequence in chunks if needed + outputs = [] + current_cache = cache + + for i in range(seq_len): + # Extract current token + current_input = u[:, i:i+1, :] + + # Initialize cache states if needed + if current_cache[0] is None: # conv state + conv_dim = self.args.intermediate_size + 2 * self.args.state_size + current_cache[0] = mx.zeros((batch_size, self.args.conv_kernel - 1, conv_dim)) + + if current_cache[1] is None: # ssm state + current_cache[1] = mx.zeros(( + batch_size, + self.args.num_heads, + self.args.head_dim, + self.args.state_size + )) + + # Project input + zxbcdt = self.in_proj(current_input) + + n_heads = self.args.num_heads + z = zxbcdt[:, :, :self.args.intermediate_size] + xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size] + dt = zxbcdt[:, :, -(n_heads):] + + # Process time steps + dt = mx.reshape(dt, (batch_size, n_heads)) + dt = mx.clip(nn.softplus(dt + self.dt_bias), self.args.time_step_min, self.args.time_step_max) + dt = mx.maximum(dt, self.args.time_step_floor) + + # Apply convolution + xBC = self.conv1d(xBC, cache=current_cache) + xBC = silu(xBC) + + # Split states + x = xBC[:, :, :self.args.intermediate_size] + B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size] + C = xBC[:, :, -self.args.state_size:] + + # Reshape for SSM + x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim)) + x = mx.squeeze(x, axis=1) + B = mx.reshape(B, (batch_size, 1, self.args.state_size)) + B = mx.broadcast_to(B, (batch_size, n_heads, self.args.state_size)) + B = mx.expand_dims(B, axis=2) + C = mx.reshape(C, (batch_size, 1, self.args.state_size)) + C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size)) + C = mx.expand_dims(C, axis=3) + + # SSM updates + dA = mx.exp(dt * mx.expand_dims(A, 0)) + dA = mx.expand_dims(mx.expand_dims(dA, -1), -1) + + # Update state + x = mx.expand_dims(x, axis=3) + dBx = mx.matmul(x, B) + current_cache[1] = current_cache[1] * dA + dBx + + # Compute output + y = mx.matmul(current_cache[1], C) + y = mx.squeeze(y, axis=-1) + y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1) + y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim)) + y = self.norm(y + z) + + outputs.append(self.out_proj(y)) + + # Concatenate all outputs + return mx.concatenate(outputs, axis=1) + + +class ResidualBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.residual_in_fp32 = args.residual_in_fp32 + self.mixer = Mamba2Block(args) + self.norm = nn.RMSNorm(args.hidden_size) + + def __call__(self, x: mx.array, cache): + if self.residual_in_fp32: + x = x.astype(mx.float32) + normed = self.norm(x) + output = self.mixer(normed, cache) + return output + x + +class Mamba2(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] + self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + + def __call__(self, x: mx.array, cache): + x = self.embeddings(x) + if cache is None: + cache = [None] * len(self.layers) + + hidden = x + for layer, c in zip(self.layers, cache): + hidden = layer(hidden, c) + return self.norm_f(hidden) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.backbone = Mamba2(args) + + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__(self, inputs: mx.array, cache=None): + hidden = self.backbone(inputs, cache) + + if self.args.tie_word_embeddings: + logits = self.backbone.embeddings.as_linear(hidden) + else: + logits = self.lm_head(hidden) + + return logits + + def make_cache(self): + return [MambaCache() for _ in range(len(self.layers))] + + @property + def layers(self): + return self.backbone.layers \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba2-works-hella-slow.py b/llms/mlx_lm/models/mamba2-works-hella-slow.py index 4468432f..2960d3d0 100644 --- a/llms/mlx_lm/models/mamba2-works-hella-slow.py +++ b/llms/mlx_lm/models/mamba2-works-hella-slow.py @@ -5,7 +5,7 @@ import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs -from .cache import Mamba2Cache +from .cache import MambaCache @dataclass class ModelArgs(BaseModelArgs): @@ -62,7 +62,6 @@ def silu(x): return x * mx.sigmoid(x) def ssd(x, A, B, C, chunk_size): - # Replace einsum operations with explicit reshape and matrix multiply batch, seqlen, nheads, dim = x.shape B = mx.expand_dims(B, axis=2) C = mx.expand_dims(C, axis=2) @@ -74,7 +73,6 @@ def ssd(x, A, B, C, chunk_size): chunk = slice(i, min(i + chunk_size, seqlen)) dA = mx.exp(mx.expand_dims(A[chunk], axis=0)) - # Replace einsum with explicit operations x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim] x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size] B_chunk = B[:, chunk] # [batch, chunk_size, state_size] @@ -82,14 +80,13 @@ def ssd(x, A, B, C, chunk_size): state = state * mx.expand_dims(dA, axis=-1) + dBx - # Replace einsum with explicit operations C_chunk = C[:, chunk] # [batch, chunk_size, state_size] y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size] y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim] outputs.append(y) return mx.concatenate(outputs, axis=1), state - + class DepthWiseConv1d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0): @@ -112,43 +109,42 @@ class DepthWiseConv1d(nn.Module): assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}" - if cache is not None and cache.conv_states[0] is not None: - # Convert None to proper array if needed - if isinstance(cache.conv_states[0], type(None)): - cache.conv_states[0] = mx.zeros((B, K-1, C)) - x = mx.concatenate([cache.conv_states[0], x], axis=1) - - # Process each channel independently + if cache is not None: + # Access conv_state directly from cache[0] + if cache[0] is None: + cache[0] = mx.zeros((B, K-1, C)) + + x = mx.concatenate([cache[0], x], axis=1) + outputs = [] for c in range(C): x_c = x[:, :, c] x_c = mx.expand_dims(x_c, axis=1) - + w_c = self.weight[c] if w_c.ndim == 2: w_c = mx.expand_dims(w_c, axis=0) elif w_c.ndim == 1: w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0) - - # Apply convolution + y_c = mx.conv_general( x_c, w_c, stride=1, padding=0 ) - if self.bias is not None: y_c = y_c + self.bias[c] - outputs.append(mx.squeeze(y_c, axis=1)) - - y = mx.stack(outputs, axis=-1) + y_c = mx.squeeze(y_c, axis=1) + outputs.append(y_c) - # Update cache + y = mx.stack(outputs, axis=-1) + + # Update cache directly using cache[0] if cache is not None: - cache.conv_states[0] = x[:, -K+1:, :] if x.shape[1] >= K else x - + cache[0] = x[:, -K+1:, :] if x.shape[1] >= K else x + return y @@ -182,98 +178,80 @@ class Mamba2Block(nn.Module): self.out_proj.weight = self.out_proj.weight * layer_scale def __call__(self, u: mx.array, cache=None): - batch_size = u.shape[0] - seq_len = u.shape[1] - outputs = [] + batch_size, seq_len, dimension = u.shape + assert seq_len == 1, "Input should be a single token" - # Initialize states if needed - if cache.conv_states[0] is None: + # Initialize cache states directly using indices + if cache[0] is None: # conv state conv_dim = self.args.intermediate_size + 2 * self.args.state_size - cache.conv_states[0] = mx.zeros(( - batch_size, - self.args.conv_kernel - 1, - conv_dim - )) - - if cache.ssm_states[0] is None: - cache.ssm_states[0] = mx.zeros(( + cache[0] = mx.zeros((batch_size, self.args.conv_kernel - 1, conv_dim)) + + if cache[1] is None: # ssm state + cache[1] = mx.zeros(( batch_size, self.args.num_heads, self.args.head_dim, self.args.state_size )) - for pos in range(seq_len): - u_t = u[:, pos:pos+1, :] - zxbcdt = self.in_proj(u_t) - - n_heads = self.args.num_heads - - z = zxbcdt[:, :, :self.args.intermediate_size] - xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size] - dt = zxbcdt[:, :, -(n_heads):] - - dt = mx.reshape(dt, (batch_size, n_heads)) - dt = mx.clip( - nn.softplus(dt + self.dt_bias), - self.args.time_step_min, - self.args.time_step_max - ) - dt = mx.maximum(dt, self.args.time_step_floor) + zxbcdt = self.in_proj(u) + + n_heads = self.args.num_heads + z = zxbcdt[:, :, :self.args.intermediate_size] + xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size] + dt = zxbcdt[:, :, -(n_heads):] - xBC = self.conv1d(xBC, cache=cache) - xBC = silu(xBC) + dt = mx.reshape(dt, (batch_size, n_heads)) + dt = mx.clip(nn.softplus(dt + self.dt_bias), self.args.time_step_min, self.args.time_step_max) + dt = mx.maximum(dt, self.args.time_step_floor) - x = xBC[:, :, :self.args.intermediate_size] - B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size] - C = xBC[:, :, -self.args.state_size:] - - x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim)) - x = mx.squeeze(x, axis=1) + xBC = self.conv1d(xBC, cache=cache) + xBC = silu(xBC) - B = mx.reshape(B, (batch_size, 1, self.args.state_size)) - B = mx.broadcast_to(B, (batch_size, n_heads, self.args.state_size)) - B = mx.expand_dims(B, axis=2) + x = xBC[:, :, :self.args.intermediate_size] + B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size] + C = xBC[:, :, -self.args.state_size:] - C = mx.reshape(C, (batch_size, 1, self.args.state_size)) - C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size)) - C = mx.expand_dims(C, axis=3) + x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim)) + x = mx.squeeze(x, axis=1) + B = mx.reshape(B, (batch_size, 1, self.args.state_size)) + B = mx.broadcast_to(B, (batch_size, n_heads, self.args.state_size)) + B = mx.expand_dims(B, axis=2) + C = mx.reshape(C, (batch_size, 1, self.args.state_size)) + C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size)) + C = mx.expand_dims(C, axis=3) - A = -mx.exp(self.A_log) - dA = mx.exp(dt * mx.expand_dims(A, 0)) - dA = mx.expand_dims(mx.expand_dims(dA, -1), -1) + A = -mx.exp(self.A_log) + dA = mx.exp(dt * mx.expand_dims(A, 0)) + dA = mx.expand_dims(mx.expand_dims(dA, -1), -1) - x = mx.expand_dims(x, axis=3) - dBx = mx.matmul(x, B) - - cache.ssm_states[0] = cache.ssm_states[0] * dA + dBx + x = mx.expand_dims(x, axis=3) + dBx = mx.matmul(x, B) + # Update ssm state directly using cache[1] + cache[1] = cache[1] * dA + dBx - y = mx.matmul(cache.ssm_states[0], C) - y = mx.squeeze(y, axis=-1) - - y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1) - - y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim)) - y = self.norm(y + z) - y = self.out_proj(y) - outputs.append(y) + y = mx.matmul(cache[1], C) + y = mx.squeeze(y, axis=-1) + y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1) + y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim)) + y = self.norm(y + z) - return mx.concatenate(outputs, axis=1) + return self.out_proj(y) class ResidualBlock(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.residual_in_fp32 = args.residual_in_fp32 - self.mixer = Mamba2Block(args) self.norm = nn.RMSNorm(args.hidden_size) def __call__(self, x: mx.array, cache): if self.residual_in_fp32: x = x.astype(mx.float32) - return self.mixer(self.norm(x), cache) + x - + normed = self.norm(x) + output = self.mixer(normed, cache) + return output + x class Mamba2(nn.Module): def __init__(self, args: ModelArgs): @@ -287,9 +265,11 @@ class Mamba2(nn.Module): x = self.embeddings(x) if cache is None: cache = [None] * len(self.layers) + + hidden = x for layer, c in zip(self.layers, cache): - x = layer(x, c) - return self.norm_f(x) + hidden = layer(hidden, c) + return self.norm_f(hidden) class Model(nn.Module): @@ -297,40 +277,23 @@ class Model(nn.Module): super().__init__() self.args = args self.model_type = args.model_type - self.backbone = Mamba2(args) if not args.tie_word_embeddings: self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) def __call__(self, inputs: mx.array, cache=None): - B, T = inputs.shape - - x = self.backbone(inputs, cache) - + hidden = self.backbone(inputs, cache) + if self.args.tie_word_embeddings: - logits = self.backbone.embeddings.as_linear(x) + logits = self.backbone.embeddings.as_linear(hidden) else: - logits = self.lm_head(x) - + logits = self.lm_head(hidden) + return logits - def make_cache(self, batch_size=1): - return [Mamba2Cache(batch_size, self.args.conv_kernel) for _ in range(len(self.layers))] - - def sanitize(self, weights): - sanitized = {} - for k, v in weights.items(): - if "conv1d.weight" in k: - # Ensure weights are in correct shape (channels, 1, kernel_size) - if v.ndim == 2: - v = mx.expand_dims(v, axis=1) - elif v.ndim == 1: - v = mx.expand_dims(mx.expand_dims(v, axis=0), axis=0) - sanitized[k] = v - else: - sanitized[k] = v - return sanitized + def make_cache(self): + return [MambaCache() for _ in range(len(self.layers))] @property def layers(self): diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index 2d8f4a09..3360a615 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -148,202 +148,131 @@ class DepthWiseConv1d(nn.Module): return y -# class Mamba2Block(nn.Module): -# def __init__(self, args: ModelArgs): -# super().__init__() -# self.args = args - -# d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads -# self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias) - -# conv_dim = args.intermediate_size + 2 * args.state_size -# self.conv1d = DepthWiseConv1d( -# in_channels=conv_dim, -# out_channels=conv_dim, -# kernel_size=args.conv_kernel, -# groups=conv_dim, -# bias=args.use_conv_bias, -# padding=args.conv_kernel - 1 -# ) - -# self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range -# self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range -# self.D = mx.random.normal((args.num_heads,)) * args.initializer_range - -# self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon) -# self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) - -# if args.rescale_prenorm_residual: -# layer_scale = math.sqrt(1.0 / args.num_hidden_layers) -# self.out_proj.weight = self.out_proj.weight * layer_scale - -# def __call__(self, u: mx.array, cache=None): -# batch_size, seq_len, dimension = u.shape -# assert seq_len == 1, "Input should be a single token" - -# # Initialize cache states directly using indices -# if cache[0] is None: # conv state -# conv_dim = self.args.intermediate_size + 2 * self.args.state_size -# cache[0] = mx.zeros((batch_size, self.args.conv_kernel - 1, conv_dim)) - -# if cache[1] is None: # ssm state -# cache[1] = mx.zeros(( -# batch_size, -# self.args.num_heads, -# self.args.head_dim, -# self.args.state_size -# )) - -# zxbcdt = self.in_proj(u) - -# n_heads = self.args.num_heads -# z = zxbcdt[:, :, :self.args.intermediate_size] -# xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size] -# dt = zxbcdt[:, :, -(n_heads):] - -# dt = mx.reshape(dt, (batch_size, n_heads)) -# dt = mx.clip(nn.softplus(dt + self.dt_bias), self.args.time_step_min, self.args.time_step_max) -# dt = mx.maximum(dt, self.args.time_step_floor) - -# xBC = self.conv1d(xBC, cache=cache) -# xBC = silu(xBC) - -# x = xBC[:, :, :self.args.intermediate_size] -# B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size] -# C = xBC[:, :, -self.args.state_size:] - -# x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim)) -# x = mx.squeeze(x, axis=1) -# B = mx.reshape(B, (batch_size, 1, self.args.state_size)) -# B = mx.broadcast_to(B, (batch_size, n_heads, self.args.state_size)) -# B = mx.expand_dims(B, axis=2) -# C = mx.reshape(C, (batch_size, 1, self.args.state_size)) -# C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size)) -# C = mx.expand_dims(C, axis=3) - -# A = -mx.exp(self.A_log) -# dA = mx.exp(dt * mx.expand_dims(A, 0)) -# dA = mx.expand_dims(mx.expand_dims(dA, -1), -1) - -# x = mx.expand_dims(x, axis=3) -# dBx = mx.matmul(x, B) -# # Update ssm state directly using cache[1] -# cache[1] = cache[1] * dA + dBx - -# y = mx.matmul(cache[1], C) -# y = mx.squeeze(y, axis=-1) -# y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1) -# y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim)) -# y = self.norm(y + z) - -# return self.out_proj(y) - - class Mamba2Block(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args = args - d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads - self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias) - - conv_dim = args.intermediate_size + 2 * args.state_size + # Calculate dimensions + self.d_model = args.hidden_size + self.d_state = args.state_size + self.d_conv = args.conv_kernel + self.expand = args.expand + self.d_inner = int(self.expand * self.d_model) + self.n_heads = args.num_heads + self.d_head = self.d_inner // self.n_heads + + # Input projection + d_in_proj = self.d_inner * 2 + self.d_state * 2 + self.n_heads + self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=args.use_bias) + + # Convolution + conv_dim = self.d_inner + 2 * self.d_state self.conv1d = DepthWiseConv1d( in_channels=conv_dim, out_channels=conv_dim, - kernel_size=args.conv_kernel, - groups=conv_dim, + kernel_size=self.d_conv, bias=args.use_conv_bias, - padding=args.conv_kernel - 1 + groups=conv_dim ) - - self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range - self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range - self.D = mx.random.normal((args.num_heads,)) * args.initializer_range - - self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon) - self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) - + + # SSM parameters + self.dt_bias = mx.random.normal((self.n_heads,)) * args.initializer_range + self.A_log = mx.random.normal((self.n_heads,)) * args.initializer_range + self.D = mx.random.normal((self.n_heads,)) * args.initializer_range + + # Output projection + self.norm = MambaRMSNormGated(self.d_inner, eps=args.layer_norm_epsilon) + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=args.use_bias) + if args.rescale_prenorm_residual: layer_scale = math.sqrt(1.0 / args.num_hidden_layers) self.out_proj.weight = self.out_proj.weight * layer_scale def __call__(self, u: mx.array, cache=None): - batch_size, seq_len, dimension = u.shape + batch_size, seq_len, _ = u.shape - # Process sequence in chunks if needed + # Project input + proj = self.in_proj(u) # [batch, seq_len, d_in_proj] + + # Calculate split indices and slice tensors + z = proj[..., :self.d_inner] + x_conv = proj[..., self.d_inner:self.d_inner + (self.d_inner + 2 * self.d_state)] + dt = proj[..., -self.n_heads:] + + # Process time steps + dt = nn.softplus(dt + self.dt_bias) + dt = mx.clip(dt, self.args.time_step_min, self.args.time_step_max) + dt = mx.maximum(dt, self.args.time_step_floor) + + # Convolution and activation + x_conv = self.conv1d(x_conv, cache=[cache[0] if cache else None]) + x_conv = silu(x_conv) + + # Split conv output + x = x_conv[..., :self.d_inner] + B = x_conv[..., self.d_inner:self.d_inner + self.d_state] + C = x_conv[..., -self.d_state:] + + # Reshape x for SSM + x = mx.reshape(x, (batch_size, seq_len, self.n_heads, self.d_head)) + + # Process B and C without reshaping heads + B = mx.expand_dims(B, axis=2) # [batch, seq_len, 1, d_state] + B = mx.broadcast_to(B, (batch_size, seq_len, self.n_heads, self.d_state)) + + C = mx.expand_dims(C, axis=2) # [batch, seq_len, 1, d_state] + C = mx.broadcast_to(C, (batch_size, seq_len, self.n_heads, self.d_state)) + + # Initialize or get previous state + if cache and cache[1] is not None: + prev_state = cache[1] + else: + prev_state = mx.zeros((batch_size, self.n_heads, self.d_head, self.d_state)) + + # Compute dA + dA = -mx.exp(self.A_log) # [n_heads] + dt = mx.reshape(dt, (batch_size, seq_len, self.n_heads)) # Ensure correct shape + dA = mx.exp(mx.expand_dims(dt * mx.expand_dims(dA, 0), -1)) # [batch, seq_len, n_heads, 1] + dA = mx.expand_dims(dA, -1) # [batch, seq_len, n_heads, 1, 1] + + # Process sequence + next_state = prev_state outputs = [] - current_cache = cache - for i in range(seq_len): - # Extract current token - current_input = u[:, i:i+1, :] + for t in range(seq_len): + # Get current step tensors + xt = x[:, t] # [batch, n_heads, d_head] + Bt = B[:, t] # [batch, n_heads, d_state] + Ct = C[:, t] # [batch, n_heads, d_state] + dAt = dA[:, t] # [batch, n_heads, 1, 1] - # Initialize cache states if needed - if current_cache[0] is None: # conv state - conv_dim = self.args.intermediate_size + 2 * self.args.state_size - current_cache[0] = mx.zeros((batch_size, self.args.conv_kernel - 1, conv_dim)) - - if current_cache[1] is None: # ssm state - current_cache[1] = mx.zeros(( - batch_size, - self.args.num_heads, - self.args.head_dim, - self.args.state_size - )) - - # Project input - zxbcdt = self.in_proj(current_input) - - n_heads = self.args.num_heads - z = zxbcdt[:, :, :self.args.intermediate_size] - xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size] - dt = zxbcdt[:, :, -(n_heads):] - - # Process time steps - dt = mx.reshape(dt, (batch_size, n_heads)) - dt = mx.clip(nn.softplus(dt + self.dt_bias), self.args.time_step_min, self.args.time_step_max) - dt = mx.maximum(dt, self.args.time_step_floor) - - # Apply convolution - xBC = self.conv1d(xBC, cache=current_cache) - xBC = silu(xBC) - - # Split states - x = xBC[:, :, :self.args.intermediate_size] - B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size] - C = xBC[:, :, -self.args.state_size:] - - # Reshape for SSM - x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim)) - x = mx.squeeze(x, axis=1) - B = mx.reshape(B, (batch_size, 1, self.args.state_size)) - B = mx.broadcast_to(B, (batch_size, n_heads, self.args.state_size)) - B = mx.expand_dims(B, axis=2) - C = mx.reshape(C, (batch_size, 1, self.args.state_size)) - C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size)) - C = mx.expand_dims(C, axis=3) - - # SSM updates - A = -mx.exp(self.A_log) - dA = mx.exp(dt * mx.expand_dims(A, 0)) - dA = mx.expand_dims(mx.expand_dims(dA, -1), -1) - # Update state - x = mx.expand_dims(x, axis=3) - dBx = mx.matmul(x, B) - current_cache[1] = current_cache[1] * dA + dBx - - # Compute output - y = mx.matmul(current_cache[1], C) - y = mx.squeeze(y, axis=-1) - y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1) - y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim)) - y = self.norm(y + z) + next_state = ( + next_state * dAt + # Broadcasting: [batch, n_heads, d_head, d_state] * [batch, n_heads, 1, 1] + mx.matmul( + mx.expand_dims(xt, -1), # [batch, n_heads, d_head, 1] + mx.expand_dims(Bt, -2) # [batch, n_heads, 1, d_state] + ) + ) - outputs.append(self.out_proj(y)) - - # Concatenate all outputs + # Compute output + yt = mx.matmul( + next_state, # [batch, n_heads, d_head, d_state] + mx.expand_dims(Ct, -1) # [batch, n_heads, d_state, 1] + ) + yt = mx.squeeze(yt, -1) # [batch, n_heads, d_head] + yt = yt + xt * mx.expand_dims(self.D, -1) + + # Reshape and normalize + yt = mx.reshape(yt, (batch_size, 1, self.d_inner)) + yt = self.norm(yt, z[:, t:t+1]) + outputs.append(self.out_proj(yt)) + + # Update cache + if cache is not None: + cache[1] = next_state + return mx.concatenate(outputs, axis=1) diff --git a/llms/mlx_lm/models/mamba22.py b/llms/mlx_lm/models/mamba22.py deleted file mode 100644 index c0cbe1d7..00000000 --- a/llms/mlx_lm/models/mamba22.py +++ /dev/null @@ -1,316 +0,0 @@ -import math -from dataclasses import dataclass, field -from typing import Tuple, Union -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs -from .cache import Mamba2Cache - - -@dataclass -class ModelArgs(BaseModelArgs): - num_heads: int - head_dim: int - vocab_size: int - hidden_size: int - state_size: int - num_hidden_layers: int - layer_norm_epsilon: float - expand: int - conv_kernel: int - n_groups: int - use_bias: bool - use_conv_bias: bool - initializer_range: float - residual_in_fp32: bool - time_step_min: float - time_step_max: float - time_step_floor: float - rescale_prenorm_residual: bool - rms_norm: bool - chunk_size: int - tie_word_embeddings: bool - use_cache: bool = True - time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf"))) - time_step_rank: Union[int, str] = "auto" - model_type: str = "mamba2" - - def __post_init__(self): - if not hasattr(self, "intermediate_size"): - self.intermediate_size = int(self.expand * self.hidden_size) - if not hasattr(self, "head_dim"): - self.head_dim = self.hidden_size // self.num_heads - if self.time_step_rank == "auto": - self.time_step_rank = math.ceil(self.hidden_size / 16) - - -def silu(x): - return x * mx.sigmoid(x) - - -class MambaRMSNormGated(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - super().__init__() - self.weight = mx.ones((hidden_size,)) - self.variance_epsilon = eps - - def __call__(self, hidden_states, gate=None): - # Fuse operations where possible - if gate is not None: - hidden_states = hidden_states * nn.silu(gate) - # Compute variance in fp32 for better numerical stability - hidden_states_fp32 = hidden_states.astype(mx.float32) - variance = mx.mean(hidden_states_fp32 * hidden_states_fp32, axis=-1, keepdims=True) - hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states - - -def ssd_optimized(x, A, B, C, chunk_size): - batch, seqlen, nheads, dim = x.shape - B = mx.expand_dims(B, axis=2) - C = mx.expand_dims(C, axis=2) - - output = mx.zeros((batch, seqlen, nheads, dim)) - state = mx.zeros((batch, nheads, dim, B.shape[-1])) - - for i in range(0, seqlen, chunk_size): - chunk = slice(i, min(i + chunk_size, seqlen)) - chunk_size_actual = min(chunk_size, seqlen - i) - - dA = mx.exp(mx.expand_dims(A[chunk], axis=0)) - x_chunk = mx.transpose(x[:, chunk], [0, 2, 3, 1]) - dBx = mx.matmul(x_chunk, B[:, chunk]) - state = state * mx.expand_dims(dA, axis=-1) + dBx - y = mx.matmul(state, mx.transpose(C[:, chunk], [0, 2, 1])) - output[:, i:i+chunk_size_actual] = mx.transpose(y, [0, 3, 1, 2]) - - return output, state - - -def update_conv_cache(x: mx.array, cache, kernel_size: int) -> Tuple[mx.array, mx.array]: - """Update convolution cache for sequential processing.""" - B, L, C = x.shape - - if cache is None: - # Initialize cache with zeros - cache = mx.zeros((B, kernel_size - 1, C)) - - # Concatenate cache with current input - x_with_cache = mx.concatenate([cache, x], axis=1) - - # Update cache with the last (kernel_size - 1) elements - new_cache = x_with_cache[:, -kernel_size+1:] if x_with_cache.shape[1] >= kernel_size else x_with_cache - - return x_with_cache, new_cache - - -class Mamba2Block(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - - self.intermediate_size = int(args.expand * args.hidden_size) - self.state_size = args.state_size - self.num_heads = args.num_heads - self.head_dim = args.head_dim - self.ssm_state_size = args.state_size - self.n_groups = args.n_groups - self.conv_kernel = args.conv_kernel - - self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size - projection_size = self.intermediate_size + self.conv_dim + self.num_heads - - self.in_proj = nn.Linear(args.hidden_size, projection_size, bias=args.use_bias) - - # Using built-in Conv1d instead of custom DepthwiseConv1d - self.conv1d = nn.Conv1d( - in_channels=self.conv_dim, - out_channels=self.conv_dim, - kernel_size=args.conv_kernel, - groups=self.conv_dim, # For depthwise convolution - padding=0, # We'll handle padding manually with the cache - bias=args.use_conv_bias - ) - - self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range - self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range - self.D = mx.random.normal((args.num_heads,)) * args.initializer_range - - self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon) - self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) - - if args.rescale_prenorm_residual: - layer_scale = math.sqrt(1.0 / args.num_hidden_layers) - self.out_proj.weight = self.out_proj.weight * layer_scale - - def __call__(self, u: mx.array, cache): - batch_size, seq_len, _ = u.shape - - projected = self.in_proj(u) - d_conv = self.conv_dim - - z = projected[..., :self.intermediate_size] - xBC = projected[..., self.intermediate_size:self.intermediate_size + d_conv] - dt = projected[..., -self.num_heads:] - - dt = mx.clip( - nn.softplus(dt + self.dt_bias), - self.args.time_step_min, - self.args.time_step_max - ) - dt = mx.maximum(dt, self.args.time_step_floor) - - # Handle convolution with separate cache update - if cache is not None: - # Update cache and get padded input - xBC_padded, new_cache = update_conv_cache(xBC, cache.conv_states, self.conv_kernel) - cache.conv_states = new_cache - - # Prepare input for conv1d: [B, L, C] -> [B, C, L] - xBC_conv = mx.transpose(xBC_padded, [0, 2, 1]) - - # Apply convolution - xBC = self.conv1d(xBC_conv) - - # Transform back: [B, C, L] -> [B, L, C] - xBC = mx.transpose(xBC, [0, 2, 1]) - - # Take only the relevant part corresponding to input length - xBC = xBC[:, :seq_len] - else: - # For training, use regular convolution with padding - xBC = mx.transpose(xBC, [0, 2, 1]) - xBC = self.conv1d(xBC) - xBC = mx.transpose(xBC, [0, 2, 1]) - - xBC = silu(xBC) - - x = xBC[..., :self.intermediate_size] - BC = xBC[..., self.intermediate_size:] - B = BC[..., :self.state_size] - C = BC[..., self.state_size:] - - x = mx.reshape(x, (-1, seq_len, self.num_heads, self.intermediate_size // self.num_heads)) - - A = -mx.exp(self.A_log) - D_expanded = mx.expand_dims(self.D, -1) - - if cache is not None and cache.ssm_state is None: - cache.ssm_state = mx.zeros(( - batch_size, - self.num_heads, - self.intermediate_size // self.num_heads, - self.state_size - )) - - if cache is not None: - output = mx.zeros((batch_size, seq_len, self.args.hidden_size)) - - for pos in range(seq_len): - x_t = x[:, pos:pos+1] - - dA = mx.exp(dt[:, pos:pos+1] * mx.expand_dims(A, 0)) - dA = mx.expand_dims(mx.expand_dims(dA, -1), -1) - - x_expanded = mx.expand_dims(x_t, axis=3) - dBx = mx.matmul(x_expanded, mx.expand_dims(B[:, pos:pos+1], axis=2)) - - cache.ssm_state = cache.ssm_state * dA + dBx - - y = mx.matmul(cache.ssm_state, mx.expand_dims(C[:, pos:pos+1], axis=3)) - y = mx.squeeze(y, axis=-1) - y = y + x_t * D_expanded - - y = mx.reshape(y, (batch_size, 1, -1)) - y = self.norm(y + z[:, pos:pos+1]) - y = self.out_proj(y) - - if self.args.residual_in_fp32: - y = y.astype(mx.float32) - - output = output.at[:, pos:pos+1].set(y) - else: - y, ssm_state = ssd_optimized( - x * mx.expand_dims(dt, -1), - -mx.exp(self.A_log) * dt, - B, C, - self.args.chunk_size - ) - - y = mx.reshape( - y + x * mx.expand_dims(self.D, -1), - (batch_size, seq_len, -1) - ) - - y = self.norm(y + z) - output = self.out_proj(y) - - if self.args.residual_in_fp32: - output = output.astype(mx.float32) - - return output - - -class ResidualBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.mixer = Mamba2Block(args) - self.norm = nn.RMSNorm(args.hidden_size) - - def __call__(self, x: mx.array, cache): - return self.mixer(self.norm(x), cache) + x - - -class Mamba2(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] - self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) - - def __call__(self, x: mx.array, cache): - x = self.embeddings(x) - if cache is None: - cache = [None] * len(self.layers) - for layer, c in zip(self.layers, cache): - x = layer(x, c) - return self.norm_f(x) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - - self.backbone = Mamba2(args) - - if not args.tie_word_embeddings: - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__(self, inputs: mx.array, cache=None): - B, T = inputs.shape - - x = self.backbone(inputs, cache) - - if self.args.tie_word_embeddings: - logits = self.backbone.embeddings.as_linear(x) - else: - logits = self.lm_head(x) - - return logits - - def make_cache(self, batch_size=1): - return [Mamba2Cache() for _ in range(len(self.layers))] - - def sanitize(self, weights): - for k, v in weights.items(): - if "conv1d.weight" in k and v.ndim == 3: - weights[k] = v.moveaxis(2, 1) - return weights - - @property - def layers(self): - return self.backbone.layers \ No newline at end of file