diff --git a/llms/mlx_lm/models/mamba2 copy.py b/llms/mlx_lm/models/mamba2 copy.py deleted file mode 100644 index 5fa503d6..00000000 --- a/llms/mlx_lm/models/mamba2 copy.py +++ /dev/null @@ -1,899 +0,0 @@ -import math -from dataclasses import dataclass, field -from typing import Tuple, Union -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs -from .cache import MambaCache - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - num_heads: int - head_dim: int - vocab_size: int - hidden_size: int - state_size: int - num_hidden_layers: int - layer_norm_epsilon: float - expand: int - conv_kernel: int - n_groups: int - use_bias: bool - use_conv_bias: bool - initializer_range: float - residual_in_fp32: bool - chunk_size: int - tie_word_embeddings: bool - time_step_limit: Tuple[float, float] - time_step_rank: Union[int, str] - time_step_min: float - time_step_max: float - time_step_floor: float - norm_before_gate: bool = True - - def __post_init__(self): - if not hasattr(self, "intermediate_size"): - self.intermediate_size = int(self.expand * self.hidden_size) - if not hasattr(self, "head_dim"): - self.head_dim = self.hidden_size // self.num_heads - if self.time_step_rank == "auto": - self.time_step_rank = math.ceil(self.hidden_size / 16) - - -def segsum(x): - """Stable segment sum calculation. - - `exp(segsum(A))` produces a 1-semiseparable matrix, which is equivalent to a scalar SSM. - """ - T = x.shape[-1] - x = mx.expand_dims(x, -1) - x = mx.repeat(x, T, axis=-1) - mask = mx.tril(mx.ones((T, T), dtype=mx.bool_), k=-1) - x = mx.where(mask, x, 0) - x_segsum = mx.cumsum(x, axis=-2) - mask = mx.tril(mx.ones((T, T), dtype=mx.bool_), k=0) - x_segsum = mx.where(mask, x_segsum, -mx.inf) - return x_segsum - -def ssd(x, A, B, C, chunk_size, initial_states=None): - """Structured State Space Duality (SSD) - the core of Mamba-2 - - Arguments - x: (batch, seqlen, n_heads, d_head) - A: (batch, seqlen, n_heads) - B: (batch, seqlen, n_heads, d_state) - C: (batch, seqlen, n_heads, d_state) - - Return - y: (batch, seqlen, n_heads, d_head) - final_state: final state for inference - """ - assert x.shape[1] % chunk_size == 0 - - # Rearrange into chunks - def rearrange_to_chunks(m): - shape = list(m.shape) - shape[1:2] = [shape[1] // chunk_size, chunk_size] - return m.reshape(shape) - - x_chunked = rearrange_to_chunks(x) - A_chunked = rearrange_to_chunks(A) - B_chunked = rearrange_to_chunks(B) - C_chunked = rearrange_to_chunks(C) - - # Transpose A for easier cumsum - A_chunked = mx.transpose(A_chunked, (0, 3, 1, 2)) # b c l h -> b h c l - A_cumsum = mx.cumsum(A_chunked, axis=-1) - - # 1. Compute the output for each intra-chunk (diagonal blocks) - L = mx.exp(segsum(A_chunked)) - Y_diag = mx.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C_chunked, B_chunked, L, x_chunked) - - # 2. Compute the state for each intra-chunk - decay_states = mx.exp(A_cumsum[:, :, :, -1:] - A_cumsum) - states = mx.einsum("bclhn,bhcl,bclhp->bchpn", B_chunked, decay_states, x_chunked) - - # 3. Compute the inter-chunk SSM recurrence - if initial_states is None: - initial_states = mx.zeros_like(states[:, :1]) - states = mx.concatenate([initial_states, states], axis=1) - - A_cumsum_last = A_cumsum[:, :, :, -1] - A_cumsum_padded = mx.pad(A_cumsum_last, [(0, 0), (0, 0), (1, 0)]) - decay_chunk = mx.exp(segsum(A_cumsum_padded)) - new_states = mx.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) - states, final_state = new_states[:, :-1], new_states[:, -1] - - # 4. Compute state -> output conversion per chunk - state_decay_out = mx.exp(A_cumsum) - Y_off = mx.einsum("bclhn,bchpn,bhcl->bclhp", C_chunked, states, state_decay_out) - - # Add output of intra-chunk and inter-chunk terms - Y_combined = Y_diag + Y_off - - # Reshape back to original sequence shape - batch, chunks, chunk_len, heads, head_dim = Y_combined.shape - Y = Y_combined.reshape(batch, chunks * chunk_len, heads, head_dim) - - return Y, final_state - -def silu(x): - """Applies the Sigmoid Linear Unit (SiLU), element-wise.""" - return x * mx.sigmoid(x) - - -class Mamba2Block(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.d_model = args.hidden_size - self.d_state = args.state_size - self.d_conv = args.conv_kernel - self.expand = args.expand - self.d_inner = int(self.expand * self.d_model) - self.n_groups = args.n_groups - self.n_heads = args.num_heads - self.d_head = self.d_inner // self.n_heads - self.chunk_size = args.chunk_size - - d_in_proj = 2 * self.d_inner + 2 * self.n_groups * self.d_state + self.n_heads - self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=args.use_bias) - - self.dt_bias = mx.random.normal((self.n_heads,)) * args.initializer_range - self.A_log = mx.random.normal((self.n_heads,)) * args.initializer_range - self.D = mx.random.normal((self.n_heads,)) * args.initializer_range - - # Use standard Conv1d with groups for depthwise convolution - conv_dim = self.d_inner + 2 * self.n_groups * self.d_state - self.conv1d = nn.Conv1d( - in_channels=conv_dim, - out_channels=conv_dim, - kernel_size=self.d_conv, - groups=conv_dim, # Makes it depthwise - padding=self.d_conv-1, - bias=args.use_conv_bias - ) - - self.norm = nn.RMSNorm(self.d_inner, eps=args.layer_norm_epsilon) - self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=args.use_bias) - - def __call__(self, u, cache=None): - """ - Arguments - u: (batch, seqlen, d_model) input - cache: Optional tuple of (conv_state, ssm_state) for inference - - Return (y, cache) - y: (batch, seqlen, d_model) output - cache: updated tuple of (conv_state, ssm_state) for inference - """ - if cache is not None: - return self.step(u, cache) - - # Initialize cache if needed - if cache is None: - cache = [None, None] # Initialize with None values - - # Compute projections - zxbcdt = self.in_proj(u) - - # Split projections - d_inner = self.d_inner - d_state = self.n_groups * self.d_state - - z, xBC, dt = mx.split( - zxbcdt, - [d_inner, d_inner + 2 * d_state], - axis=-1 - ) - - # Process dt with softplus - dt = mx.softplus(dt + self.dt_bias) # (batch, seqlen, n_heads) - - # Apply convolution to xBC - xBC_transposed = mx.transpose(xBC, (0, 2, 1)) # (batch, d, seqlen) - xBC_conv = self.conv1d(xBC_transposed) - xBC_conv = mx.transpose(xBC_conv, (0, 2, 1)) # (batch, seqlen, d) - xBC = silu(xBC_conv[:, :u.shape[1], :]) # Ensure we only keep seqlen elements - - # Split xBC into x, B, C - x, B, C = mx.split( - xBC, - [d_inner, d_inner + d_state], - axis=-1 - ) - - # Reshape x for heads - batch, seqlen = x.shape[0], x.shape[1] - x_reshaped = x.reshape(batch, seqlen, self.n_heads, self.d_head) - - # Reshape B and C for SSM - B = B.reshape(batch, seqlen, 1, d_state) - C = C.reshape(batch, seqlen, 1, d_state) - - # Apply SSM with SSD algorithm - A = -mx.exp(self.A_log) # (n_heads,) - A_dt = A * dt # (batch, seqlen, n_heads) - - y, ssm_state = ssd( - x_reshaped * mx.expand_dims(dt, -1), # Scale x by dt - A_dt, - B, - C, - self.chunk_size - ) - - # Apply D and reshape - y = y + x_reshaped * mx.reshape(self.D, (1, 1, self.n_heads, 1)) - y = y.reshape(batch, seqlen, d_inner) - - # Apply norm and gating - y = self.norm(y, z) - - # Final projection - y = self.out_proj(y) - - # Create cache for inference - if seqlen == 1 and cache is not None: - conv_state = mx.zeros((batch, d_inner + 2 * d_state, self.d_conv)) - conv_state = mx.update_slice(conv_state, xBC.reshape(batch, -1, 1), (0, 0, self.d_conv - 1)) - cache[0] = conv_state - cache[1] = ssm_state - - return y, cache - - def step(self, u, cache): - """Take an inference step for the current input and cache - - Arguments - u: (batch, seqlen, d_model) - can be multiple tokens - cache: tuple of (conv_state, ssm_state) - - Return (y, cache) - y: (batch, seqlen, d_model) - cache: updated cache object - """ - batch, seqlen = u.shape[0], u.shape[1] - - # Initialize cache if it's None - if cache[0] is None or cache[1] is None: - d_state = self.n_groups * self.d_state - conv_dim = self.d_inner + 2 * d_state - conv_state = mx.zeros((batch, conv_dim, self.d_conv)) - - # Fix: use correct state size per head - state_per_head = d_state // self.n_heads - ssm_state = mx.zeros((batch, self.n_heads, self.d_head, state_per_head)) - else: - conv_state, ssm_state = cache[0], cache[1] - - # Project input - zxbcdt = self.in_proj(u) # (batch, seqlen, d_in_proj) - - # Split projections - d_inner = self.d_inner - d_state = self.n_groups * self.d_state - - z, xBC, dt = mx.split( - zxbcdt, - [d_inner, d_inner + 2 * d_state], - axis=-1 - ) - - # Process each token through the convolution sequentially - outputs = [] - for i in range(seqlen): - # Get current token's input - xBC_i = xBC[:, i] # (batch, d_inner + 2*d_state) - dt_i = dt[:, i] # (batch, dt_size) - - # Extract the head-specific dt values - dt_size = dt_i.shape[-1] - - if dt_size % self.n_heads == 0: - # Reshape dt_i to extract the head-specific values - dt_reshaped = dt_i.reshape(batch, self.n_heads, dt_size // self.n_heads) - # Take the first element for each head - dt_heads = dt_reshaped[:, :, 0] - else: - # If we can't reshape, just take the first n_heads elements - dt_heads = dt_i[:, :self.n_heads] - - - # Process dt with softplus - dt_heads = nn.softplus(dt_heads + self.dt_bias.reshape(1, -1)) # (batch, n_heads) - - # Update convolution state - conv_state = mx.roll(conv_state, shift=-1, axis=-1) - - # Use slice_update instead of update_slice - # Reshape xBC_i to match the expected shape for the update - xBC_reshaped = xBC_i.reshape(batch, -1, 1) - # Create start_indices for the update - start_indices = mx.array([0, 0, self.d_conv - 1]) - # Update the conv_state - conv_state = mx.slice_update( - conv_state, - xBC_reshaped, - start_indices, - axes=(0, 1, 2) - ) - - # Apply convolution step - weight = self.conv1d.weight - bias = self.conv1d.bias if self.args.use_conv_bias else None - - xBC_conv = mx.sum(conv_state * weight.reshape(1, -1, self.d_conv), axis=-1) - if bias is not None: - xBC_conv = xBC_conv + bias - - xBC_conv = silu(xBC_conv) - - # Split xBC - x_i, B_i, C_i = mx.split( - xBC_conv, - [d_inner, d_inner + d_state], - axis=-1 - ) - - # Apply SSM step - A = -mx.exp(self.A_log) # (n_heads,) - dA = mx.exp(dt_heads * A) # (batch, n_heads) - - # Reshape x for heads - x_i = x_i.reshape(batch, self.n_heads, self.d_head) - - # Reshape B and C for SSM with correct dimensions - state_per_head = d_state // self.n_heads - B_i_reshaped = B_i.reshape(batch, self.n_heads, state_per_head) - C_i_reshaped = C_i.reshape(batch, self.n_heads, state_per_head) - - # Calculate dBx with the correctly shaped B - dBx = mx.einsum("bhn,bhp->bhpn", B_i_reshaped, x_i * mx.expand_dims(dt_heads, -1)) - - # Update SSM state - ssm_state = ssm_state * mx.reshape(dA, (batch, self.n_heads, 1, 1)) + dBx - - # Calculate output with the correctly shaped C - y_i = mx.einsum("bhpn,bhn->bhp", ssm_state, C_i_reshaped) - - # Apply D and reshape - y_i = y_i + x_i * mx.reshape(self.D, (1, self.n_heads, 1)) - - # Reshape y - y_i = y_i.reshape(batch, d_inner) - - # Apply norm and gating (SwiGLU-like activation) - y_i = self.norm(y_i) # Just normalize without gating - y_i = y_i * nn.sigmoid(z[:, i]) # Apply gating separately - - # Final projection - y_i = self.out_proj(y_i) - - outputs.append(y_i) - - # Stack outputs along sequence dimension - y = mx.stack(outputs, axis=1) # (batch, seqlen, d_model) - - # Update cache - cache[0] = conv_state - cache[1] = ssm_state - - return y - - -class ResidualBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.residual_in_fp32 = args.residual_in_fp32 - self.mixer = Mamba2Block(args) - self.norm = nn.RMSNorm(args.hidden_size) - - def __call__(self, x: mx.array, cache): - if self.residual_in_fp32: - x = x.astype(mx.float32) - normed = self.norm(x) - output = self.mixer(normed, cache) - return output + x - - -class Mamba2(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] - self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) - - def __call__(self, x: mx.array, cache): - x = self.embeddings(x) - if cache is None: - cache = [None] * len(self.layers) - - hidden = x - for layer, c in zip(self.layers, cache): - hidden = layer(hidden, c) - return self.norm_f(hidden) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.backbone = Mamba2(args) - - if not args.tie_word_embeddings: - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__(self, inputs: mx.array, cache=None): - hidden = self.backbone(inputs, cache) - - if self.args.tie_word_embeddings: - logits = self.backbone.embeddings.as_linear(hidden) - else: - logits = self.lm_head(hidden) - - return logits - - def make_cache(self): - return [MambaCache() for _ in range(len(self.layers))] - - @property - def layers(self): - return self.backbone.layers - - - - - - - - - - -######################################################## - - - - - -import math -from dataclasses import dataclass, field -from typing import Tuple, Union -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs -from .cache import MambaCache - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - num_heads: int - head_dim: int - vocab_size: int - hidden_size: int - state_size: int - num_hidden_layers: int - layer_norm_epsilon: float - expand: int - conv_kernel: int - n_groups: int - use_bias: bool - use_conv_bias: bool - initializer_range: float - residual_in_fp32: bool - chunk_size: int - tie_word_embeddings: bool - time_step_limit: Tuple[float, float] - time_step_rank: Union[int, str] - time_step_min: float - time_step_max: float - time_step_floor: float - norm_before_gate: bool = True - - def __post_init__(self): - if not hasattr(self, "intermediate_size"): - self.intermediate_size = int(self.expand * self.hidden_size) - if not hasattr(self, "head_dim"): - self.head_dim = self.hidden_size // self.num_heads - if self.time_step_rank == "auto": - self.time_step_rank = math.ceil(self.hidden_size / 16) - - -def segsum(x): - """Stable segment sum calculation. - - `exp(segsum(A))` produces a 1-semiseparable matrix, which is equivalent to a scalar SSM. - """ - T = x.shape[-1] - x = mx.expand_dims(x, -1) - x = mx.repeat(x, T, axis=-1) - mask = mx.tril(mx.ones((T, T), dtype=mx.bool_), k=-1) - x = mx.where(mask, x, 0) - x_segsum = mx.cumsum(x, axis=-2) - mask = mx.tril(mx.ones((T, T), dtype=mx.bool_), k=0) - x_segsum = mx.where(mask, x_segsum, -mx.inf) - return x_segsum - -def ssd(x, A, B, C, chunk_size, initial_states=None): - """Structured State Space Duality (SSD) - the core of Mamba-2 - - Arguments - x: (batch, seqlen, n_heads, d_head) - A: (batch, seqlen, n_heads) - B: (batch, seqlen, n_heads, d_state) - C: (batch, seqlen, n_heads, d_state) - - Return - y: (batch, seqlen, n_heads, d_head) - final_state: final state for inference - """ - assert x.shape[1] % chunk_size == 0 - - # Rearrange into chunks - def rearrange_to_chunks(m): - shape = list(m.shape) - shape[1:2] = [shape[1] // chunk_size, chunk_size] - return m.reshape(shape) - - x_chunked = rearrange_to_chunks(x) - A_chunked = rearrange_to_chunks(A) - B_chunked = rearrange_to_chunks(B) - C_chunked = rearrange_to_chunks(C) - - # Transpose A for easier cumsum - A_chunked = mx.transpose(A_chunked, (0, 3, 1, 2)) # b c l h -> b h c l - A_cumsum = mx.cumsum(A_chunked, axis=-1) - - # 1. Compute the output for each intra-chunk (diagonal blocks) - L = mx.exp(segsum(A_chunked)) - Y_diag = mx.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C_chunked, B_chunked, L, x_chunked) - - # 2. Compute the state for each intra-chunk - decay_states = mx.exp(A_cumsum[:, :, :, -1:] - A_cumsum) - states = mx.einsum("bclhn,bhcl,bclhp->bchpn", B_chunked, decay_states, x_chunked) - - # 3. Compute the inter-chunk SSM recurrence - if initial_states is None: - initial_states = mx.zeros_like(states[:, :1]) - states = mx.concatenate([initial_states, states], axis=1) - - A_cumsum_last = A_cumsum[:, :, :, -1] - A_cumsum_padded = mx.pad(A_cumsum_last, [(0, 0), (0, 0), (1, 0)]) - decay_chunk = mx.exp(segsum(A_cumsum_padded)) - new_states = mx.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) - states, final_state = new_states[:, :-1], new_states[:, -1] - - # 4. Compute state -> output conversion per chunk - state_decay_out = mx.exp(A_cumsum) - Y_off = mx.einsum("bclhn,bchpn,bhcl->bclhp", C_chunked, states, state_decay_out) - - # Add output of intra-chunk and inter-chunk terms - Y_combined = Y_diag + Y_off - - # Reshape back to original sequence shape - batch, chunks, chunk_len, heads, head_dim = Y_combined.shape - Y = Y_combined.reshape(batch, chunks * chunk_len, heads, head_dim) - - return Y, final_state - -def silu(x): - """Applies the Sigmoid Linear Unit (SiLU), element-wise.""" - return x * mx.sigmoid(x) - - -class Mamba2Block(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.d_model = args.hidden_size - self.d_state = args.state_size - self.d_conv = args.conv_kernel - self.expand = args.expand - self.d_inner = int(self.expand * self.d_model) - self.n_groups = args.n_groups - self.n_heads = args.num_heads - self.d_head = self.d_inner // self.n_heads - self.chunk_size = args.chunk_size - - d_in_proj = 2 * self.d_inner + 2 * self.n_groups * self.d_state + self.n_heads - self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=args.use_bias) - - self.dt_bias = mx.random.normal((self.n_heads,)) * args.initializer_range - self.A_log = mx.random.normal((self.n_heads,)) * args.initializer_range - self.D = mx.random.normal((self.n_heads,)) * args.initializer_range - - # Use standard Conv1d with groups for depthwise convolution - conv_dim = self.d_inner + 2 * self.n_groups * self.d_state - self.conv1d = nn.Conv1d( - in_channels=conv_dim, - out_channels=conv_dim, - kernel_size=self.d_conv, - groups=conv_dim, - padding=self.d_conv-1, - bias=args.use_conv_bias - ) - - self.norm = nn.RMSNorm(self.d_inner, eps=args.layer_norm_epsilon) - self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=args.use_bias) - - def __call__(self, u, cache=None): - """ - Arguments - u: (batch, seqlen, d_model) input - cache: Optional tuple of (conv_state, ssm_state) for inference - - Return (y, cache) - y: (batch, seqlen, d_model) output - cache: updated tuple of (conv_state, ssm_state) for inference - """ - if cache is not None: - return self.step(u, cache) - - # Initialize cache if needed - if cache is None: - cache = [None, None] # Initialize with None values - - # Compute projections - zxbcdt = self.in_proj(u) - - # Split projections - d_inner = self.d_inner - d_state = self.n_groups * self.d_state - - z, xBC, dt = mx.split( - zxbcdt, - [d_inner, d_inner + 2 * d_state], - axis=-1 - ) - - # Process dt with softplus - dt = mx.softplus(dt + self.dt_bias) # (batch, seqlen, n_heads) - - # Apply convolution to xBC - xBC_transposed = mx.transpose(xBC, (0, 2, 1)) # (batch, d, seqlen) - xBC_conv = self.conv1d(xBC_transposed) - xBC_conv = mx.transpose(xBC_conv, (0, 2, 1)) # (batch, seqlen, d) - xBC = silu(xBC_conv[:, :u.shape[1], :]) # Ensure we only keep seqlen elements - - # Split xBC into x, B, C - x, B, C = mx.split( - xBC, - [d_inner, d_inner + d_state], - axis=-1 - ) - - # Reshape x for heads - batch, seqlen = x.shape[0], x.shape[1] - x_reshaped = x.reshape(batch, seqlen, self.n_heads, self.d_head) - - # Reshape B and C for SSM - B = B.reshape(batch, seqlen, 1, d_state) - C = C.reshape(batch, seqlen, 1, d_state) - - # Apply SSM with SSD algorithm - A = -mx.exp(self.A_log) # (n_heads,) - A_dt = A * dt # (batch, seqlen, n_heads) - - y, ssm_state = ssd( - x_reshaped * mx.expand_dims(dt, -1), # Scale x by dt - A_dt, - B, - C, - self.chunk_size - ) - - # Apply D and reshape - y = y + x_reshaped * mx.reshape(self.D, (1, 1, self.n_heads, 1)) - y = y.reshape(batch, seqlen, d_inner) - - # Apply norm and gating - y = self.norm(y, z) - - # Final projection - y = self.out_proj(y) - - # Create cache for inference - if seqlen == 1 and cache is not None: - conv_state = mx.zeros((batch, d_inner + 2 * d_state, self.d_conv)) - conv_state = mx.update_slice(conv_state, xBC.reshape(batch, -1, 1), (0, 0, self.d_conv - 1)) - cache[0] = conv_state - cache[1] = ssm_state - - return y - - def step(self, u, cache): - """Take an inference step for the current input and cache - - Arguments - u: (batch, seqlen, d_model) - can be multiple tokens - cache: tuple of (conv_state, ssm_state) - - Return (y, cache) - y: (batch, seqlen, d_model) - cache: updated cache object - """ - batch, seqlen = u.shape[0], u.shape[1] - - # Initialize cache if it's None - if cache[0] is None or cache[1] is None: - d_state = self.n_groups * self.d_state - conv_dim = self.d_inner + 2 * d_state - conv_state = mx.zeros((batch, conv_dim, self.d_conv)) - - # Fix: use correct state size per head - state_per_head = d_state // self.n_heads - ssm_state = mx.zeros((batch, self.n_heads, self.d_head, state_per_head)) - else: - conv_state, ssm_state = cache[0], cache[1] - - # Project input - zxbcdt = self.in_proj(u) # (batch, seqlen, d_in_proj) - - # Split projections - d_inner = self.d_inner - d_state = self.n_groups * self.d_state - - z, xBC, dt = mx.split( - zxbcdt, - [d_inner, d_inner + 2 * d_state], - axis=-1 - ) - - # Process dt with softplus once for all tokens - dt_heads = dt.reshape(batch, seqlen, -1)[:, :, :self.n_heads] - dt_heads = nn.softplus(dt_heads + self.dt_bias.reshape(1, 1, -1)) - - # Pre-compute dA for all tokens - A = -mx.exp(self.A_log) # (n_heads,) - dA = mx.exp(dt_heads * A.reshape(1, 1, -1)) # (batch, seqlen, n_heads) - - # Get convolution weights - weight = self.conv1d.weight # shape: (out_channels, 1, kernel_size) - bias = self.conv1d.bias if self.args.use_conv_bias else None - - # Process each token through the convolution sequentially - outputs = [] - for i in range(seqlen): - # Get current token's input - xBC_i = xBC[:, i] # (batch, d_inner + 2*d_state) - - # Update convolution state - conv_state = mx.roll(conv_state, shift=-1, axis=-1) - - # Update the last column of conv_state - conv_state = mx.slice_update( - conv_state, - xBC_i.reshape(batch, -1, 1), - mx.array([0, 0, self.d_conv - 1]), - axes=(0, 1, 2) - ) - - # Apply convolution step - manually handle the depthwise conv - # For a depthwise conv, we need to process each channel separately - # conv_state shape: (batch, channels, kernel_size) - # weight shape: (channels, 1, kernel_size) for depthwise conv - - # Reshape weight to match conv_state for element-wise multiplication - # and then sum along the kernel dimension - weight_reshaped = weight.reshape(conv_state.shape[1], self.d_conv) - xBC_conv = mx.sum(conv_state * weight_reshaped.reshape(1, -1, self.d_conv), axis=-1) - - if bias is not None: - xBC_conv = xBC_conv + bias - - xBC_conv = silu(xBC_conv) - - # Split xBC - x_i, BC_rest = mx.split(xBC_conv, [d_inner], axis=-1) - B_i, C_i = mx.split(BC_rest, [d_state], axis=-1) - - # Reshape x for heads - x_i = x_i.reshape(batch, self.n_heads, self.d_head) - - # Reshape B and C for SSM - state_per_head = d_state // self.n_heads - B_i_reshaped = B_i.reshape(batch, self.n_heads, state_per_head) - C_i_reshaped = C_i.reshape(batch, self.n_heads, state_per_head) - - # Get current token's dt and dA - dt_i = dt_heads[:, i] # (batch, n_heads) - dA_i = dA[:, i] # (batch, n_heads) - - # Calculate dBx - dBx = mx.einsum("bhn,bhp->bhpn", B_i_reshaped, x_i * mx.expand_dims(dt_i, -1)) - - # Update SSM state - ssm_state = ssm_state * mx.reshape(dA_i, (batch, self.n_heads, 1, 1)) + dBx - - # Calculate output with the correctly shaped C - y_i = mx.einsum("bhpn,bhn->bhp", ssm_state, C_i_reshaped) - - # Apply D and reshape - y_i = y_i + x_i * self.D.reshape(1, self.n_heads, 1) - - # Reshape y - y_i = y_i.reshape(batch, d_inner) - - # Apply norm and gating - y_i = self.norm(y_i) * nn.sigmoid(z[:, i]) - - # Final projection - y_i = self.out_proj(y_i) - - outputs.append(y_i) - - # Stack outputs along sequence dimension - y = mx.stack(outputs, axis=1) - - # Update cache - cache[0] = conv_state - cache[1] = ssm_state - - return y - - -class ResidualBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.residual_in_fp32 = args.residual_in_fp32 - self.mixer = Mamba2Block(args) - self.norm = nn.RMSNorm(args.hidden_size) - - def __call__(self, x: mx.array, cache): - if self.residual_in_fp32: - x = x.astype(mx.float32) - normed = self.norm(x) - output = self.mixer(normed, cache) - return output + x - - -class Mamba2(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] - self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) - - def __call__(self, x: mx.array, cache): - x = self.embeddings(x) - if cache is None: - cache = [None] * len(self.layers) - - hidden = x - for layer, c in zip(self.layers, cache): - hidden = layer(hidden, c) - return self.norm_f(hidden) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.backbone = Mamba2(args) - - if not args.tie_word_embeddings: - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__(self, inputs: mx.array, cache=None): - hidden = self.backbone(inputs, cache) - - if self.args.tie_word_embeddings: - logits = self.backbone.embeddings.as_linear(hidden) - else: - logits = self.lm_head(hidden) - - return logits - - def make_cache(self): - return [MambaCache() for _ in range(len(self.layers))] - - @property - def layers(self): - return self.backbone.layers \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index 72179c02..6adc6469 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -42,29 +42,6 @@ class ModelArgs(BaseModelArgs): self.time_step_rank = math.ceil(self.hidden_size / 16) -class DepthWiseConv1d(nn.Module): - def __init__(self, channels, kernel_size, bias=True, padding=0): - super().__init__() - self.channels = channels - self.kernel_size = kernel_size - self.padding = padding - self.weight = mx.random.normal((channels, kernel_size, 1)) - self.bias = mx.zeros((channels,)) if bias else None - - def __call__(self, x, cache=None): - B, L, C = x.shape - _, K, _ = self.weight.shape - - if cache is not None: - x = mx.concatenate([cache, x], axis=1) - else: - x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)]) - - y = mx.conv_general(x, self.weight, groups=C) - y = y + self.bias - return y, x[:, -K + 1:, :] - - def ssd_forward_attn( x: mx.array, dt: mx.array, @@ -118,10 +95,8 @@ def ssd_forward_attn( # Initialize or update state if prev_state is not None: - # Simply use the previous state if it exists - # This is a simplified approach - just use the new state - # In a real implementation, you'd want to properly update based on your SSM formulation - next_state = new_state_contribution + decayed_prev_state = prev_state * decay[:, :, -1, :].reshape(b, h, 1, 1) + next_state = decayed_prev_state + new_state_contribution else: next_state = new_state_contribution @@ -169,11 +144,13 @@ class Mamba2Block(nn.Module): self.A_log = mx.random.normal((self.n_heads,)) * args.initializer_range self.D = mx.random.normal((self.n_heads,)) * args.initializer_range - self.conv1d = DepthWiseConv1d( - channels=self.d_inner + 2 * self.n_groups * self.d_state, + conv_channels = self.d_inner + 2 * self.n_groups * self.d_state + self.conv1d = nn.Conv1d( + in_channels=conv_channels, + out_channels=conv_channels, kernel_size=self.d_conv, - bias=args.use_conv_bias, - padding=self.d_conv-1 + groups=conv_channels, + padding=self.d_conv - 1, ) self.norm = nn.RMSNorm(self.d_inner, eps=args.layer_norm_epsilon) @@ -195,10 +172,36 @@ class Mamba2Block(nn.Module): axis=-1 ) - xBC, conv_state = self.conv1d(xBC, conv_state) - xBC = xBC * mx.sigmoid(xBC) - xBC = xBC[:, :seq_len, :] - + # Handle convolution with caching + xBC = mx.swapaxes(xBC, 1, 2) # [B, L, C] -> [B, C, L] + + if conv_state is not None and seq_len > 0: + # Concatenate cached state with current input + xBC_with_cache = mx.concatenate([conv_state, xBC], axis=2) + elif seq_len > 0: + # For the first call, pad with zeros + padding = mx.zeros((batch_size, xBC.shape[1], self.d_conv - 1)) + xBC_with_cache = mx.concatenate([padding, xBC], axis=2) + else: + xBC_with_cache = conv_state if conv_state is not None else mx.zeros((batch_size, xBC.shape[1], 0)) + + # Save state for next iteration + if seq_len > 0: + next_conv_state = xBC_with_cache[:, :, -(self.d_conv - 1):] + else: + next_conv_state = conv_state + + # Apply regular convolution using nn.Conv1d + if seq_len > 0: + # Use the standard Conv1d module for the actual computation + xBC_conv = self.conv1d(xBC_with_cache) + xBC = xBC_conv[:, :, -seq_len:] # Take only the relevant output positions + xBC = mx.swapaxes(xBC, 1, 2) # [B, C, L] -> [B, L, C] + xBC = xBC * mx.sigmoid(xBC) + else: + # Handle empty sequence case + xBC = mx.swapaxes(xBC, 1, 2) # [B, C, L] -> [B, L, C] + x, B, C = mx.split( xBC, [self.d_inner, self.d_inner + self.d_state * self.n_groups], @@ -209,7 +212,6 @@ class Mamba2Block(nn.Module): B = mx.reshape(B, (batch_size, seq_len, self.n_groups, -1)) C = mx.reshape(C, (batch_size, seq_len, self.n_groups, -1)) - A = -mx.exp(self.A_log) y, next_ssm_state = ssd_forward_attn( x=x, dt=dt, @@ -232,7 +234,7 @@ class Mamba2Block(nn.Module): y = self.out_proj(y) - cache[0] = conv_state + cache[0] = next_conv_state cache[1] = next_ssm_state return y