From 7c8849e795dea45f939db39602040f0be451195f Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Thu, 24 Oct 2024 16:16:42 +0200 Subject: [PATCH] update --- llms/mlx_lm/models/cache.py | 141 +++++++-- llms/mlx_lm/models/mamba2 copy.py | 400 ++++++++++++++++++++++++++ llms/mlx_lm/models/mamba2-prch.py | 27 +- llms/mlx_lm/models/mamba2.py | 463 +++++++++++++----------------- 4 files changed, 757 insertions(+), 274 deletions(-) diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index b38d0203..05dcfee7 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -340,21 +340,130 @@ class MambaCache(_BaseCache): self.cache = v -class Mamba2Cache(_BaseCache): - conv_states: Optional[mx.array] = None - ssm_state: Optional[mx.array] = None +class Mamba2Cache: + batch_size: int + intermediate_size: int + state_size: int + conv_kernel: int + num_heads: int + head_dim: int + + def __init__( + self, + batch_size: int, + intermediate_size: int, + state_size: int, + conv_kernel: int, + num_heads: int, + head_dim: int + ): + self.batch_size = batch_size + self.intermediate_size = intermediate_size + self.state_size = state_size + self.conv_kernel = conv_kernel + self.num_heads = num_heads + self.head_dim = head_dim + + # Initialize conv state with proper dimensions + self.conv_dim = self.intermediate_size + 2 * self.state_size + self.conv_state = mx.zeros((batch_size, self.conv_dim, conv_kernel - 1)) + + # Initialize SSM state + self.ssm_state = mx.zeros(( + batch_size, + num_heads, + head_dim, + state_size + )) - def __getitem__(self, idx: int) -> Optional[mx.array]: - if idx == 0: - return self.conv_states - elif idx == 1: - return self.ssm_states - raise IndexError("Cache index must be 0 or 1") + def update_conv_state(self, x: mx.array) -> mx.array: + """ + Update convolution state for incremental inference. + Args: + x: Input tensor containing projected values (B, conv_in_dim) + Returns: + Combined state tensor of shape (batch_size, conv_dim, kernel_size) + """ + # Handle input shape + if x.ndim == 1: + x = mx.expand_dims(x, 0) # Add batch dimension if needed + + # Ensure batch size matches + assert x.shape[0] == self.batch_size, f"Batch size mismatch: {x.shape[0]} vs {self.batch_size}" + + # Reshape x to match conv_dim + # The input x contains intermediate_size + 2 * state_size dimensions + x_reshaped = mx.reshape(x, (self.batch_size, -1)) + x_padded = mx.pad( + x_reshaped, + [(0, 0), (0, self.conv_dim - x_reshaped.shape[1])], + mode='constant', + constant_values=0 + ) + + # Expand dims for concatenation + x_expanded = mx.expand_dims(x_padded, -1) # Shape: (batch_size, conv_dim, 1) + + # Roll the existing state left by 1 + rolled_state = mx.roll(self.conv_state, shift=-1, axis=-1) + + # Create update mask for the last position + update_pos = self.conv_kernel - 2 + state_idx = mx.arange(self.conv_kernel - 1) + update_mask = state_idx == update_pos + + # Broadcast mask to match state dimensions + update_mask = mx.broadcast_to( + mx.reshape(update_mask, (1, 1, -1)), + rolled_state.shape + ) + + # Update state with padded input + x_broadcast = mx.broadcast_to(x_expanded, (self.batch_size, self.conv_dim, 1)) + self.conv_state = mx.where( + update_mask, + x_broadcast, + rolled_state + ) + + # Return concatenated state for convolution + return mx.concatenate([self.conv_state, x_expanded], axis=-1) - def __setitem__(self, idx: int, value: Optional[mx.array]): - if idx == 0: - self.conv_states = value - elif idx == 1: - self.ssm_states = value - else: - raise IndexError("Cache index must be 0 or 1") \ No newline at end of file + def update_ssm_state(self, dA: mx.array, dBx: mx.array) -> mx.array: + """ + Update SSM state for incremental inference. + Args: + dA: State transition tensor of shape (batch_size, num_heads) + dBx: Input projection tensor of shape (batch_size, num_heads, head_dim, state_size) + Returns: + Updated SSM state of shape (batch_size, num_heads, head_dim, state_size) + """ + # Add necessary dimensions to dA for broadcasting + # dA shape: (batch_size, num_heads) -> (batch_size, num_heads, 1, 1) + dA = mx.expand_dims(mx.expand_dims(dA, -1), -1) + + # Ensure dBx has the correct shape + assert dBx.shape[-1] == self.state_size, f"dBx state dimension mismatch: {dBx.shape[-1]} vs {self.state_size}" + assert dBx.shape[-2] == self.head_dim, f"dBx head dimension mismatch: {dBx.shape[-2]} vs {self.head_dim}" + + # Update state: state = dA * state + dBx + self.ssm_state = dA * self.ssm_state + dBx + + return self.ssm_state + + @classmethod + def get_cache( + cls, + args, + batch_size: int, + max_seq_length: Optional[int] + ) -> "Mamba2Cache": + """Create a new cache instance with the given parameters.""" + return cls( + batch_size=batch_size, + intermediate_size=args.intermediate_size, + state_size=args.state_size, + conv_kernel=args.conv_kernel, + num_heads=args.num_heads, + head_dim=args.head_dim + ) \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba2 copy.py b/llms/mlx_lm/models/mamba2 copy.py index 9c3bb22d..0d4dedb2 100644 --- a/llms/mlx_lm/models/mamba2 copy.py +++ b/llms/mlx_lm/models/mamba2 copy.py @@ -258,3 +258,403 @@ class Model(nn.Module): @property def layers(self): return self.backbone.layers + + + + + +# ------ + + + +import math +from dataclasses import dataclass, field +from typing import Tuple, Union +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs +from .cache import Mamba2Cache + +@dataclass +class ModelArgs(BaseModelArgs): + num_heads: int + head_dim: int + vocab_size: int + hidden_size: int + state_size: int + num_hidden_layers: int + layer_norm_epsilon: float + expand: int + conv_kernel: int + n_groups: int + use_bias: bool + use_conv_bias: bool + initializer_range: float + residual_in_fp32: bool + time_step_min: float + time_step_max: float + time_step_floor: float + rescale_prenorm_residual: bool + rms_norm: bool + chunk_size: int + tie_word_embeddings: bool + use_cache: bool = True + time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf"))) + time_step_rank: Union[int, str] = "auto" + model_type: str = "mamba2" + + def __post_init__(self): + if not hasattr(self, "intermediate_size"): + self.intermediate_size = int(self.expand * self.hidden_size) + if not hasattr(self, "head_dim"): + self.head_dim = self.hidden_size // self.num_heads + if self.time_step_rank == "auto": + self.time_step_rank = math.ceil(self.hidden_size / 16) + + +class MambaRMSNormGated(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = mx.ones((hidden_size,)) + self.variance_epsilon = eps + + def __call__(self, hidden_states, gate=None): + if gate is not None: + hidden_states = hidden_states * nn.silu(gate) + variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True) + hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states + + +def silu(x): + return x * mx.sigmoid(x) + +def ssd(x, A, B, C, chunk_size): + # Replace einsum operations with explicit reshape and matrix multiply + batch, seqlen, nheads, dim = x.shape + B = mx.expand_dims(B, axis=2) + C = mx.expand_dims(C, axis=2) + + state = mx.zeros((batch, nheads, dim, B.shape[-1])) + outputs = [] + + for i in range(0, seqlen, chunk_size): + chunk = slice(i, min(i + chunk_size, seqlen)) + dA = mx.exp(mx.expand_dims(A[chunk], axis=0)) + + # Replace einsum with explicit operations + x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim] + x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size] + B_chunk = B[:, chunk] # [batch, chunk_size, state_size] + dBx = mx.matmul(x_chunk, B_chunk) # [batch, nheads, dim, state_size] + + state = state * mx.expand_dims(dA, axis=-1) + dBx + + # Replace einsum with explicit operations + C_chunk = C[:, chunk] # [batch, chunk_size, state_size] + y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size] + y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim] + outputs.append(y) + + return mx.concatenate(outputs, axis=1), state + + +class DepthWiseConv1d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.padding = padding + self.groups = groups if groups is not None else in_channels + + assert in_channels == out_channels, "In and out channels must be same for depthwise convolution" + assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution" + + self.weight = mx.random.normal((in_channels, 1, kernel_size)) + self.bias = mx.zeros((out_channels,)) if bias else None + + def __call__(self, x: mx.array, cache=None) -> mx.array: + B, L, C = x.shape + K = self.kernel_size + + assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}" + + if cache is not None and 'conv_states' in cache: + conv_states = cache['conv_states'] + if conv_states is not None: + assert conv_states.shape[0] == B, "Cache batch size mismatch" + assert conv_states.shape[2] == C, "Cache channel count mismatch" + x = mx.concatenate([conv_states, x], axis=1) + + # Process each channel independently + outputs = [] + for c in range(C): + x_c = x[:, :, c] + x_c = mx.expand_dims(x_c, axis=1) + + w_c = self.weight[c] + if w_c.ndim == 2: + w_c = mx.expand_dims(w_c, axis=0) + elif w_c.ndim == 1: + w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0) + + # Apply convolution + y_c = mx.conv_general( + x_c, + w_c, + stride=1, + padding=0 + ) + + if self.bias is not None: + y_c = y_c + self.bias[c] + + outputs.append(mx.squeeze(y_c, axis=1)) + + y = mx.stack(outputs, axis=-1) + + # Update cache + if cache is not None: + cache['conv_states'] = x[:, -K+1:, :] if x.shape[1] >= K else x + + return y + + +class Mamba2Block(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + + d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads + self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias) + + conv_dim = args.intermediate_size + 2 * args.state_size + self.conv1d = DepthWiseConv1d( + in_channels=conv_dim, + out_channels=conv_dim, + kernel_size=args.conv_kernel, + groups=conv_dim, + bias=args.use_conv_bias, + padding=args.conv_kernel - 1 + ) + + self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range + self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range + self.D = mx.random.normal((args.num_heads,)) * args.initializer_range + + self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon) + self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) + + if args.rescale_prenorm_residual: + layer_scale = math.sqrt(1.0 / args.num_hidden_layers) + self.out_proj.weight = self.out_proj.weight * layer_scale + + def __call__(self, x: mx.array, cache=None): + if cache is not None: + return self.step(x, cache) + + # Regular forward pass code remains the same... + d_model = self.args.intermediate_size + d_state = self.args.state_size + n_heads = self.args.num_heads + + A = -mx.exp(self.A_log) + zxbcdt = self.in_proj(x) + + splits = [d_model, d_model + 2 * d_state, n_heads] + z = zxbcdt[:, :, :splits[0]] + xBC = zxbcdt[:, :, splits[0]:splits[0] + splits[1]] + dt = zxbcdt[:, :, -splits[2]:] + + dt = mx.clip( + nn.softplus(dt + self.dt_bias), + self.args.time_step_min, + self.args.time_step_max + ) + dt = mx.maximum(dt, self.args.time_step_floor) + + xBC = silu(self.conv1d(xBC)) + + x = xBC[:, :, :d_model] + B = xBC[:, :, d_model:d_model + d_state] + C = xBC[:, :, -d_state:] + + b, l, hp = x.shape + h = self.args.num_heads + p = hp // h + x = mx.reshape(x, (b, l, h, p)) + + y, ssm_state = ssd(x * mx.expand_dims(dt, -1), A * dt, B, C, self.args.chunk_size) + y = y + x * mx.expand_dims(self.D, -1) + y = mx.reshape(y, (b, l, h * p)) + + y = self.norm(y + z) + y = self.out_proj(y) + + if self.args.residual_in_fp32: + y = y.astype(mx.float32) + + return y + + def step(self, u: mx.array, cache): + batch_size = u.shape[0] + seq_len = u.shape[1] + outputs = [] + + # Initialize cache if needed + if cache.conv_states is None: + conv_dim = self.args.intermediate_size + 2 * self.args.state_size + cache.conv_states = mx.zeros(( + batch_size, + self.args.conv_kernel - 1, + conv_dim + )) + + if cache.ssm_state is None: + cache.ssm_state = mx.zeros(( + batch_size, + self.args.num_heads, + self.args.head_dim, + self.args.state_size + )) + + for pos in range(seq_len): + u_t = u[:, pos:pos+1, :] + zxbcdt = self.in_proj(u_t) + + d_model = self.args.intermediate_size + d_state = self.args.state_size + n_heads = self.args.num_heads + + z = zxbcdt[:, :, :d_model] + xBC = zxbcdt[:, :, d_model:d_model + 2*d_state + d_model] + dt = zxbcdt[:, :, -(n_heads):] + + dt = mx.reshape(dt, (batch_size, n_heads)) + dt = mx.clip( + nn.softplus(dt + self.dt_bias), + self.args.time_step_min, + self.args.time_step_max + ) + dt = mx.maximum(dt, self.args.time_step_floor) + + # Create a temporary cache dictionary for the convolution + conv_cache = {'conv_states': cache.conv_states} + xBC = self.conv1d(xBC, cache=conv_cache) + cache.conv_states = conv_cache['conv_states'] + + xBC = silu(xBC) + + x = xBC[:, :, :d_model] + B = xBC[:, :, d_model:d_model + d_state] + C = xBC[:, :, -d_state:] + + x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim)) + x = mx.squeeze(x, axis=1) + + B = mx.reshape(B, (batch_size, 1, d_state)) + B = mx.broadcast_to(B, (batch_size, n_heads, d_state)) + B = mx.expand_dims(B, axis=2) + + C = mx.reshape(C, (batch_size, 1, d_state)) + C = mx.broadcast_to(C, (batch_size, n_heads, d_state)) + C = mx.expand_dims(C, axis=3) + + A = -mx.exp(self.A_log) + dA = mx.exp(dt * mx.expand_dims(A, 0)) + dA = mx.expand_dims(mx.expand_dims(dA, -1), -1) + + x = mx.expand_dims(x, axis=3) + dBx = mx.matmul(x, B) + + cache.ssm_state = cache.ssm_state * dA + dBx + + y = mx.matmul(cache.ssm_state, C) + y = mx.squeeze(y, axis=-1) + + y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1) + + y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim)) + y = self.norm(y + z) + y = self.out_proj(y) + + if self.args.residual_in_fp32: + y = y.astype(mx.float32) + + outputs.append(y) + + return mx.concatenate(outputs, axis=1) + + +class ResidualBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.mixer = Mamba2Block(args) + self.norm = nn.RMSNorm(args.hidden_size) + + def __call__(self, x: mx.array, cache): + return self.mixer(self.norm(x), cache) + x + + +class Mamba2(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] + self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + + def __call__(self, x: mx.array, cache): + x = self.embeddings(x) + if cache is None: + cache = [None] * len(self.layers) + for layer, c in zip(self.layers, cache): + x = layer(x, c) + return self.norm_f(x) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + + self.backbone = Mamba2(args) + + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__(self, inputs: mx.array, cache=None): + B, T = inputs.shape + + x = self.backbone(inputs, cache) + + if self.args.tie_word_embeddings: + logits = self.backbone.embeddings.as_linear(x) + else: + logits = self.lm_head(x) + + return logits + + def make_cache(self): + return [Mamba2Cache() for _ in range(len(self.layers))] + + def sanitize(self, weights): + sanitized = {} + for k, v in weights.items(): + if "conv1d.weight" in k: + # Ensure weights are in correct shape (channels, 1, kernel_size) + if v.ndim == 2: + v = mx.expand_dims(v, axis=1) + elif v.ndim == 1: + v = mx.expand_dims(mx.expand_dims(v, axis=0), axis=0) + sanitized[k] = v + else: + sanitized[k] = v + return sanitized + + @property + def layers(self): + return self.backbone.layers diff --git a/llms/mlx_lm/models/mamba2-prch.py b/llms/mlx_lm/models/mamba2-prch.py index f9bd6797..f988a825 100644 --- a/llms/mlx_lm/models/mamba2-prch.py +++ b/llms/mlx_lm/models/mamba2-prch.py @@ -88,6 +88,32 @@ class Mamba2LMHeadModel(nn.Module): ) self.lm_head.weight = self.backbone.embedding.weight + @staticmethod + def from_pretrained(huggingface_model_id: str, device: Device = None): + from transformers.utils import CONFIG_NAME, WEIGHTS_NAME + from transformers.utils.hub import cached_file + + config_path = cached_file(huggingface_model_id, CONFIG_NAME) + assert config_path, "Failed to get huggingface config file" + state_dict_path = cached_file(huggingface_model_id, WEIGHTS_NAME) + assert state_dict_path, "Failed to get huggingface state dict file" + + config = json.load(open(config_path)) + args = Mamba2Config( + d_model=config["d_model"], + n_layer=config["n_layer"], + vocab_size=config["vocab_size"], + pad_vocab_size_multiple=config["pad_vocab_size_multiple"], + ) + + map_location = "cpu" if device is None else device + state_dict = torch.load( + state_dict_path, weights_only=True, map_location=map_location, mmap=True + ) + model = Mamba2LMHeadModel(args, device=device) + model.load_state_dict(state_dict) + model.eval() + return model def forward( self, input_ids: LongTensor, h: list[InferenceCache] | list[None] | None = None @@ -193,7 +219,6 @@ class Mamba2(nn.Module): self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device)) self.A_log = nn.Parameter(torch.empty(args.nheads, device=device)) self.D = nn.Parameter(torch.empty(args.nheads, device=device)) - self.norm = RMSNorm(args.d_inner, device=device) self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device) diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index 01f9485b..cb78f316 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -1,6 +1,7 @@ import math from dataclasses import dataclass, field -from typing import Tuple, Union +from typing import Optional, Tuple, Union + import mlx.core as mx import mlx.nn as nn @@ -27,10 +28,10 @@ class ModelArgs(BaseModelArgs): time_step_max: float time_step_floor: float rescale_prenorm_residual: bool + use_cache: bool rms_norm: bool chunk_size: int tie_word_embeddings: bool - use_cache: bool = True time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf"))) time_step_rank: Union[int, str] = "auto" model_type: str = "mamba2" @@ -43,114 +44,62 @@ class ModelArgs(BaseModelArgs): if self.time_step_rank == "auto": self.time_step_rank = math.ceil(self.hidden_size / 16) + +def selective_scan(x, A, B, C, chunk_size): + """ + Selective scan implementation for training. -class MambaRMSNormGated(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - super().__init__() - self.weight = mx.ones((hidden_size,)) - self.variance_epsilon = eps + Arguments + x: (batch, seqlen, n_heads, d_head) + A: (batch, seqlen, n_heads) + B: (batch, seqlen, n_heads, d_state) + C: (batch, seqlen, n_heads, d_state) - def __call__(self, hidden_states, gate=None): - if gate is not None: - hidden_states = hidden_states * nn.silu(gate) - variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True) - hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states + Return + y: (batch, seqlen, n_heads, d_head) + """ + assert x.shape[1] % chunk_size == 0 - -def silu(x): - return x * mx.sigmoid(x) - -def ssd(x, A, B, C, chunk_size): - # Replace einsum operations with explicit reshape and matrix multiply - batch, seqlen, nheads, dim = x.shape - B = mx.expand_dims(B, axis=2) - C = mx.expand_dims(C, axis=2) + # Reshape into chunks + def chunk_reshape(m): + shape = list(m.shape) + shape[1:2] = [shape[1] // chunk_size, chunk_size] + return m.reshape(shape) - state = mx.zeros((batch, nheads, dim, B.shape[-1])) - outputs = [] + x, A, B, C = map(chunk_reshape, (x, A, B, C)) + A = mx.transpose(A, [0, 3, 1, 2]) - for i in range(0, seqlen, chunk_size): - chunk = slice(i, min(i + chunk_size, seqlen)) - dA = mx.exp(mx.expand_dims(A[chunk], axis=0)) - - # Replace einsum with explicit operations - x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim] - x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size] - B_chunk = B[:, chunk] # [batch, chunk_size, state_size] - dBx = mx.matmul(x_chunk, B_chunk) # [batch, nheads, dim, state_size] - - state = state * mx.expand_dims(dA, axis=-1) + dBx - - # Replace einsum with explicit operations - C_chunk = C[:, chunk] # [batch, chunk_size, state_size] - y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size] - y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim] - outputs.append(y) + # Compute cumulative sums + A_cumsum = mx.cumsum(A, axis=-1) - return mx.concatenate(outputs, axis=1), state + # Process chunks + L = mx.exp(selective_cumsum(A)) + Y_diag = mx.einsum('bclhn,bcshn,bhcls,bcshp->bclhp', C, B, L, x) + decay_states = mx.exp(A_cumsum[..., -1:] - A_cumsum) + states = mx.einsum('bclhn,bhcl,bclhp->bchpn', B, decay_states, x) + + initial_states = mx.zeros_like(states[:, :1]) + states = mx.concatenate([initial_states, states], axis=1) + decay_chunk = mx.exp(selective_cumsum(mx.pad(A_cumsum[..., -1], ((0,0), (0,0), (1,0))))) + new_states = mx.einsum('bhzc,bchpn->bzhpn', decay_chunk, states) + states = new_states[:, :-1] + + state_decay_out = mx.exp(A_cumsum) + Y_off = mx.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out) + + Y = (Y_diag + Y_off).reshape((-1, x.shape[1] * chunk_size, *Y_diag.shape[-2:])) + return Y -class DepthWiseConv1d(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.padding = padding - self.groups = groups if groups is not None else in_channels - - assert in_channels == out_channels, "In and out channels must be same for depthwise convolution" - assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution" - - self.weight = mx.random.normal((in_channels, 1, kernel_size)) - self.bias = mx.zeros((out_channels,)) if bias else None - - def __call__(self, x: mx.array, cache=None) -> mx.array: - B, L, C = x.shape - K = self.kernel_size - - assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}" - - if cache is not None and 'conv_states' in cache: - conv_states = cache['conv_states'] - if conv_states is not None: - assert conv_states.shape[0] == B, "Cache batch size mismatch" - assert conv_states.shape[2] == C, "Cache channel count mismatch" - x = mx.concatenate([conv_states, x], axis=1) - - # Process each channel independently - outputs = [] - for c in range(C): - x_c = x[:, :, c] - x_c = mx.expand_dims(x_c, axis=1) - - w_c = self.weight[c] - if w_c.ndim == 2: - w_c = mx.expand_dims(w_c, axis=0) - elif w_c.ndim == 1: - w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0) - - # Apply convolution - y_c = mx.conv_general( - x_c, - w_c, - stride=1, - padding=0 - ) - - if self.bias is not None: - y_c = y_c + self.bias[c] - - outputs.append(mx.squeeze(y_c, axis=1)) - - y = mx.stack(outputs, axis=-1) - - # Update cache - if cache is not None: - cache['conv_states'] = x[:, -K+1:, :] if x.shape[1] >= K else x - - return y +def selective_cumsum(x: mx.array) -> mx.array: + """Stable selective cumulative sum calculation.""" + T = x.shape[-1] + x = mx.repeat(x[..., None], T, axis=-1) + mask = mx.tril(mx.ones((T, T)), k=-1) + x = x * mask + x_cumsum = mx.cumsum(x, axis=-2) + mask = mx.tril(mx.ones((T, T)), k=0) + return mx.where(mask, x_cumsum, float('-inf')) class Mamba2Block(nn.Module): @@ -158,165 +107,172 @@ class Mamba2Block(nn.Module): super().__init__() self.args = args - d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads - self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias) + # Internal cache state + self.conv_state = None + self.ssm_state = None + + # Project input to get various components + d_in_proj = (2 * args.intermediate_size + 2 * self.args.n_groups * args.state_size + args.num_heads) + self.in_proj = nn.Linear( + args.hidden_size, + d_in_proj, + bias=args.use_bias + ) - conv_dim = args.intermediate_size + 2 * args.state_size - self.conv1d = DepthWiseConv1d( + # Convolution layer + conv_dim = args.intermediate_size + 2 * self.args.n_groups * args.state_size + self.conv1d = nn.Conv1d( in_channels=conv_dim, out_channels=conv_dim, kernel_size=args.conv_kernel, groups=conv_dim, - bias=args.use_conv_bias, - padding=args.conv_kernel - 1 + padding=args.conv_kernel - 1, + bias=args.use_conv_bias ) - self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range - self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range - self.D = mx.random.normal((args.num_heads,)) * args.initializer_range + # SSM parameters + dt_init_floor = math.log(args.time_step_floor) + self.dt_bias = mx.zeros((args.num_heads,)) * args.initializer_range + self.A_log = mx.zeros((args.num_heads,)) * args.initializer_range + self.D = mx.zeros((args.num_heads,)) * args.initializer_range - self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon) + # Output projections + self.norm = nn.RMSNorm(args.intermediate_size, eps=args.layer_norm_epsilon) self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) - if args.rescale_prenorm_residual: - layer_scale = math.sqrt(1.0 / args.num_hidden_layers) - self.out_proj.weight = self.out_proj.weight * layer_scale + def __call__(self, x: mx.array, cache=None) -> mx.array: + return self.forward_training(x) if x.shape[1] > 1 else self.forward_inference(x, cache) - def __call__(self, x: mx.array, cache=None): - if cache is not None: - return self.step(x, cache) - - # Regular forward pass code remains the same... - d_model = self.args.intermediate_size - d_state = self.args.state_size - n_heads = self.args.num_heads + def forward_training(self, u: mx.array) -> mx.array: + # Reset cache during training + self.cache = None - A = -mx.exp(self.A_log) - zxbcdt = self.in_proj(x) - - splits = [d_model, d_model + 2 * d_state, n_heads] - z = zxbcdt[:, :, :splits[0]] - xBC = zxbcdt[:, :, splits[0]:splits[0] + splits[1]] - dt = zxbcdt[:, :, -splits[2]:] + # Input projection and splitting + zxbcdt = self.in_proj(u) + z, xBC, dt = mx.split( + zxbcdt, + [ + self.args.intermediate_size, + self.args.intermediate_size + 2 * self.args.state_size + ], + axis=-1 + ) + # Time step processing dt = mx.clip( nn.softplus(dt + self.dt_bias), self.args.time_step_min, self.args.time_step_max ) - dt = mx.maximum(dt, self.args.time_step_floor) - xBC = silu(self.conv1d(xBC)) + # Convolution processing + xBC_t = mx.transpose(xBC, [0, 2, 1]) + conv_out = self.conv1d(xBC_t) + xBC = mx.transpose(conv_out, [0, 2, 1])[:, :u.shape[1]] + xBC = mx.sigmoid(xBC) * xBC # SiLU - x = xBC[:, :, :d_model] - B = xBC[:, :, d_model:d_model + d_state] - C = xBC[:, :, -d_state:] + # Split states + x, B, C = mx.split( + xBC, + [self.args.intermediate_size, self.args.state_size], + axis=-1 + ) - b, l, hp = x.shape - h = self.args.num_heads - p = hp // h - x = mx.reshape(x, (b, l, h, p)) + # Reshape for selective scan + x = x.reshape((-1, x.shape[1], self.args.num_heads, self.args.head_dim)) + A = -mx.exp(self.A_log) - y, ssm_state = ssd(x * mx.expand_dims(dt, -1), A * dt, B, C, self.args.chunk_size) - y = y + x * mx.expand_dims(self.D, -1) - y = mx.reshape(y, (b, l, h * p)) + # Apply selective scan + y = selective_scan( + x * dt[..., None], + A * dt, + B[..., None, :], + C[..., None, :], + self.args.chunk_size + ) - y = self.norm(y + z) + # Output processing + y = y + x * self.D[None, None, :, None] + y = y.reshape((-1, y.shape[1], self.args.intermediate_size)) + y = self.norm(y, z) y = self.out_proj(y) - - if self.args.residual_in_fp32: - y = y.astype(mx.float32) - + return y - def step(self, u: mx.array, cache): + def forward_inference(self, u: mx.array, cache=None) -> mx.array: + """Single token processing during inference.""" + assert u.shape[1] == 1, "Inference mode expects single token" + batch_size = u.shape[0] - seq_len = u.shape[1] - outputs = [] + # Use provided cache or create new one + self.cache = cache if cache is not None else Mamba2Cache.get_cache(self.args, batch_size, None) + + # Project input + zxbcdt = self.in_proj(mx.squeeze(u, 1)) + parts = mx.split( + zxbcdt, + [ + self.args.intermediate_size, + self.args.intermediate_size + 2 * self.args.state_size + ], + axis=-1 + ) + z, xBC = parts[0], parts[1] + dt = zxbcdt[:, -self.args.num_heads:] # Extract dt separately - # Initialize cache if needed - if cache.conv_states is None: - conv_dim = self.args.intermediate_size + 2 * self.args.state_size - cache.conv_states = mx.zeros(( - batch_size, - self.args.conv_kernel - 1, - conv_dim - )) - - if cache.ssm_state is None: - cache.ssm_state = mx.zeros(( - batch_size, - self.args.num_heads, - self.args.head_dim, - self.args.state_size - )) + # Update convolution state and apply + conv_state = self.cache.update_conv_state(xBC) + xBC = mx.sum( + conv_state * mx.transpose(self.conv1d.weight, [1, 0, 2]), + axis=-1 + ) + if self.args.use_conv_bias: + xBC = xBC + self.conv1d.bias + xBC = mx.sigmoid(xBC) * xBC # SiLU - for pos in range(seq_len): - u_t = u[:, pos:pos+1, :] - zxbcdt = self.in_proj(u_t) - - d_model = self.args.intermediate_size - d_state = self.args.state_size - n_heads = self.args.num_heads - - z = zxbcdt[:, :, :d_model] - xBC = zxbcdt[:, :, d_model:d_model + 2*d_state + d_model] - dt = zxbcdt[:, :, -(n_heads):] - - dt = mx.reshape(dt, (batch_size, n_heads)) - dt = mx.clip( - nn.softplus(dt + self.dt_bias), - self.args.time_step_min, - self.args.time_step_max - ) - dt = mx.maximum(dt, self.args.time_step_floor) + # Split states and ensure proper shapes + x_splits = mx.split( + xBC, + [self.args.intermediate_size, self.args.state_size], + axis=-1 + ) + x, B, C = x_splits[0], x_splits[1], x_splits[2] + + # Process time steps - ensure proper broadcasting + dt = mx.reshape(dt, (batch_size, self.args.num_heads)) + dt = mx.clip( + nn.softplus(dt + self.dt_bias[None, :]), + self.args.time_step_min, + self.args.time_step_max + ) + + # SSM step with explicit shapes + A = -mx.exp(self.A_log) + dA = mx.exp(dt * A[None, :]) # Shape: (batch_size, num_heads) + + # Reshape x considering intermediate size + # x shape should be (batch_size * num_heads, head_dim) + x = mx.reshape(x, (batch_size, self.args.num_heads, -1)) + assert x.shape[-1] == self.args.head_dim, f"Head dimension mismatch: {x.shape[-1]} vs {self.args.head_dim}" + + # Reshape B and C for ssm computation + B = mx.reshape(B, (batch_size, -1)) # Should be (batch_size, state_size) + C = mx.reshape(C, (batch_size, -1)) # Should be (batch_size, state_size) + + # Compute dBx with explicit shapes + dBx = mx.einsum('bh,bs,bhd->bhds', dt, B, x) + + ssm_state = self.cache.update_ssm_state(dA, dBx) + + y = mx.einsum('bhds,bs->bhd', ssm_state, C) + y = y + x * self.D[None, :, None] + y = mx.reshape(y, (batch_size, self.args.intermediate_size)) + + # Output processing + y = self.norm(y, z) + y = self.out_proj(y) - # Create a temporary cache dictionary for the convolution - conv_cache = {'conv_states': cache.conv_states} - xBC = self.conv1d(xBC, cache=conv_cache) - cache.conv_states = conv_cache['conv_states'] - - xBC = silu(xBC) - - x = xBC[:, :, :d_model] - B = xBC[:, :, d_model:d_model + d_state] - C = xBC[:, :, -d_state:] - - x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim)) - x = mx.squeeze(x, axis=1) - - B = mx.reshape(B, (batch_size, 1, d_state)) - B = mx.broadcast_to(B, (batch_size, n_heads, d_state)) - B = mx.expand_dims(B, axis=2) - - C = mx.reshape(C, (batch_size, 1, d_state)) - C = mx.broadcast_to(C, (batch_size, n_heads, d_state)) - C = mx.expand_dims(C, axis=3) - - A = -mx.exp(self.A_log) - dA = mx.exp(dt * mx.expand_dims(A, 0)) - dA = mx.expand_dims(mx.expand_dims(dA, -1), -1) - - x = mx.expand_dims(x, axis=3) - dBx = mx.matmul(x, B) - - cache.ssm_state = cache.ssm_state * dA + dBx - - y = mx.matmul(cache.ssm_state, C) - y = mx.squeeze(y, axis=-1) - - y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1) - - y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim)) - y = self.norm(y + z) - y = self.out_proj(y) - - if self.args.residual_in_fp32: - y = y.astype(mx.float32) - - outputs.append(y) - - return mx.concatenate(outputs, axis=1) + return mx.expand_dims(y, 1) class ResidualBlock(nn.Module): @@ -325,11 +281,11 @@ class ResidualBlock(nn.Module): self.mixer = Mamba2Block(args) self.norm = nn.RMSNorm(args.hidden_size) - def __call__(self, x: mx.array, cache): + def __call__(self, x: mx.array, cache=None) -> mx.array: return self.mixer(self.norm(x), cache) + x -class Mamba2(nn.Module): +class Mamba2Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args = args @@ -337,12 +293,12 @@ class Mamba2(nn.Module): self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) - def __call__(self, x: mx.array, cache): + def __call__(self, x: mx.array, cache=None) -> mx.array: x = self.embeddings(x) if cache is None: cache = [None] * len(self.layers) - for layer, c in zip(self.layers, cache): - x = layer(x, c) + for layer, layer_cache in zip(self.layers, cache): + x = layer(x, layer_cache) return self.norm_f(x) @@ -350,14 +306,12 @@ class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args = args - self.model_type = args.model_type + self.backbone = Mamba2Model(args) - self.backbone = Mamba2(args) - if not args.tie_word_embeddings: self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - def __call__(self, inputs: mx.array, cache=None): + def __call__(self, inputs: mx.array, cache=None) -> mx.array: B, T = inputs.shape x = self.backbone(inputs, cache) @@ -368,24 +322,19 @@ class Model(nn.Module): logits = self.lm_head(x) return logits - + def make_cache(self, batch_size=1): - return [Mamba2Cache() for _ in range(len(self.layers))] + return [Mamba2Cache( + batch_size=batch_size, + intermediate_size=self.args.intermediate_size, + state_size=self.args.state_size, + conv_kernel=self.args.conv_kernel, + num_heads=self.args.num_heads, + head_dim=self.args.head_dim + ) for _ in range(len(self.backbone.layers))] def sanitize(self, weights): - sanitized = {} for k, v in weights.items(): - if "conv1d.weight" in k: - # Ensure weights are in correct shape (channels, 1, kernel_size) - if v.ndim == 2: - v = mx.expand_dims(v, axis=1) - elif v.ndim == 1: - v = mx.expand_dims(mx.expand_dims(v, axis=0), axis=0) - sanitized[k] = v - else: - sanitized[k] = v - return sanitized - - @property - def layers(self): - return self.backbone.layers + if "conv1d.weight" in k and v.ndim == 3: + weights[k] = v.moveaxis(2, 1) + return weights \ No newline at end of file