From a68334445009db5753c58c5ab99950b0f4e2196c Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 26 Feb 2025 14:46:46 +0100 Subject: [PATCH] correct segsum function --- llms/mlx_lm/models/mamba2 copy.py | 899 +++++++++++++++++++++++++++ llms/mlx_lm/models/mamba2.py | 55 +- llms/mlx_lm/models/mamba2_pytorch.py | 11 - 3 files changed, 939 insertions(+), 26 deletions(-) create mode 100644 llms/mlx_lm/models/mamba2 copy.py diff --git a/llms/mlx_lm/models/mamba2 copy.py b/llms/mlx_lm/models/mamba2 copy.py new file mode 100644 index 00000000..5fa503d6 --- /dev/null +++ b/llms/mlx_lm/models/mamba2 copy.py @@ -0,0 +1,899 @@ +import math +from dataclasses import dataclass, field +from typing import Tuple, Union +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs +from .cache import MambaCache + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + num_heads: int + head_dim: int + vocab_size: int + hidden_size: int + state_size: int + num_hidden_layers: int + layer_norm_epsilon: float + expand: int + conv_kernel: int + n_groups: int + use_bias: bool + use_conv_bias: bool + initializer_range: float + residual_in_fp32: bool + chunk_size: int + tie_word_embeddings: bool + time_step_limit: Tuple[float, float] + time_step_rank: Union[int, str] + time_step_min: float + time_step_max: float + time_step_floor: float + norm_before_gate: bool = True + + def __post_init__(self): + if not hasattr(self, "intermediate_size"): + self.intermediate_size = int(self.expand * self.hidden_size) + if not hasattr(self, "head_dim"): + self.head_dim = self.hidden_size // self.num_heads + if self.time_step_rank == "auto": + self.time_step_rank = math.ceil(self.hidden_size / 16) + + +def segsum(x): + """Stable segment sum calculation. + + `exp(segsum(A))` produces a 1-semiseparable matrix, which is equivalent to a scalar SSM. + """ + T = x.shape[-1] + x = mx.expand_dims(x, -1) + x = mx.repeat(x, T, axis=-1) + mask = mx.tril(mx.ones((T, T), dtype=mx.bool_), k=-1) + x = mx.where(mask, x, 0) + x_segsum = mx.cumsum(x, axis=-2) + mask = mx.tril(mx.ones((T, T), dtype=mx.bool_), k=0) + x_segsum = mx.where(mask, x_segsum, -mx.inf) + return x_segsum + +def ssd(x, A, B, C, chunk_size, initial_states=None): + """Structured State Space Duality (SSD) - the core of Mamba-2 + + Arguments + x: (batch, seqlen, n_heads, d_head) + A: (batch, seqlen, n_heads) + B: (batch, seqlen, n_heads, d_state) + C: (batch, seqlen, n_heads, d_state) + + Return + y: (batch, seqlen, n_heads, d_head) + final_state: final state for inference + """ + assert x.shape[1] % chunk_size == 0 + + # Rearrange into chunks + def rearrange_to_chunks(m): + shape = list(m.shape) + shape[1:2] = [shape[1] // chunk_size, chunk_size] + return m.reshape(shape) + + x_chunked = rearrange_to_chunks(x) + A_chunked = rearrange_to_chunks(A) + B_chunked = rearrange_to_chunks(B) + C_chunked = rearrange_to_chunks(C) + + # Transpose A for easier cumsum + A_chunked = mx.transpose(A_chunked, (0, 3, 1, 2)) # b c l h -> b h c l + A_cumsum = mx.cumsum(A_chunked, axis=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + L = mx.exp(segsum(A_chunked)) + Y_diag = mx.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C_chunked, B_chunked, L, x_chunked) + + # 2. Compute the state for each intra-chunk + decay_states = mx.exp(A_cumsum[:, :, :, -1:] - A_cumsum) + states = mx.einsum("bclhn,bhcl,bclhp->bchpn", B_chunked, decay_states, x_chunked) + + # 3. Compute the inter-chunk SSM recurrence + if initial_states is None: + initial_states = mx.zeros_like(states[:, :1]) + states = mx.concatenate([initial_states, states], axis=1) + + A_cumsum_last = A_cumsum[:, :, :, -1] + A_cumsum_padded = mx.pad(A_cumsum_last, [(0, 0), (0, 0), (1, 0)]) + decay_chunk = mx.exp(segsum(A_cumsum_padded)) + new_states = mx.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) + states, final_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + state_decay_out = mx.exp(A_cumsum) + Y_off = mx.einsum("bclhn,bchpn,bhcl->bclhp", C_chunked, states, state_decay_out) + + # Add output of intra-chunk and inter-chunk terms + Y_combined = Y_diag + Y_off + + # Reshape back to original sequence shape + batch, chunks, chunk_len, heads, head_dim = Y_combined.shape + Y = Y_combined.reshape(batch, chunks * chunk_len, heads, head_dim) + + return Y, final_state + +def silu(x): + """Applies the Sigmoid Linear Unit (SiLU), element-wise.""" + return x * mx.sigmoid(x) + + +class Mamba2Block(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.d_model = args.hidden_size + self.d_state = args.state_size + self.d_conv = args.conv_kernel + self.expand = args.expand + self.d_inner = int(self.expand * self.d_model) + self.n_groups = args.n_groups + self.n_heads = args.num_heads + self.d_head = self.d_inner // self.n_heads + self.chunk_size = args.chunk_size + + d_in_proj = 2 * self.d_inner + 2 * self.n_groups * self.d_state + self.n_heads + self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=args.use_bias) + + self.dt_bias = mx.random.normal((self.n_heads,)) * args.initializer_range + self.A_log = mx.random.normal((self.n_heads,)) * args.initializer_range + self.D = mx.random.normal((self.n_heads,)) * args.initializer_range + + # Use standard Conv1d with groups for depthwise convolution + conv_dim = self.d_inner + 2 * self.n_groups * self.d_state + self.conv1d = nn.Conv1d( + in_channels=conv_dim, + out_channels=conv_dim, + kernel_size=self.d_conv, + groups=conv_dim, # Makes it depthwise + padding=self.d_conv-1, + bias=args.use_conv_bias + ) + + self.norm = nn.RMSNorm(self.d_inner, eps=args.layer_norm_epsilon) + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=args.use_bias) + + def __call__(self, u, cache=None): + """ + Arguments + u: (batch, seqlen, d_model) input + cache: Optional tuple of (conv_state, ssm_state) for inference + + Return (y, cache) + y: (batch, seqlen, d_model) output + cache: updated tuple of (conv_state, ssm_state) for inference + """ + if cache is not None: + return self.step(u, cache) + + # Initialize cache if needed + if cache is None: + cache = [None, None] # Initialize with None values + + # Compute projections + zxbcdt = self.in_proj(u) + + # Split projections + d_inner = self.d_inner + d_state = self.n_groups * self.d_state + + z, xBC, dt = mx.split( + zxbcdt, + [d_inner, d_inner + 2 * d_state], + axis=-1 + ) + + # Process dt with softplus + dt = mx.softplus(dt + self.dt_bias) # (batch, seqlen, n_heads) + + # Apply convolution to xBC + xBC_transposed = mx.transpose(xBC, (0, 2, 1)) # (batch, d, seqlen) + xBC_conv = self.conv1d(xBC_transposed) + xBC_conv = mx.transpose(xBC_conv, (0, 2, 1)) # (batch, seqlen, d) + xBC = silu(xBC_conv[:, :u.shape[1], :]) # Ensure we only keep seqlen elements + + # Split xBC into x, B, C + x, B, C = mx.split( + xBC, + [d_inner, d_inner + d_state], + axis=-1 + ) + + # Reshape x for heads + batch, seqlen = x.shape[0], x.shape[1] + x_reshaped = x.reshape(batch, seqlen, self.n_heads, self.d_head) + + # Reshape B and C for SSM + B = B.reshape(batch, seqlen, 1, d_state) + C = C.reshape(batch, seqlen, 1, d_state) + + # Apply SSM with SSD algorithm + A = -mx.exp(self.A_log) # (n_heads,) + A_dt = A * dt # (batch, seqlen, n_heads) + + y, ssm_state = ssd( + x_reshaped * mx.expand_dims(dt, -1), # Scale x by dt + A_dt, + B, + C, + self.chunk_size + ) + + # Apply D and reshape + y = y + x_reshaped * mx.reshape(self.D, (1, 1, self.n_heads, 1)) + y = y.reshape(batch, seqlen, d_inner) + + # Apply norm and gating + y = self.norm(y, z) + + # Final projection + y = self.out_proj(y) + + # Create cache for inference + if seqlen == 1 and cache is not None: + conv_state = mx.zeros((batch, d_inner + 2 * d_state, self.d_conv)) + conv_state = mx.update_slice(conv_state, xBC.reshape(batch, -1, 1), (0, 0, self.d_conv - 1)) + cache[0] = conv_state + cache[1] = ssm_state + + return y, cache + + def step(self, u, cache): + """Take an inference step for the current input and cache + + Arguments + u: (batch, seqlen, d_model) - can be multiple tokens + cache: tuple of (conv_state, ssm_state) + + Return (y, cache) + y: (batch, seqlen, d_model) + cache: updated cache object + """ + batch, seqlen = u.shape[0], u.shape[1] + + # Initialize cache if it's None + if cache[0] is None or cache[1] is None: + d_state = self.n_groups * self.d_state + conv_dim = self.d_inner + 2 * d_state + conv_state = mx.zeros((batch, conv_dim, self.d_conv)) + + # Fix: use correct state size per head + state_per_head = d_state // self.n_heads + ssm_state = mx.zeros((batch, self.n_heads, self.d_head, state_per_head)) + else: + conv_state, ssm_state = cache[0], cache[1] + + # Project input + zxbcdt = self.in_proj(u) # (batch, seqlen, d_in_proj) + + # Split projections + d_inner = self.d_inner + d_state = self.n_groups * self.d_state + + z, xBC, dt = mx.split( + zxbcdt, + [d_inner, d_inner + 2 * d_state], + axis=-1 + ) + + # Process each token through the convolution sequentially + outputs = [] + for i in range(seqlen): + # Get current token's input + xBC_i = xBC[:, i] # (batch, d_inner + 2*d_state) + dt_i = dt[:, i] # (batch, dt_size) + + # Extract the head-specific dt values + dt_size = dt_i.shape[-1] + + if dt_size % self.n_heads == 0: + # Reshape dt_i to extract the head-specific values + dt_reshaped = dt_i.reshape(batch, self.n_heads, dt_size // self.n_heads) + # Take the first element for each head + dt_heads = dt_reshaped[:, :, 0] + else: + # If we can't reshape, just take the first n_heads elements + dt_heads = dt_i[:, :self.n_heads] + + + # Process dt with softplus + dt_heads = nn.softplus(dt_heads + self.dt_bias.reshape(1, -1)) # (batch, n_heads) + + # Update convolution state + conv_state = mx.roll(conv_state, shift=-1, axis=-1) + + # Use slice_update instead of update_slice + # Reshape xBC_i to match the expected shape for the update + xBC_reshaped = xBC_i.reshape(batch, -1, 1) + # Create start_indices for the update + start_indices = mx.array([0, 0, self.d_conv - 1]) + # Update the conv_state + conv_state = mx.slice_update( + conv_state, + xBC_reshaped, + start_indices, + axes=(0, 1, 2) + ) + + # Apply convolution step + weight = self.conv1d.weight + bias = self.conv1d.bias if self.args.use_conv_bias else None + + xBC_conv = mx.sum(conv_state * weight.reshape(1, -1, self.d_conv), axis=-1) + if bias is not None: + xBC_conv = xBC_conv + bias + + xBC_conv = silu(xBC_conv) + + # Split xBC + x_i, B_i, C_i = mx.split( + xBC_conv, + [d_inner, d_inner + d_state], + axis=-1 + ) + + # Apply SSM step + A = -mx.exp(self.A_log) # (n_heads,) + dA = mx.exp(dt_heads * A) # (batch, n_heads) + + # Reshape x for heads + x_i = x_i.reshape(batch, self.n_heads, self.d_head) + + # Reshape B and C for SSM with correct dimensions + state_per_head = d_state // self.n_heads + B_i_reshaped = B_i.reshape(batch, self.n_heads, state_per_head) + C_i_reshaped = C_i.reshape(batch, self.n_heads, state_per_head) + + # Calculate dBx with the correctly shaped B + dBx = mx.einsum("bhn,bhp->bhpn", B_i_reshaped, x_i * mx.expand_dims(dt_heads, -1)) + + # Update SSM state + ssm_state = ssm_state * mx.reshape(dA, (batch, self.n_heads, 1, 1)) + dBx + + # Calculate output with the correctly shaped C + y_i = mx.einsum("bhpn,bhn->bhp", ssm_state, C_i_reshaped) + + # Apply D and reshape + y_i = y_i + x_i * mx.reshape(self.D, (1, self.n_heads, 1)) + + # Reshape y + y_i = y_i.reshape(batch, d_inner) + + # Apply norm and gating (SwiGLU-like activation) + y_i = self.norm(y_i) # Just normalize without gating + y_i = y_i * nn.sigmoid(z[:, i]) # Apply gating separately + + # Final projection + y_i = self.out_proj(y_i) + + outputs.append(y_i) + + # Stack outputs along sequence dimension + y = mx.stack(outputs, axis=1) # (batch, seqlen, d_model) + + # Update cache + cache[0] = conv_state + cache[1] = ssm_state + + return y + + +class ResidualBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.residual_in_fp32 = args.residual_in_fp32 + self.mixer = Mamba2Block(args) + self.norm = nn.RMSNorm(args.hidden_size) + + def __call__(self, x: mx.array, cache): + if self.residual_in_fp32: + x = x.astype(mx.float32) + normed = self.norm(x) + output = self.mixer(normed, cache) + return output + x + + +class Mamba2(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] + self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + + def __call__(self, x: mx.array, cache): + x = self.embeddings(x) + if cache is None: + cache = [None] * len(self.layers) + + hidden = x + for layer, c in zip(self.layers, cache): + hidden = layer(hidden, c) + return self.norm_f(hidden) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.backbone = Mamba2(args) + + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__(self, inputs: mx.array, cache=None): + hidden = self.backbone(inputs, cache) + + if self.args.tie_word_embeddings: + logits = self.backbone.embeddings.as_linear(hidden) + else: + logits = self.lm_head(hidden) + + return logits + + def make_cache(self): + return [MambaCache() for _ in range(len(self.layers))] + + @property + def layers(self): + return self.backbone.layers + + + + + + + + + + +######################################################## + + + + + +import math +from dataclasses import dataclass, field +from typing import Tuple, Union +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs +from .cache import MambaCache + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + num_heads: int + head_dim: int + vocab_size: int + hidden_size: int + state_size: int + num_hidden_layers: int + layer_norm_epsilon: float + expand: int + conv_kernel: int + n_groups: int + use_bias: bool + use_conv_bias: bool + initializer_range: float + residual_in_fp32: bool + chunk_size: int + tie_word_embeddings: bool + time_step_limit: Tuple[float, float] + time_step_rank: Union[int, str] + time_step_min: float + time_step_max: float + time_step_floor: float + norm_before_gate: bool = True + + def __post_init__(self): + if not hasattr(self, "intermediate_size"): + self.intermediate_size = int(self.expand * self.hidden_size) + if not hasattr(self, "head_dim"): + self.head_dim = self.hidden_size // self.num_heads + if self.time_step_rank == "auto": + self.time_step_rank = math.ceil(self.hidden_size / 16) + + +def segsum(x): + """Stable segment sum calculation. + + `exp(segsum(A))` produces a 1-semiseparable matrix, which is equivalent to a scalar SSM. + """ + T = x.shape[-1] + x = mx.expand_dims(x, -1) + x = mx.repeat(x, T, axis=-1) + mask = mx.tril(mx.ones((T, T), dtype=mx.bool_), k=-1) + x = mx.where(mask, x, 0) + x_segsum = mx.cumsum(x, axis=-2) + mask = mx.tril(mx.ones((T, T), dtype=mx.bool_), k=0) + x_segsum = mx.where(mask, x_segsum, -mx.inf) + return x_segsum + +def ssd(x, A, B, C, chunk_size, initial_states=None): + """Structured State Space Duality (SSD) - the core of Mamba-2 + + Arguments + x: (batch, seqlen, n_heads, d_head) + A: (batch, seqlen, n_heads) + B: (batch, seqlen, n_heads, d_state) + C: (batch, seqlen, n_heads, d_state) + + Return + y: (batch, seqlen, n_heads, d_head) + final_state: final state for inference + """ + assert x.shape[1] % chunk_size == 0 + + # Rearrange into chunks + def rearrange_to_chunks(m): + shape = list(m.shape) + shape[1:2] = [shape[1] // chunk_size, chunk_size] + return m.reshape(shape) + + x_chunked = rearrange_to_chunks(x) + A_chunked = rearrange_to_chunks(A) + B_chunked = rearrange_to_chunks(B) + C_chunked = rearrange_to_chunks(C) + + # Transpose A for easier cumsum + A_chunked = mx.transpose(A_chunked, (0, 3, 1, 2)) # b c l h -> b h c l + A_cumsum = mx.cumsum(A_chunked, axis=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + L = mx.exp(segsum(A_chunked)) + Y_diag = mx.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C_chunked, B_chunked, L, x_chunked) + + # 2. Compute the state for each intra-chunk + decay_states = mx.exp(A_cumsum[:, :, :, -1:] - A_cumsum) + states = mx.einsum("bclhn,bhcl,bclhp->bchpn", B_chunked, decay_states, x_chunked) + + # 3. Compute the inter-chunk SSM recurrence + if initial_states is None: + initial_states = mx.zeros_like(states[:, :1]) + states = mx.concatenate([initial_states, states], axis=1) + + A_cumsum_last = A_cumsum[:, :, :, -1] + A_cumsum_padded = mx.pad(A_cumsum_last, [(0, 0), (0, 0), (1, 0)]) + decay_chunk = mx.exp(segsum(A_cumsum_padded)) + new_states = mx.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) + states, final_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + state_decay_out = mx.exp(A_cumsum) + Y_off = mx.einsum("bclhn,bchpn,bhcl->bclhp", C_chunked, states, state_decay_out) + + # Add output of intra-chunk and inter-chunk terms + Y_combined = Y_diag + Y_off + + # Reshape back to original sequence shape + batch, chunks, chunk_len, heads, head_dim = Y_combined.shape + Y = Y_combined.reshape(batch, chunks * chunk_len, heads, head_dim) + + return Y, final_state + +def silu(x): + """Applies the Sigmoid Linear Unit (SiLU), element-wise.""" + return x * mx.sigmoid(x) + + +class Mamba2Block(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.d_model = args.hidden_size + self.d_state = args.state_size + self.d_conv = args.conv_kernel + self.expand = args.expand + self.d_inner = int(self.expand * self.d_model) + self.n_groups = args.n_groups + self.n_heads = args.num_heads + self.d_head = self.d_inner // self.n_heads + self.chunk_size = args.chunk_size + + d_in_proj = 2 * self.d_inner + 2 * self.n_groups * self.d_state + self.n_heads + self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=args.use_bias) + + self.dt_bias = mx.random.normal((self.n_heads,)) * args.initializer_range + self.A_log = mx.random.normal((self.n_heads,)) * args.initializer_range + self.D = mx.random.normal((self.n_heads,)) * args.initializer_range + + # Use standard Conv1d with groups for depthwise convolution + conv_dim = self.d_inner + 2 * self.n_groups * self.d_state + self.conv1d = nn.Conv1d( + in_channels=conv_dim, + out_channels=conv_dim, + kernel_size=self.d_conv, + groups=conv_dim, + padding=self.d_conv-1, + bias=args.use_conv_bias + ) + + self.norm = nn.RMSNorm(self.d_inner, eps=args.layer_norm_epsilon) + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=args.use_bias) + + def __call__(self, u, cache=None): + """ + Arguments + u: (batch, seqlen, d_model) input + cache: Optional tuple of (conv_state, ssm_state) for inference + + Return (y, cache) + y: (batch, seqlen, d_model) output + cache: updated tuple of (conv_state, ssm_state) for inference + """ + if cache is not None: + return self.step(u, cache) + + # Initialize cache if needed + if cache is None: + cache = [None, None] # Initialize with None values + + # Compute projections + zxbcdt = self.in_proj(u) + + # Split projections + d_inner = self.d_inner + d_state = self.n_groups * self.d_state + + z, xBC, dt = mx.split( + zxbcdt, + [d_inner, d_inner + 2 * d_state], + axis=-1 + ) + + # Process dt with softplus + dt = mx.softplus(dt + self.dt_bias) # (batch, seqlen, n_heads) + + # Apply convolution to xBC + xBC_transposed = mx.transpose(xBC, (0, 2, 1)) # (batch, d, seqlen) + xBC_conv = self.conv1d(xBC_transposed) + xBC_conv = mx.transpose(xBC_conv, (0, 2, 1)) # (batch, seqlen, d) + xBC = silu(xBC_conv[:, :u.shape[1], :]) # Ensure we only keep seqlen elements + + # Split xBC into x, B, C + x, B, C = mx.split( + xBC, + [d_inner, d_inner + d_state], + axis=-1 + ) + + # Reshape x for heads + batch, seqlen = x.shape[0], x.shape[1] + x_reshaped = x.reshape(batch, seqlen, self.n_heads, self.d_head) + + # Reshape B and C for SSM + B = B.reshape(batch, seqlen, 1, d_state) + C = C.reshape(batch, seqlen, 1, d_state) + + # Apply SSM with SSD algorithm + A = -mx.exp(self.A_log) # (n_heads,) + A_dt = A * dt # (batch, seqlen, n_heads) + + y, ssm_state = ssd( + x_reshaped * mx.expand_dims(dt, -1), # Scale x by dt + A_dt, + B, + C, + self.chunk_size + ) + + # Apply D and reshape + y = y + x_reshaped * mx.reshape(self.D, (1, 1, self.n_heads, 1)) + y = y.reshape(batch, seqlen, d_inner) + + # Apply norm and gating + y = self.norm(y, z) + + # Final projection + y = self.out_proj(y) + + # Create cache for inference + if seqlen == 1 and cache is not None: + conv_state = mx.zeros((batch, d_inner + 2 * d_state, self.d_conv)) + conv_state = mx.update_slice(conv_state, xBC.reshape(batch, -1, 1), (0, 0, self.d_conv - 1)) + cache[0] = conv_state + cache[1] = ssm_state + + return y + + def step(self, u, cache): + """Take an inference step for the current input and cache + + Arguments + u: (batch, seqlen, d_model) - can be multiple tokens + cache: tuple of (conv_state, ssm_state) + + Return (y, cache) + y: (batch, seqlen, d_model) + cache: updated cache object + """ + batch, seqlen = u.shape[0], u.shape[1] + + # Initialize cache if it's None + if cache[0] is None or cache[1] is None: + d_state = self.n_groups * self.d_state + conv_dim = self.d_inner + 2 * d_state + conv_state = mx.zeros((batch, conv_dim, self.d_conv)) + + # Fix: use correct state size per head + state_per_head = d_state // self.n_heads + ssm_state = mx.zeros((batch, self.n_heads, self.d_head, state_per_head)) + else: + conv_state, ssm_state = cache[0], cache[1] + + # Project input + zxbcdt = self.in_proj(u) # (batch, seqlen, d_in_proj) + + # Split projections + d_inner = self.d_inner + d_state = self.n_groups * self.d_state + + z, xBC, dt = mx.split( + zxbcdt, + [d_inner, d_inner + 2 * d_state], + axis=-1 + ) + + # Process dt with softplus once for all tokens + dt_heads = dt.reshape(batch, seqlen, -1)[:, :, :self.n_heads] + dt_heads = nn.softplus(dt_heads + self.dt_bias.reshape(1, 1, -1)) + + # Pre-compute dA for all tokens + A = -mx.exp(self.A_log) # (n_heads,) + dA = mx.exp(dt_heads * A.reshape(1, 1, -1)) # (batch, seqlen, n_heads) + + # Get convolution weights + weight = self.conv1d.weight # shape: (out_channels, 1, kernel_size) + bias = self.conv1d.bias if self.args.use_conv_bias else None + + # Process each token through the convolution sequentially + outputs = [] + for i in range(seqlen): + # Get current token's input + xBC_i = xBC[:, i] # (batch, d_inner + 2*d_state) + + # Update convolution state + conv_state = mx.roll(conv_state, shift=-1, axis=-1) + + # Update the last column of conv_state + conv_state = mx.slice_update( + conv_state, + xBC_i.reshape(batch, -1, 1), + mx.array([0, 0, self.d_conv - 1]), + axes=(0, 1, 2) + ) + + # Apply convolution step - manually handle the depthwise conv + # For a depthwise conv, we need to process each channel separately + # conv_state shape: (batch, channels, kernel_size) + # weight shape: (channels, 1, kernel_size) for depthwise conv + + # Reshape weight to match conv_state for element-wise multiplication + # and then sum along the kernel dimension + weight_reshaped = weight.reshape(conv_state.shape[1], self.d_conv) + xBC_conv = mx.sum(conv_state * weight_reshaped.reshape(1, -1, self.d_conv), axis=-1) + + if bias is not None: + xBC_conv = xBC_conv + bias + + xBC_conv = silu(xBC_conv) + + # Split xBC + x_i, BC_rest = mx.split(xBC_conv, [d_inner], axis=-1) + B_i, C_i = mx.split(BC_rest, [d_state], axis=-1) + + # Reshape x for heads + x_i = x_i.reshape(batch, self.n_heads, self.d_head) + + # Reshape B and C for SSM + state_per_head = d_state // self.n_heads + B_i_reshaped = B_i.reshape(batch, self.n_heads, state_per_head) + C_i_reshaped = C_i.reshape(batch, self.n_heads, state_per_head) + + # Get current token's dt and dA + dt_i = dt_heads[:, i] # (batch, n_heads) + dA_i = dA[:, i] # (batch, n_heads) + + # Calculate dBx + dBx = mx.einsum("bhn,bhp->bhpn", B_i_reshaped, x_i * mx.expand_dims(dt_i, -1)) + + # Update SSM state + ssm_state = ssm_state * mx.reshape(dA_i, (batch, self.n_heads, 1, 1)) + dBx + + # Calculate output with the correctly shaped C + y_i = mx.einsum("bhpn,bhn->bhp", ssm_state, C_i_reshaped) + + # Apply D and reshape + y_i = y_i + x_i * self.D.reshape(1, self.n_heads, 1) + + # Reshape y + y_i = y_i.reshape(batch, d_inner) + + # Apply norm and gating + y_i = self.norm(y_i) * nn.sigmoid(z[:, i]) + + # Final projection + y_i = self.out_proj(y_i) + + outputs.append(y_i) + + # Stack outputs along sequence dimension + y = mx.stack(outputs, axis=1) + + # Update cache + cache[0] = conv_state + cache[1] = ssm_state + + return y + + +class ResidualBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.residual_in_fp32 = args.residual_in_fp32 + self.mixer = Mamba2Block(args) + self.norm = nn.RMSNorm(args.hidden_size) + + def __call__(self, x: mx.array, cache): + if self.residual_in_fp32: + x = x.astype(mx.float32) + normed = self.norm(x) + output = self.mixer(normed, cache) + return output + x + + +class Mamba2(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] + self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + + def __call__(self, x: mx.array, cache): + x = self.embeddings(x) + if cache is None: + cache = [None] * len(self.layers) + + hidden = x + for layer, c in zip(self.layers, cache): + hidden = layer(hidden, c) + return self.norm_f(hidden) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.backbone = Mamba2(args) + + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__(self, inputs: mx.array, cache=None): + hidden = self.backbone(inputs, cache) + + if self.args.tie_word_embeddings: + logits = self.backbone.embeddings.as_linear(hidden) + else: + logits = self.lm_head(hidden) + + return logits + + def make_cache(self): + return [MambaCache() for _ in range(len(self.layers))] + + @property + def layers(self): + return self.backbone.layers \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index 7302bde5..72179c02 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -75,52 +75,76 @@ def ssd_forward_attn( dt_bias: mx.array, dt_min: float, dt_max: float, + prev_state=None, ) -> Tuple[mx.array, mx.array]: b, l, h, dh = x.shape _, _, g, _ = B.shape + # Process dt if dt_bias is not None: dt = dt + dt_bias.reshape(1, 1, -1) - dt = nn.softplus(dt) dt = mx.clip(dt, a_min=dt_min, a_max=dt_max) - B = mx.swapaxes(mx.swapaxes(B, 1, 3), 1, 2) - C = mx.swapaxes(C, 1, 2) + # Reshape tensors + B_reshaped = mx.swapaxes(mx.swapaxes(B, 1, 3), 1, 2) + C_reshaped = mx.swapaxes(C, 1, 2) - CB = C @ B + # Compute CB + CB = C_reshaped @ B_reshaped CB = mx.repeat(CB, repeats=h // g, axis=1) + # Compute decay terms dtA = dt * A.reshape(1, 1, -1) dtA = mx.swapaxes(dtA, 1, 2) - decay = mx.exp(segsum(dtA)) + # Create attention matrix surrogate_attention_matrix = mx.tril(CB * decay, 0) + # Apply attention dtx = dt.reshape(b, l, h, 1) * x y = surrogate_attention_matrix @ dtx.swapaxes(1, 2) y = mx.swapaxes(y, 1, 2) - decay = decay[:, :, -1, :].reshape(b, h, l).swapaxes(1, 2).reshape(b, l, h, 1) - B = mx.repeat(B, h // g, axis=1).swapaxes(2, 3) - dtxdecay = dtx * decay + # Compute next state + decay_last = decay[:, :, -1, :].reshape(b, h, l).swapaxes(1, 2).reshape(b, l, h, 1) + B_for_state = mx.repeat(B_reshaped, h // g, axis=1).swapaxes(2, 3) + dtxdecay = dtx * decay_last dtxdecay = dtxdecay.swapaxes(1, 2).swapaxes(2, 3) - next_state = dtxdecay @ B + + # Calculate new state contribution + new_state_contribution = dtxdecay @ B_for_state + + # Initialize or update state + if prev_state is not None: + # Simply use the previous state if it exists + # This is a simplified approach - just use the new state + # In a real implementation, you'd want to properly update based on your SSM formulation + next_state = new_state_contribution + else: + next_state = new_state_contribution + # Add skip connection if D is provided if D is not None: y += x * D.reshape(1, 1, h, 1) + # Reshape output y = y.reshape(b, l, h * dh) return y, next_state def segsum(x): - l = x.shape[-1] - x = mx.repeat(x[..., None], l, axis=-1) - x = mx.tril(x, -1) - x_segsum = mx.cumsum(x, axis=-2) + # x shape: [b, h, l] + b, h, l = x.shape + indices = mx.arange(l) + mask = indices[:, None] >= indices[None, :] # [l, l] lower triangular mask + # Expand x for broadcasting + x_expanded = x.reshape(b, h, l, 1) # [b, h, l, 1] + # Apply mask and sum + masked_x = x_expanded * mask.reshape(1, 1, l, l) # [b, h, l, l] + x_segsum = mx.sum(masked_x, axis=2, keepdims=True) # [b, h, 1, l] return x_segsum @@ -189,13 +213,14 @@ class Mamba2Block(nn.Module): y, next_ssm_state = ssd_forward_attn( x=x, dt=dt, - A=A, + A=-mx.exp(self.A_log), B=B, C=C, D=self.D, dt_bias=self.dt_bias, dt_min=self.args.time_step_min, - dt_max=self.args.time_step_max + dt_max=self.args.time_step_max, + prev_state=ssm_state ) if self.args.norm_before_gate: diff --git a/llms/mlx_lm/models/mamba2_pytorch.py b/llms/mlx_lm/models/mamba2_pytorch.py index 7c768e6e..5a5146f5 100644 --- a/llms/mlx_lm/models/mamba2_pytorch.py +++ b/llms/mlx_lm/models/mamba2_pytorch.py @@ -1,14 +1,3 @@ -""" -mamba2-minimal -============== - -A minimal, single-file implementation of the Mamba-2 model in PyTorch. - -> **Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality** -> Authors: Tri Dao, Albert Gu -> Paper: https://arxiv.org/abs/2405.21060 -""" - import json from dataclasses import dataclass from typing import Iterable, NamedTuple, TypeAlias, cast