From 58b448dc0bb6f2461563526869dcf59dbd4a4a8a Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 30 Oct 2024 21:23:13 +0100 Subject: [PATCH] updates --- llms/mlx_lm/models/cache.py | 134 +---- llms/mlx_lm/models/mamba2 copy.py | 844 +++++++++++------------------ llms/mlx_lm/models/mamba2-prch.py | 861 ++++++++++++++++-------------- llms/mlx_lm/models/mamba2.py | 453 ++++++++-------- 4 files changed, 1007 insertions(+), 1285 deletions(-) diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index 05dcfee7..2764ce44 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -324,6 +324,7 @@ class RotatingKVCache(_BaseCache): class MambaCache(_BaseCache): def __init__(self): self.cache = [None, None] + self.offset = 0 def __setitem__(self, idx, value): self.cache[idx] = value @@ -341,129 +342,12 @@ class MambaCache(_BaseCache): class Mamba2Cache: - batch_size: int - intermediate_size: int - state_size: int - conv_kernel: int - num_heads: int - head_dim: int - - def __init__( - self, - batch_size: int, - intermediate_size: int, - state_size: int, - conv_kernel: int, - num_heads: int, - head_dim: int - ): - self.batch_size = batch_size - self.intermediate_size = intermediate_size - self.state_size = state_size - self.conv_kernel = conv_kernel - self.num_heads = num_heads - self.head_dim = head_dim - - # Initialize conv state with proper dimensions - self.conv_dim = self.intermediate_size + 2 * self.state_size - self.conv_state = mx.zeros((batch_size, self.conv_dim, conv_kernel - 1)) - - # Initialize SSM state - self.ssm_state = mx.zeros(( - batch_size, - num_heads, - head_dim, - state_size - )) + def __init__(self, batch_size, conv_dim, kernel_size, num_heads, head_dim, state_size): + self.conv_states = mx.zeros((batch_size, conv_dim, kernel_size - 1)) + self.ssm_states = mx.zeros((batch_size, num_heads, head_dim, state_size)) + self.seqlen_offset = 0 - def update_conv_state(self, x: mx.array) -> mx.array: - """ - Update convolution state for incremental inference. - Args: - x: Input tensor containing projected values (B, conv_in_dim) - Returns: - Combined state tensor of shape (batch_size, conv_dim, kernel_size) - """ - # Handle input shape - if x.ndim == 1: - x = mx.expand_dims(x, 0) # Add batch dimension if needed - - # Ensure batch size matches - assert x.shape[0] == self.batch_size, f"Batch size mismatch: {x.shape[0]} vs {self.batch_size}" - - # Reshape x to match conv_dim - # The input x contains intermediate_size + 2 * state_size dimensions - x_reshaped = mx.reshape(x, (self.batch_size, -1)) - x_padded = mx.pad( - x_reshaped, - [(0, 0), (0, self.conv_dim - x_reshaped.shape[1])], - mode='constant', - constant_values=0 - ) - - # Expand dims for concatenation - x_expanded = mx.expand_dims(x_padded, -1) # Shape: (batch_size, conv_dim, 1) - - # Roll the existing state left by 1 - rolled_state = mx.roll(self.conv_state, shift=-1, axis=-1) - - # Create update mask for the last position - update_pos = self.conv_kernel - 2 - state_idx = mx.arange(self.conv_kernel - 1) - update_mask = state_idx == update_pos - - # Broadcast mask to match state dimensions - update_mask = mx.broadcast_to( - mx.reshape(update_mask, (1, 1, -1)), - rolled_state.shape - ) - - # Update state with padded input - x_broadcast = mx.broadcast_to(x_expanded, (self.batch_size, self.conv_dim, 1)) - self.conv_state = mx.where( - update_mask, - x_broadcast, - rolled_state - ) - - # Return concatenated state for convolution - return mx.concatenate([self.conv_state, x_expanded], axis=-1) - - def update_ssm_state(self, dA: mx.array, dBx: mx.array) -> mx.array: - """ - Update SSM state for incremental inference. - Args: - dA: State transition tensor of shape (batch_size, num_heads) - dBx: Input projection tensor of shape (batch_size, num_heads, head_dim, state_size) - Returns: - Updated SSM state of shape (batch_size, num_heads, head_dim, state_size) - """ - # Add necessary dimensions to dA for broadcasting - # dA shape: (batch_size, num_heads) -> (batch_size, num_heads, 1, 1) - dA = mx.expand_dims(mx.expand_dims(dA, -1), -1) - - # Ensure dBx has the correct shape - assert dBx.shape[-1] == self.state_size, f"dBx state dimension mismatch: {dBx.shape[-1]} vs {self.state_size}" - assert dBx.shape[-2] == self.head_dim, f"dBx head dimension mismatch: {dBx.shape[-2]} vs {self.head_dim}" - - # Update state: state = dA * state + dBx - self.ssm_state = dA * self.ssm_state + dBx - - return self.ssm_state - - @classmethod - def get_cache( - cls, - args, - batch_size: int, - max_seq_length: Optional[int] - ) -> "Mamba2Cache": - """Create a new cache instance with the given parameters.""" - return cls( - batch_size=batch_size, - intermediate_size=args.intermediate_size, - state_size=args.state_size, - conv_kernel=args.conv_kernel, - num_heads=args.num_heads, - head_dim=args.head_dim - ) \ No newline at end of file + def update(self, new_conv_state, new_ssm_state): + self.conv_states = new_conv_state + self.ssm_states = new_ssm_state + self.seqlen_offset += 1 \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba2 copy.py b/llms/mlx_lm/models/mamba2 copy.py index 0d4dedb2..fc3f23d8 100644 --- a/llms/mlx_lm/models/mamba2 copy.py +++ b/llms/mlx_lm/models/mamba2 copy.py @@ -1,275 +1,7 @@ import math from dataclasses import dataclass, field -from typing import Tuple, Union -import mlx.core as mx -import mlx.nn as nn +from typing import Optional, Tuple, Union -from .base import BaseModelArgs -from .cache import MambaCache - -@dataclass -class ModelArgs(BaseModelArgs): - num_heads: int - head_dim: int - vocab_size: int - hidden_size: int - state_size: int - num_hidden_layers: int - layer_norm_epsilon: float - expand: int - conv_kernel: int - n_groups: int - use_bias: bool - use_conv_bias: bool - initializer_range: float - residual_in_fp32: bool - time_step_min: float - time_step_max: float - time_step_floor: float - rescale_prenorm_residual: bool - use_cache: bool - rms_norm: bool - chunk_size: int - tie_word_embeddings: bool - time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf"))) - time_step_rank: Union[int, str] = "auto" - model_type: str = "mamba2" - - def __post_init__(self): - if not hasattr(self, "intermediate_size"): - self.intermediate_size = int(self.expand * self.hidden_size) - if not hasattr(self, "head_dim"): - self.head_dim = self.hidden_size // self.num_heads - if self.time_step_rank == "auto": - self.time_step_rank = math.ceil(self.hidden_size / 16) - - -class MambaRMSNormGated(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - super().__init__() - self.weight = mx.ones((hidden_size,)) - self.variance_epsilon = eps - - def __call__(self, hidden_states, gate=None): - if gate is not None: - hidden_states = hidden_states * nn.silu(gate) - variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True) - hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states - - -class DepthWiseConv1d(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.padding = padding - self.groups = groups if groups is not None else in_channels - - # Ensure in_channels and out_channels are the same for depthwise conv - assert in_channels == out_channels, "In and out channels must be the same for depthwise convolution" - # Ensure groups is equal to in_channels for depthwise conv - assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution" - - # Initialize weight with shape (out_channels, kernel_size, 1) - self.weight = mx.random.normal((out_channels, kernel_size, 1)) - self.bias = mx.zeros((out_channels,)) if bias else None - - def __call__(self, x, cache=None): - B, L, C = x.shape - _, K, _ = self.weight.shape - - if cache is not None: - x = mx.concatenate([cache, x], axis=1) - else: - x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)]) - - y = mx.conv_general(x, self.weight, groups=self.groups) - - if self.bias is not None: - y = y + self.bias - - return y, x[:, -K + 1 :, :] - - -class Mamba2Block(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.intermediate_size = args.intermediate_size - self.time_step_rank = args.time_step_rank - self.conv_kernel_size = args.conv_kernel - self.hidden_size = args.hidden_size - self.state_size = args.state_size - self.num_heads = args.num_heads - self.head_dim = args.hidden_size // args.num_heads - self.n_groups = args.n_groups - - # projection_size = 2 * args.intermediate_size + 2 * args.n_groups * args.state_size + args.num_heads - projection_size = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads - self.in_proj = nn.Linear( - args.hidden_size, - projection_size, - bias=args.use_bias - ) - - # self.conv_dim = args.intermediate_size + 2 * args.n_groups * args.state_size - self.conv_dim = args.intermediate_size + 2 * args.state_size - self.conv1d = DepthWiseConv1d( - in_channels=self.conv_dim, - out_channels=self.conv_dim, - kernel_size=args.conv_kernel, - bias=args.use_conv_bias, - groups=self.conv_dim, - padding=args.conv_kernel - 1 - ) - - self.A_log = mx.zeros(args.num_heads) - self.D = mx.ones((args.num_heads,)) - self.dt_bias = mx.zeros(args.num_heads) - - self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) - self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon) - - def ssm_step(self, x, state, dt): - A = -mx.exp(self.A_log) - D = self.D - dt = nn.softplus(dt + self.dt_bias) - - B, C = mx.split(x, indices_or_sections=[self.state_size * self.n_groups], axis=-1) - - batch_size = B.shape[0] - B = B.reshape(batch_size, self.n_groups, self.state_size) - C = C.reshape(batch_size, -1, self.state_size) - - dt = dt.reshape(batch_size, self.num_heads, 1) - A = A.reshape(1, self.num_heads, 1) - - if state is None: - new_state = dt * B - else: - new_state = dt * (B + state * mx.exp(dt * A)) - - y = mx.sum(new_state[:, :, None, :] * C[:, None, :, :], axis=(-1, -2)) - y = y + D * x[:, :self.num_heads] - return y, new_state - - def __call__(self, x, cache): - B, T, D = x.shape - if cache is None: - cache = [None, None] - - outputs = [] - for t in range(T): - xt = x[:, t, :] - zxbcdt = self.in_proj(xt) - - z, xBC, dt = mx.split( - zxbcdt, - # indices_or_sections=[self.conv_dim, self.conv_dim + self.intermediate_size], - indices_or_sections=[ - self.intermediate_size, - self.intermediate_size + 2 * self.state_size, - self.num_heads - ], - axis=-1 - ) - - # Use the new DepthWiseConv1d with caching - conv_out, cache[0] = self.conv1d(mx.expand_dims(z, 1), cache[0]) - z = conv_out.squeeze(1) - z = nn.silu(z) - y_t, cache[1] = self.ssm_step(z, cache[1], dt) - xBC = nn.silu(xBC) - - # Element-wise multiplication - output_t = y_t[:, :, None] * xBC[:, None, :] - - output_t = self.norm(output_t) - output_t = output_t.sum(axis=1) - output_t = self.out_proj(output_t) - outputs.append(output_t) - - output = mx.stack(outputs, axis=1) - return output - - -class ResidualBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.mixer = Mamba2Block(args) - self.norm = nn.RMSNorm(args.hidden_size) - - def __call__(self, x: mx.array, cache): - return self.mixer(self.norm(x), cache) + x - - -class Mamba2(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] - self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) - - def __call__(self, x: mx.array, cache): - x = self.embeddings(x) - if cache is None: - cache = [None] * len(self.layers) - for layer, c in zip(self.layers, cache): - x = layer(x, c) - return self.norm_f(x) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - - self.backbone = Mamba2(args) - # self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) - - if not args.tie_word_embeddings: - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__(self, inputs: mx.array, cache=None): - B, T = inputs.shape - - x = self.backbone(inputs, cache) - - if self.args.tie_word_embeddings: - logits = self.backbone.embeddings.as_linear(x) - else: - logits = self.lm_head(x) - - return logits - - def sanitize(self, weights): - for k, v in weights.items(): - if "conv1d.weight" in k and v.ndim == 3: - weights[k] = v.moveaxis(2, 1) - return weights - - def make_cache(self): - return [MambaCache() for _ in range(len(self.layers))] - - @property - def layers(self): - return self.backbone.layers - - - - - -# ------ - - - -import math -from dataclasses import dataclass, field -from typing import Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -296,130 +28,79 @@ class ModelArgs(BaseModelArgs): time_step_max: float time_step_floor: float rescale_prenorm_residual: bool + use_cache: bool rms_norm: bool chunk_size: int tie_word_embeddings: bool - use_cache: bool = True + intermediate_size: int = None time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf"))) time_step_rank: Union[int, str] = "auto" model_type: str = "mamba2" def __post_init__(self): - if not hasattr(self, "intermediate_size"): - self.intermediate_size = int(self.expand * self.hidden_size) + self.intermediate_size = int(self.expand * self.hidden_size) # E*D = ED + if not hasattr(self, "head_dim"): self.head_dim = self.hidden_size // self.num_heads if self.time_step_rank == "auto": self.time_step_rank = math.ceil(self.hidden_size / 16) + +def selective_scan(x, A, B, C, chunk_size): + """ + Selective scan implementation for training. -class MambaRMSNormGated(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - super().__init__() - self.weight = mx.ones((hidden_size,)) - self.variance_epsilon = eps + Arguments + x: (batch, seqlen, n_heads, d_head) + A: (batch, seqlen, n_heads) + B: (batch, seqlen, n_heads, d_state) + C: (batch, seqlen, n_heads, d_state) - def __call__(self, hidden_states, gate=None): - if gate is not None: - hidden_states = hidden_states * nn.silu(gate) - variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True) - hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states + Return + y: (batch, seqlen, n_heads, d_head) + """ + assert x.shape[1] % chunk_size == 0 - -def silu(x): - return x * mx.sigmoid(x) - -def ssd(x, A, B, C, chunk_size): - # Replace einsum operations with explicit reshape and matrix multiply - batch, seqlen, nheads, dim = x.shape - B = mx.expand_dims(B, axis=2) - C = mx.expand_dims(C, axis=2) + # Reshape into chunks + def chunk_reshape(m): + shape = list(m.shape) + shape[1:2] = [shape[1] // chunk_size, chunk_size] + return m.reshape(shape) - state = mx.zeros((batch, nheads, dim, B.shape[-1])) - outputs = [] + x, A, B, C = map(chunk_reshape, (x, A, B, C)) + A = mx.transpose(A, [0, 3, 1, 2]) - for i in range(0, seqlen, chunk_size): - chunk = slice(i, min(i + chunk_size, seqlen)) - dA = mx.exp(mx.expand_dims(A[chunk], axis=0)) - - # Replace einsum with explicit operations - x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim] - x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size] - B_chunk = B[:, chunk] # [batch, chunk_size, state_size] - dBx = mx.matmul(x_chunk, B_chunk) # [batch, nheads, dim, state_size] - - state = state * mx.expand_dims(dA, axis=-1) + dBx - - # Replace einsum with explicit operations - C_chunk = C[:, chunk] # [batch, chunk_size, state_size] - y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size] - y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim] - outputs.append(y) + # Compute cumulative sums + A_cumsum = mx.cumsum(A, axis=-1) - return mx.concatenate(outputs, axis=1), state + # Process chunks + L = mx.exp(selective_cumsum(A)) + Y_diag = mx.einsum('bclhn,bcshn,bhcls,bcshp->bclhp', C, B, L, x) + decay_states = mx.exp(A_cumsum[..., -1:] - A_cumsum) + states = mx.einsum('bclhn,bhcl,bclhp->bchpn', B, decay_states, x) + + initial_states = mx.zeros_like(states[:, :1]) + states = mx.concatenate([initial_states, states], axis=1) + decay_chunk = mx.exp(selective_cumsum(mx.pad(A_cumsum[..., -1], ((0,0), (0,0), (1,0))))) + new_states = mx.einsum('bhzc,bchpn->bzhpn', decay_chunk, states) + states = new_states[:, :-1] + + state_decay_out = mx.exp(A_cumsum) + Y_off = mx.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out) + + Y = (Y_diag + Y_off).reshape((-1, x.shape[1] * chunk_size, *Y_diag.shape[-2:])) + return Y -class DepthWiseConv1d(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.padding = padding - self.groups = groups if groups is not None else in_channels - - assert in_channels == out_channels, "In and out channels must be same for depthwise convolution" - assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution" - - self.weight = mx.random.normal((in_channels, 1, kernel_size)) - self.bias = mx.zeros((out_channels,)) if bias else None - - def __call__(self, x: mx.array, cache=None) -> mx.array: - B, L, C = x.shape - K = self.kernel_size - - assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}" - - if cache is not None and 'conv_states' in cache: - conv_states = cache['conv_states'] - if conv_states is not None: - assert conv_states.shape[0] == B, "Cache batch size mismatch" - assert conv_states.shape[2] == C, "Cache channel count mismatch" - x = mx.concatenate([conv_states, x], axis=1) - - # Process each channel independently - outputs = [] - for c in range(C): - x_c = x[:, :, c] - x_c = mx.expand_dims(x_c, axis=1) - - w_c = self.weight[c] - if w_c.ndim == 2: - w_c = mx.expand_dims(w_c, axis=0) - elif w_c.ndim == 1: - w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0) - - # Apply convolution - y_c = mx.conv_general( - x_c, - w_c, - stride=1, - padding=0 - ) - - if self.bias is not None: - y_c = y_c + self.bias[c] - - outputs.append(mx.squeeze(y_c, axis=1)) - - y = mx.stack(outputs, axis=-1) - - # Update cache - if cache is not None: - cache['conv_states'] = x[:, -K+1:, :] if x.shape[1] >= K else x - - return y +def selective_cumsum(x: mx.array) -> mx.array: + """Stable selective cumulative sum calculation.""" + T = x.shape[-1] + x = mx.repeat(x[..., None], T, axis=-1) + mask = mx.tril(mx.ones((T, T)), k=-1) + x = x * mask + x_cumsum = mx.cumsum(x, axis=-2) + mask = mx.tril(mx.ones((T, T)), k=0) + return mx.where(mask, x_cumsum, float('-inf')) class Mamba2Block(nn.Module): @@ -427,165 +108,250 @@ class Mamba2Block(nn.Module): super().__init__() self.args = args - d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads - self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias) + # Project input to get various components [z, x, B, C, dt] + projection_size = (2 * args.intermediate_size + 2 * args.n_groups * args.state_size + args.num_heads) + self.in_proj = nn.Linear( + args.hidden_size, + projection_size, + bias=args.use_bias + ) - conv_dim = args.intermediate_size + 2 * args.state_size - self.conv1d = DepthWiseConv1d( + # Convolution layer + conv_dim = args.intermediate_size + 2 * args.n_groups * args.state_size + self.conv1d = nn.Conv1d( in_channels=conv_dim, out_channels=conv_dim, kernel_size=args.conv_kernel, groups=conv_dim, - bias=args.use_conv_bias, - padding=args.conv_kernel - 1 + padding=args.conv_kernel - 1, + bias=args.use_conv_bias ) - self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range - self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range - self.D = mx.random.normal((args.num_heads,)) * args.initializer_range + # SSM parameters + self.dt_bias = mx.zeros(args.num_heads) + self.A_log = mx.zeros(args.num_heads) + self.D = mx.ones(args.num_heads) - self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon) + # Output projections + self.norm = nn.RMSNorm(args.intermediate_size, eps=args.layer_norm_epsilon) self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) - if args.rescale_prenorm_residual: - layer_scale = math.sqrt(1.0 / args.num_hidden_layers) - self.out_proj.weight = self.out_proj.weight * layer_scale + def __call__(self, u: mx.array, cache=None) -> mx.array: + # return self.forward_training(x) if x.shape[1] > 1 else self.forward_inference(x, cache) - def __call__(self, x: mx.array, cache=None): - if cache is not None: - return self.step(x, cache) - - # Regular forward pass code remains the same... - d_model = self.args.intermediate_size - d_state = self.args.state_size - n_heads = self.args.num_heads + # def forward_training(self, u: mx.array) -> mx.array: + # # Reset cache during training + # self.cache = None - A = -mx.exp(self.A_log) - zxbcdt = self.in_proj(x) - - splits = [d_model, d_model + 2 * d_state, n_heads] - z = zxbcdt[:, :, :splits[0]] - xBC = zxbcdt[:, :, splits[0]:splits[0] + splits[1]] - dt = zxbcdt[:, :, -splits[2]:] + # # Input projection and splitting + # zxbcdt = self.in_proj(u) + # z, xBC, dt = mx.split( + # zxbcdt, + # [ + # self.args.hidden_size, + # self.args.hidden_size + 2 * self.args.state_size + # ], + # axis=-1 + # ) + # # Time step processing + # dt = mx.clip( + # nn.softplus(dt + self.dt_bias), + # self.args.time_step_min, + # self.args.time_step_max + # ) + + # # Convolution processing + # xBC_t = mx.transpose(xBC, [0, 2, 1]) + # conv_out = self.conv1d(xBC_t) + # xBC = mx.transpose(conv_out, [0, 2, 1])[:, :u.shape[1]] + # xBC = mx.sigmoid(xBC) * xBC # SiLU + + # # Split states + # x, B, C = mx.split( + # xBC, + # [self.args.hidden_size, self.args.state_size], + # axis=-1 + # ) + + # # Reshape for selective scan + # x = x.reshape((-1, x.shape[1], self.args.num_heads, self.args.head_dim)) + # A = -mx.exp(self.A_log) + + # # Apply selective scan + # y = selective_scan( + # x * dt[..., None], + # A * dt, + # B[..., None, :], + # C[..., None, :], + # self.args.chunk_size + # ) + + # # Output processing + # y = y + x * self.D[None, None, :, None] + # y = y.reshape((-1, y.shape[1], self.args.hidden_size)) + # y = self.norm(y, z) + # y = self.out_proj(y) + + # return y + + # def forward_inference(self, u: mx.array, cache=None) -> mx.array: + # """ + # u: (B, 1, D) + # cache: (h_cache, conv_cache) + # """ + # """Single token processing during inference.""" + # assert u.shape[1] == 1, "Inference mode expects single token" + + # batch_size = u.shape[0] + # # Use provided cache or create new one + # self.cache = cache if cache is not None else Mamba2Cache.get_cache(self.args, batch_size, None) + + # # Project input + # zxbcdt = self.in_proj(u.squeeze(1)) # (B, 2D) + # d_mlp = (zxbcdt.shape[-1] - 2 * self.args.hidden_size - 2 * self.args.n_groups * self.args.state_size - self.args.num_heads) // 2 + + # # (1, 768) (1, 0) (1, 0) (1, 256) (1, 0) (1, 3328) + # y0, z0, x0, z, xBC, dt = mx.split( + # zxbcdt, + # [ + # d_mlp, + # d_mlp, + # self.args.hidden_size, + # self.args.hidden_size + 2 * self.args.n_groups * self.args.state_size, + # self.args.num_heads + # ], + # axis=-1 + # ) + + # # Update convolution state and apply + # conv_state = self.cache.update_conv_state(xBC) + # xBC = mx.sum(conv_state[:, :, -1] * mx.transpose(self.conv1d.weight, [1, 0, 2]), axis=-1) # (B, D) (4, 1792) + + # if self.args.use_conv_bias: + # xBC = xBC + self.conv1d.bias + + # xBC = mx.sigmoid(xBC) * xBC # SiLU (4, 1792) + + # # Split states and ensure proper shapes + # a0, x, B, C = mx.split( + # xBC, # (4, 1792) + # [ + # self.args.hidden_size, + # self.args.n_groups * self.args.state_size, + # self.args.n_groups * self.args.state_size + # ], + # axis=-1 + # ) + + # # SSM step with explicit shapes + # A = -mx.exp(self.A_log) # (num_heads) (24,) + # print(A.shape) # (24,) + # print(dt.shape) # (1, 3328) + # dA = mx.exp(dt * A[None, :]) # Shape: (batch_size, num_heads) <------- her eis the error + + # # Reshape x considering intermediate size + # # x shape should be (batch_size * num_heads, head_dim) + # x = mx.reshape(x, (batch_size, self.args.num_heads, -1)) + # assert x.shape[-1] == self.args.head_dim, f"Head dimension mismatch: {x.shape[-1]} vs {self.args.head_dim}" + + # B = mx.reshape(B, (batch_size, -1)) # Should be (batch_size, state_size) + # C = mx.reshape(C, (batch_size, -1)) # Should be (batch_size, state_size) + + # # Compute dBx with explicit shapes + # dBx = mx.einsum('bh,bs,bhd->bhds', dt, B, x) + + # ssm_state = self.cache.update_ssm_state(dA, dBx) + + # y = mx.einsum('bhds,bs->bhd', ssm_state, C) + # y = y + x * self.D[None, :, None] + # y = mx.reshape(y, (batch_size, self.args.hidden_size)) + + # # Output processing + # y = self.norm(y, z) + + # if d_mlp > 0: + # y = mx.cat([nn.silu(z0) * x0, y], axis=-1) + + # y = self.out_proj(y) + + # return mx.expand_dims(y, 1) + + assert u.shape[1] == 1, "Inference mode expects single token" + + batch_size = u.shape[0] + # Use provided cache or create new one + self.cache = cache if cache is not None else Mamba2Cache.get_cache(self.args, batch_size, None) + + # Project input + zxbcdt = self.in_proj(u.squeeze(1)) # (B, projection_size) + + # Calculate splits based on model dimensions + d_mlp = self.args.intermediate_size + d_state = self.args.state_size * self.args.n_groups + + # Split the projection into its components + splits = [ + d_mlp, # y0 + d_mlp, # z0 + self.args.hidden_size, # x0 + self.args.hidden_size, # z + d_state * 2, # xBC (includes both B and C) + self.args.num_heads # dt + ] + + y0, z0, x0, z, xBC, dt = mx.split(zxbcdt, splits[:-1], axis=-1) + + # Update convolution state and apply + conv_state = self.cache.update_conv_state(xBC) + xBC = mx.sum(conv_state[:, :, -1] * mx.transpose(self.conv1d.weight, [1, 0, 2]), axis=-1) + + if self.args.use_conv_bias: + xBC = xBC + self.conv1d.bias + + xBC = mx.sigmoid(xBC) * xBC # SiLU + + # Split states and reshape + x, BC = mx.split(xBC, [self.args.intermediate_size], axis=-1) + B, C = mx.split(BC, [d_state], axis=-1) + + # Reshape for SSM computation + x = mx.reshape(x, (batch_size, self.args.num_heads, -1)) # (B, H, head_dim) + B = mx.reshape(B, (batch_size, self.args.num_heads, -1)) # (B, H, state_per_head) + C = mx.reshape(C, (batch_size, self.args.num_heads, -1)) # (B, H, state_per_head) + + # Process dt to match expected shape + dt = mx.reshape(dt, (batch_size, self.args.num_heads)) # (B, H) dt = mx.clip( nn.softplus(dt + self.dt_bias), self.args.time_step_min, self.args.time_step_max ) - dt = mx.maximum(dt, self.args.time_step_floor) - - xBC = silu(self.conv1d(xBC)) - - x = xBC[:, :, :d_model] - B = xBC[:, :, d_model:d_model + d_state] - C = xBC[:, :, -d_state:] - - b, l, hp = x.shape - h = self.args.num_heads - p = hp // h - x = mx.reshape(x, (b, l, h, p)) - - y, ssm_state = ssd(x * mx.expand_dims(dt, -1), A * dt, B, C, self.args.chunk_size) - y = y + x * mx.expand_dims(self.D, -1) - y = mx.reshape(y, (b, l, h * p)) - - y = self.norm(y + z) + + # SSM step + A = -mx.exp(self.A_log) # (H,) + dA = mx.exp(dt * A[None, :]) # (B, H) + + # Compute dBx + dBx = mx.einsum('bh,bhs,bhd->bhds', dt, B, x) + + # Update SSM state and compute output + ssm_state = self.cache.update_ssm_state(dA, dBx) + y = mx.einsum('bhds,bhs->bhd', ssm_state, C) + y = y + x * self.D[None, :, None] + + # Reshape output + y = mx.reshape(y, (batch_size, self.args.hidden_size)) + + # Final output processing + y = self.norm(y, z) + + if d_mlp > 0: + y = mx.concat([nn.silu(z0) * x0, y], axis=-1) + y = self.out_proj(y) - - if self.args.residual_in_fp32: - y = y.astype(mx.float32) - - return y - - def step(self, u: mx.array, cache): - batch_size = u.shape[0] - seq_len = u.shape[1] - outputs = [] - - # Initialize cache if needed - if cache.conv_states is None: - conv_dim = self.args.intermediate_size + 2 * self.args.state_size - cache.conv_states = mx.zeros(( - batch_size, - self.args.conv_kernel - 1, - conv_dim - )) - - if cache.ssm_state is None: - cache.ssm_state = mx.zeros(( - batch_size, - self.args.num_heads, - self.args.head_dim, - self.args.state_size - )) - - for pos in range(seq_len): - u_t = u[:, pos:pos+1, :] - zxbcdt = self.in_proj(u_t) - - d_model = self.args.intermediate_size - d_state = self.args.state_size - n_heads = self.args.num_heads - - z = zxbcdt[:, :, :d_model] - xBC = zxbcdt[:, :, d_model:d_model + 2*d_state + d_model] - dt = zxbcdt[:, :, -(n_heads):] - - dt = mx.reshape(dt, (batch_size, n_heads)) - dt = mx.clip( - nn.softplus(dt + self.dt_bias), - self.args.time_step_min, - self.args.time_step_max - ) - dt = mx.maximum(dt, self.args.time_step_floor) - - # Create a temporary cache dictionary for the convolution - conv_cache = {'conv_states': cache.conv_states} - xBC = self.conv1d(xBC, cache=conv_cache) - cache.conv_states = conv_cache['conv_states'] - - xBC = silu(xBC) - - x = xBC[:, :, :d_model] - B = xBC[:, :, d_model:d_model + d_state] - C = xBC[:, :, -d_state:] - - x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim)) - x = mx.squeeze(x, axis=1) - - B = mx.reshape(B, (batch_size, 1, d_state)) - B = mx.broadcast_to(B, (batch_size, n_heads, d_state)) - B = mx.expand_dims(B, axis=2) - - C = mx.reshape(C, (batch_size, 1, d_state)) - C = mx.broadcast_to(C, (batch_size, n_heads, d_state)) - C = mx.expand_dims(C, axis=3) - - A = -mx.exp(self.A_log) - dA = mx.exp(dt * mx.expand_dims(A, 0)) - dA = mx.expand_dims(mx.expand_dims(dA, -1), -1) - - x = mx.expand_dims(x, axis=3) - dBx = mx.matmul(x, B) - - cache.ssm_state = cache.ssm_state * dA + dBx - - y = mx.matmul(cache.ssm_state, C) - y = mx.squeeze(y, axis=-1) - - y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1) - - y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim)) - y = self.norm(y + z) - y = self.out_proj(y) - - if self.args.residual_in_fp32: - y = y.astype(mx.float32) - - outputs.append(y) - - return mx.concatenate(outputs, axis=1) + + return mx.expand_dims(y, 1) # (B, 1, D) class ResidualBlock(nn.Module): @@ -594,11 +360,12 @@ class ResidualBlock(nn.Module): self.mixer = Mamba2Block(args) self.norm = nn.RMSNorm(args.hidden_size) - def __call__(self, x: mx.array, cache): - return self.mixer(self.norm(x), cache) + x + def __call__(self, x: mx.array, cache=None) -> mx.array: + # x : (B, L, D) + return self.mixer(self.norm(x), cache) + x # (B, L, D) -class Mamba2(nn.Module): +class Mamba2Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args = args @@ -606,12 +373,15 @@ class Mamba2(nn.Module): self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) - def __call__(self, x: mx.array, cache): + def __call__(self, x: mx.array, cache=None) -> mx.array: + # x : (B, L) x = self.embeddings(x) + # x : (B, L, D) if cache is None: cache = [None] * len(self.layers) - for layer, c in zip(self.layers, cache): - x = layer(x, c) + + for layer, layer_cache in zip(self.layers, cache): + x = layer(x, layer_cache) return self.norm_f(x) @@ -619,14 +389,13 @@ class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args = args - self.model_type = args.model_type + self.backbone = Mamba2Model(args) - self.backbone = Mamba2(args) - if not args.tie_word_embeddings: self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - def __call__(self, inputs: mx.array, cache=None): + def __call__(self, inputs: mx.array, cache=None) -> mx.array: + # inputs : (B, L) B, T = inputs.shape x = self.backbone(inputs, cache) @@ -637,24 +406,19 @@ class Model(nn.Module): logits = self.lm_head(x) return logits - - def make_cache(self): - return [Mamba2Cache() for _ in range(len(self.layers))] + + def make_cache(self, batch_size=1): + return [Mamba2Cache( + batch_size=batch_size, + hidden_size=self.args.hidden_size, + state_size=self.args.state_size, + conv_kernel=self.args.conv_kernel, + num_heads=self.args.num_heads, + head_dim=self.args.head_dim + ) for _ in range(len(self.backbone.layers))] def sanitize(self, weights): - sanitized = {} for k, v in weights.items(): - if "conv1d.weight" in k: - # Ensure weights are in correct shape (channels, 1, kernel_size) - if v.ndim == 2: - v = mx.expand_dims(v, axis=1) - elif v.ndim == 1: - v = mx.expand_dims(mx.expand_dims(v, axis=0), axis=0) - sanitized[k] = v - else: - sanitized[k] = v - return sanitized - - @property - def layers(self): - return self.backbone.layers + if "conv1d.weight" in k and v.ndim == 3: + weights[k] = v.moveaxis(2, 1) + return weights \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba2-prch.py b/llms/mlx_lm/models/mamba2-prch.py index f988a825..84bf2174 100644 --- a/llms/mlx_lm/models/mamba2-prch.py +++ b/llms/mlx_lm/models/mamba2-prch.py @@ -1,437 +1,490 @@ -""" -mamba2-minimal -============== +# coding=utf-8 +# Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch MAMBA2 model.""" -A minimal, single-file implementation of the Mamba-2 model in PyTorch. - -> **Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality** -> Authors: Tri Dao, Albert Gu -> Paper: https://arxiv.org/abs/2405.21060 -""" - -import json +import math from dataclasses import dataclass -from typing import Iterable, NamedTuple, TypeAlias, cast +from typing import Optional, Tuple, Union import torch -import torch.nn.functional as F -from einops import rearrange, repeat -from torch import LongTensor, Tensor, nn +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss -Device: TypeAlias = str | torch.device | None +logger = logging.get_logger(__name__) -@dataclass -class Mamba2Config: - d_model: int # model dimension (D) - n_layer: int = 24 # number of Mamba-2 layers in the language model - d_state: int = 128 # state dimension (N) - d_conv: int = 4 # convolution kernel size - expand: int = 2 # expansion factor (E) - headdim: int = 64 # head dimension (P) - chunk_size: int = 64 # matrix partition size (Q) - vocab_size: int = 50277 - pad_vocab_size_multiple: int = 16 +def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): + """ + Padding x tensor with `pad_size` on the seq_len dim (dim=1) - def __post_init__(self): - self.d_inner = self.expand * self.d_model - assert self.d_inner % self.headdim == 0 - self.nheads = self.d_inner // self.headdim - if self.vocab_size % self.pad_vocab_size_multiple != 0: - self.vocab_size += ( - self.pad_vocab_size_multiple - - self.vocab_size % self.pad_vocab_size_multiple - ) + Assumes that we only have tensors of either size 4 or 3 + """ + pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) + + return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) -class InferenceCache(NamedTuple): - conv_state: Tensor # (batch, d_inner + 2 * d_state, d_conv) - ssm_state: Tensor # (batch, nheads, headdim, d_state) +def reshape_into_chunks(input_tensor, pad_size, chunk_size): + """ + Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and + simultaneously splitting it into chunk sequences. - @staticmethod - def alloc(batch_size: int, args: Mamba2Config, device: Device = None): - return InferenceCache( - torch.zeros( - batch_size, args.d_inner + 2 * args.d_state, args.d_conv, device=device - ), - torch.zeros( - batch_size, args.nheads, args.headdim, args.d_state, device=device - ), + Assumes that we only have tensors of either size 4 or 3 + """ + # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] + input_tensor = pad_tensor_by_size(input_tensor, pad_size) + + if len(input_tensor.shape) == 3: + # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] + return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) + else: + # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size] + return input_tensor.reshape( + input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3] ) -class Mamba2LMHeadModel(nn.Module): - def __init__(self, args: Mamba2Config, device: Device = None): +def segment_sum(input_tensor): + """ + More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions. + """ + chunk_size = input_tensor.size(-1) + # 1. expand input tensor to have an additional dimension and repeat along that dimension + # [..., chunk_size] -> [..., chunk_size, chunk_size] + input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size) + # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1) + input_tensor = input_tensor.masked_fill(~mask, 0) + # 3. compute actual cumsum + tensor_segsum = torch.cumsum(input_tensor, dim=-2) + + # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time) + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0) + tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf) + return tensor_segsum + + +class Mamba2Cache: + """ + Arguments: + config: ModelArgs + batch_size: int + dtype: torch.dtype + device: torch.device + + Attributes: + seqlen_offset: int + dtype: torch.dtype + conv_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, conv_kernel_size] + ssm_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, ssm_state_size] + """ + + def __init__( + self, config: ModelArgs, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None + ): + self.seqlen_offset = 0 + self.dtype = dtype + self.conv_kernel_size = config.conv_kernel + self.intermediate_size = int(config.expand * config.hidden_size) + + self.conv_states = { + i: torch.zeros( + batch_size, + self.intermediate_size + 2 * config.n_groups * config.state_size, + self.conv_kernel_size, + device=device, + dtype=dtype, + ) + for i in range(config.num_hidden_layers) + } + self.ssm_states = { + i: torch.zeros( + batch_size, config.num_heads, config.head_dim, config.state_size, device=device, dtype=dtype + ) + for i in range(config.num_hidden_layers) + } + + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor + ) -> torch.Tensor: + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) + + conv_state = conv_state.roll(shifts=-1, dims=-1) + conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device) + self.conv_states[layer_idx].zero_() + self.conv_states[layer_idx] += conv_state + return self.conv_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + +class MambaRMSNormGated(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6): super().__init__() - self.args = args - self.device = device + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps - self.backbone = nn.ModuleDict( - dict( - embedding=nn.Embedding(args.vocab_size, args.d_model, device=device), - layers=nn.ModuleList( - [ - nn.ModuleDict( - dict( - mixer=Mamba2(args, device=device), - norm=RMSNorm(args.d_model, device=device), - ) - ) - for _ in range(args.n_layer) - ] - ), - norm_f=RMSNorm(args.d_model, device=device), + def forward(self, hidden_states, gate=None): + input_dtype = hidden_states.dtype + hidden_states = hidden_states + + if gate is not None: + hidden_states = hidden_states * nn.functional.silu(gate) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + return self.weight * hidden_states + + +class Mamba2Mixer(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.num_heads = config.num_heads + self.hidden_size = config.hidden_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + self.intermediate_size = int(config.expand * self.hidden_size) + self.time_step_rank = int(config.time_step_rank) + self.use_conv_bias = config.use_conv_bias + self.act = nn.silu + + self.layer_norm_epsilon = config.layer_norm_epsilon + self.rms_norm = config.rms_norm + + self.n_groups = config.n_groups + self.head_dim = config.head_dim + self.chunk_size = config.chunk_size + + self.time_step_limit = config.time_step_limit + self.time_step_min = config.time_step_min + self.time_step_max = config.time_step_max + + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=config.use_conv_bias, + kernel_size=config.conv_kernel, + groups=self.conv_dim, + padding=config.conv_kernel - 1, + ) + + # projection of the input hidden states + projection_size = self.intermediate_size + self.conv_dim + self.num_heads + self.in_proj = nn.Linear( + self.hidden_size, + projection_size, + bias=config.use_bias, + ) + + self.dt_bias = torch.ones(self.num_heads) + A = torch.arange(1, self.num_heads + 1) + self.A_log = torch.log(A) + self.D = torch.ones(self.num_heads) + + self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon) + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) + + def forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, cache_position:Optional[torch.LongTensor]=None): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + + # Gated MLP's linear projection + projected_states = self.in_proj(input_states.squeeze(1)) + d_mlp = ( + projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size- self.num_heads) // 2 + _, _, gate, hidden_states, dt = projected_states.split( + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # Convolution sequence transformation + if cache_params is not None: + ssm_state = cache_params.ssm_states[self.layer_idx].clone() + ssm_state = ssm_state.to(hidden_states.device) + if cache_params.seqlen_offset > 0: + conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size] + conv_state = torch.roll(conv_state, shifts=-1, dims=-1) + # handle batched generation - states are copied through + conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states + cache_params.conv_states[self.layer_idx].copy_(conv_state) + hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + if self.use_conv_bias: + hidden_states += self.conv1d.bias + hidden_states = self.act(hidden_states)[:, None, ...] # [batch, 1, intermediate_size] : decoding + else: + hidden_states = hidden_states.transpose(1,2) + conv_state = nn.functional.pad( + hidden_states, + (self.conv_kernel_size - hidden_states.shape[-1], 0) + ) + cache_params.conv_states[self.layer_idx].copy_(conv_state) + hidden_states = self.act(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len] + else: + ssm_state = torch.zeros( + (batch_size, self.num_heads, self.head_dim, self.ssm_state_size), + device=hidden_states.device ) - ) - self.lm_head = nn.Linear( - args.d_model, args.vocab_size, bias=False, device=device - ) - self.lm_head.weight = self.backbone.embedding.weight + hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2)) - @staticmethod - def from_pretrained(huggingface_model_id: str, device: Device = None): - from transformers.utils import CONFIG_NAME, WEIGHTS_NAME - from transformers.utils.hub import cached_file + hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1) + A = -torch.exp(self.A_log.float()) # [num_heads] - config_path = cached_file(huggingface_model_id, CONFIG_NAME) - assert config_path, "Failed to get huggingface config file" - state_dict_path = cached_file(huggingface_model_id, WEIGHTS_NAME) - assert state_dict_path, "Failed to get huggingface state dict file" + if cache_params is not None and cache_params.seqlen_offset > 0: + # Note: there is no need to pad parameter matrices here, as there is just one new token + # for batched generation + dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...] + dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) + # [num_heads] -> [num_heads, head_dim] + dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) - config = json.load(open(config_path)) - args = Mamba2Config( - d_model=config["d_model"], - n_layer=config["n_layer"], - vocab_size=config["vocab_size"], - pad_vocab_size_multiple=config["pad_vocab_size_multiple"], - ) + dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype)) + dt = torch.clamp(dt, self.time_step_min) #, self.time_step_max) + A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + # [bsz, num_heads, head_dim, state_size] + dA = torch.exp(dt[..., None] * A) - map_location = "cpu" if device is None else device - state_dict = torch.load( - state_dict_path, weights_only=True, map_location=map_location, mmap=True - ) - model = Mamba2LMHeadModel(args, device=device) - model.load_state_dict(state_dict) - model.eval() - return model + # Discretize B + # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> + # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] + B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] + B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous() + B = B.reshape(batch_size, -1, B.shape[-1]) + # [bsz, num_heads, head_dim, state_size] + dB = dt[..., None] * B[..., None, :] + + # Discretize x into dB + # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] + hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) + dBx = dB * hidden_states[..., None] + + # State calculation + cache_params.ssm_states[self.layer_idx].copy_( + cache_params.ssm_states[self.layer_idx] * dA + dBx + ) + + # Subsequent output + # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] + C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] + C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous() + C = C.reshape(batch_size, -1, C.shape[-1]) + # [bsz, num_heads, head_dim] + + ssm_states = cache_params.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n] + # Reshape ssm_states to merge the first two dimensions + ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] + C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] + y = torch.bmm(ssm_states_reshaped, C_reshaped) + y = y.view(batch_size, self.num_heads, self.head_dim) + + # D skip connection + # [num_heads] -> [num_heads, head_dim] + D = self.D[..., None].expand(self.D.shape[0], self.head_dim) + y = (y + hidden_states * D).to(y.dtype) + + # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size] + y = y.reshape(batch_size, -1)[:, None, ...] + else: + # begin ssd naive implementation without einsums + dt = nn.functional.softplus(dt + self.dt_bias) + dt = torch.clamp(dt, self.time_step_min) + hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() + B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) + C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) + pad_size = self.chunk_size - (seq_len % self.chunk_size) + + D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) + + # Discretize x and A + hidden_states = hidden_states * dt[..., None] + A = A.to(hidden_states.dtype) * dt + + # Rearrange into blocks/chunks + hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] + + + # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] + A = A.permute(0, 3, 1, 2) + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + # This is the analog of a causal mask + L = torch.exp(segment_sum(A)) + + # First, contraction of C and B to get G (attention-weights like) + G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n) + G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) + + + # Step 2: Compute M, equivalent to applying attention mask to weights + M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] + M = M_intermediate.sum(dim=-1) + + # Step 3: Compute Y_diag (apply to values) + Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3) + + # (right term of low-rank factorization of off-diagonal blocks; B terms) + + decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None] + # permute back B * decay states + states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3) + if cache_params is not None and cache_params.seqlen_offset > 0: + previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...] + else: + previous_states = torch.zeros_like(states[:, :1]) + states = torch.cat([previous_states, states], dim=1) + decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) + + states_permuted = states.permute(0, 2, 1, 3, 4) + result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2) + new_states = result.permute(0, 2, 1, 3, 4) + states, ssm_state = new_states[:, :-1], new_states[:, -1] + + # Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + # compute Yoff + C_times_states = (C[..., None, :] * states[:, :, None, ...]) + state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) + Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + + y = Y_diag + Y_off + # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] + y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) + + y = y + D_residual + # Cutting off padded chunks + if pad_size > 0: + y = y[:, :seq_len, :, :] + y = y.reshape(batch_size, seq_len, -1) + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + scan_output = self.norm(y, gate) + + # end ssd naive + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output) # [batch, seq_len, hidden_size] + return contextualized_states + + +class Mamba2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Mamba2RMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states + + +class Mamba2Block(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.norm = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.mixer = Mamba2Mixer(config) def forward( - self, input_ids: LongTensor, h: list[InferenceCache] | list[None] | None = None - ) -> tuple[LongTensor, list[InferenceCache]]: - """ - Arguments - input_ids: (batch, seqlen) tokens from `EleutherAI/gpt-neox-20b` tokenizer - h: hidden states for inference step. If present the constant-time - (wrt sequence length) inference path will be taken, input_ids - should have shape (batch, 1) containing the next batch of prompt - token. - - Return (logits, h) - logits: (batch, seqlen, vocab_size) - h: updated inference cache after processing `input_ids` - """ - seqlen = input_ids.shape[1] - - if h is None: - h = [None for _ in range(self.args.n_layer)] - - x = self.backbone.embedding(input_ids) - for i, layer in enumerate(self.backbone.layers): - y, h[i] = layer.mixer(layer.norm(x), h[i]) - x = y + x - - x = self.backbone.norm_f(x) - logits = self.lm_head(x) - return logits[:, :seqlen], cast(list[InferenceCache], h) - - def generate( self, - input_ids: LongTensor, - max_new_length: int = 20, - temperature: float = 1.0, - top_k: int = 50, - top_p: float = 1.0, - eos_token_id: int = 0, - ) -> Iterable[tuple[int, list[InferenceCache]]]: - prefix, tokens = input_ids[:-1], input_ids[-1:].unsqueeze(0) + hidden_states, + cache_params: Optional[Mamba2Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + ): + x = self.mixer( + self.norm(hidden_states), cache_params=cache_params, cache_position=cache_position + ) + return x + hidden_states - # Process prompt - # The input sequence to forward (non-inference path) must have length multiple that of chunk_size. - # We split out excess tokens so that n_chunked tokens can be processed by one forward call and - # process the rest in multiple inference steps. - n_chunked = (prefix.shape[0] // self.args.chunk_size) * self.args.chunk_size - if n_chunked > 0: - _, h = self(prefix[:n_chunked].unsqueeze(0), None) + +class Mamba2Model(nn.Module): + def __init__(self, config): + super().__init__(config) + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList([Mamba2Block(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]) + + self.gradient_checkpointing = False + self.norm_f = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + cache_params: Optional[Mamba2Cache] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ): + inputs_embeds = self.embeddings(input_ids) + + if use_cache: + if cache_params is None: + cache_params = Mamba2Cache( + self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) + cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device) else: - h = [ - InferenceCache.alloc(1, self.args, device=self.device) - for _ in range(self.args.n_layer) - ] - for i in range(n_chunked, prefix.shape[0]): - _, h = self(prefix[i : i + 1].unsqueeze(0), h) + cache_params = None - # Generate - for _ in range(max_new_length): - with torch.no_grad(): - out, h = self(tokens, h) - logits = out[0, -1] - if temperature != 1.0: - logits = logits / temperature - if top_k > 0: - indices_to_remove = logits < torch.topk(logits, k=top_k)[0][-1] - logits[indices_to_remove] = -torch.inf - if top_p < 1.0: - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) - sorted_indices_to_remove = cum_probs > 0.5 - sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() - sorted_indices_to_remove[0] = False - indices_to_remove = sorted_indices[sorted_indices_to_remove] - logits[indices_to_remove] = -torch.inf - probs = F.softmax(logits, dim=-1) - next_token = torch.multinomial(probs, num_samples=1) - if next_token.item() == eos_token_id: - return - tokens = next_token.unsqueeze(0) - yield cast(int, next_token.item()), h + hidden_states = inputs_embeds + for mixer_block in self.layers: + hidden_states = mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + ) + + if use_cache: + cache_params.seqlen_offset += inputs_embeds.shape[1] + + return self.norm_f(hidden_states), cache_params if use_cache else None -class Mamba2(nn.Module): - def __init__(self, args: Mamba2Config, device: Device = None): - super().__init__() - self.args = args - self.device = device - # Order: (z, x, B, C, dt) - d_in_proj = 2 * args.d_inner + 2 * args.d_state + args.nheads - self.in_proj = nn.Linear(args.d_model, d_in_proj, bias=False, device=device) +class Mamba2ForCausalLM(nn.Module): + def __init__(self, config): + super().__init__(config) + self.backbone = Mamba2Model(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - conv_dim = args.d_inner + 2 * args.d_state - self.conv1d = nn.Conv1d( - in_channels=conv_dim, - out_channels=conv_dim, - kernel_size=args.d_conv, - groups=conv_dim, - padding=args.d_conv - 1, - device=device, + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + cache_params: Optional[Mamba2Cache] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ): + mamba2_outputs = self.backbone( + input_ids, + cache_params=cache_params, + use_cache=use_cache, + cache_position=cache_position, ) + hidden_states = mamba2_outputs[0] - self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device)) - self.A_log = nn.Parameter(torch.empty(args.nheads, device=device)) - self.D = nn.Parameter(torch.empty(args.nheads, device=device)) - self.norm = RMSNorm(args.d_inner, device=device) - self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device) - - def forward(self, u: Tensor, h: InferenceCache | None = None): - """ - Arguments - u: (batch, seqlen, d_model) input. seqlen should be a multiple of chunk_size. - h: hidden states for inference step. Initialized to 0s if not present. - - Return (y, h) - y: (batch, seqlen, d_model) output - h: updated inference cache after processing `u` - """ - if h: - return self.step(u, h) - - A = -torch.exp(self.A_log) # (nheads,) - zxbcdt = self.in_proj(u) # (batch, seqlen, d_in_proj) - z, xBC, dt = torch.split( - zxbcdt, - [ - self.args.d_inner, - self.args.d_inner + 2 * self.args.d_state, - self.args.nheads, - ], - dim=-1, - ) - dt = F.softplus(dt + self.dt_bias) # (batch, seqlen, nheads) - - # Pad or truncate xBC seqlen to d_conv - conv_state = F.pad( - rearrange(xBC, "b l d -> b d l"), (self.args.d_conv - u.shape[1], 0) - ) - - xBC = silu( - self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, : u.shape[1], :] - ) # (batch, seqlen, d_inner + 2 * d_state)) - x, B, C = torch.split( - xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1 - ) - x = rearrange(x, "b l (h p) -> b l h p", p=self.args.headdim) - y, ssm_state = ssd( - x * dt.unsqueeze(-1), - A * dt, - rearrange(B, "b l n -> b l 1 n"), - rearrange(C, "b l n -> b l 1 n"), - self.args.chunk_size, - device=self.device, - ) - y = y + x * self.D.unsqueeze(-1) - y = rearrange(y, "b l h p -> b l (h p)") - y = self.norm(y, z) - y = self.out_proj(y) - - h = InferenceCache(conv_state, ssm_state) - return y, h - - def step(self, u: Tensor, h: InferenceCache) -> tuple[Tensor, InferenceCache]: - """Take a single inference step for the current input and hidden state - - Unlike attention-based models, RNN-based models (eg Mamba) does not need - to look back at all the past tokens to generate a new token. Instead a - hidden state (initialized to 0s initially) is updated for each input and - passed to the next inference step. This means that the total inference - time is linear with respect to the sequence length instead of quadratic - in attention's case. - - Arguments - u: (batch, 1, d_model) - h: initial/running hidden state - - Return (y, h) - y: (batch, 1, d_model) - h: updated hidden state - """ - assert u.shape[1] == 1, "Only one token can be decoded per inference step" - - zxbcdt = self.in_proj(u.squeeze(1)) # (batch, d_in_proj) - z, xBC, dt = torch.split( - zxbcdt, - [ - self.args.d_inner, - self.args.d_inner + 2 * self.args.d_state, - self.args.nheads, - ], - dim=-1, - ) - - # Advance convolution input - h.conv_state.copy_(torch.roll(h.conv_state, shifts=-1, dims=-1)) - h.conv_state[:, :, -1] = xBC - # Convolution step - xBC = torch.sum( - h.conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1 - ) - xBC += self.conv1d.bias - xBC = silu(xBC) - - x, B, C = torch.split( - xBC, [self.args.d_inner, self.args.d_state, self.args.d_state], dim=-1 - ) - A = -torch.exp(self.A_log) # (nheads,) - - # SSM step - dt = F.softplus(dt + self.dt_bias) # (batch, nheads) - dA = torch.exp(dt * A) # (batch, nheads) - x = rearrange(x, "b (h p) -> b h p", p=self.args.headdim) - dBx = torch.einsum("bh, bn, bhp -> bhpn", dt, B, x) - h.ssm_state.copy_(h.ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx) - y = torch.einsum("bhpn, bn -> bhp", h.ssm_state, C) - y = y + rearrange(self.D, "h -> h 1") * x - y = rearrange(y, "b h p -> b (h p)") - y = self.norm(y, z) - y = self.out_proj(y) - - return y.unsqueeze(1), h - - -def segsum(x: Tensor, device: Device = None) -> Tensor: - """Stable segment sum calculation. - - `exp(segsum(A))` produces a 1-semiseparable matrix, which is equivalent to a scalar SSM. - - Source: https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L23-L32 - """ - T = x.size(-1) - x = repeat(x, "... d -> ... d e", e=T) - mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=-1) - x = x.masked_fill(~mask, 0) - x_segsum = torch.cumsum(x, dim=-2) - mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=0) - x_segsum = x_segsum.masked_fill(~mask, -torch.inf) - return x_segsum - - -def ssd(x, A, B, C, chunk_size, initial_states=None, device: Device = None): - """Structed State Space Duality (SSD) - the core of Mamba-2 - - This is almost the exact same minimal SSD code from the blog post. - - Arguments - x: (batch, seqlen, n_heads, d_head) - A: (batch, seqlen, n_heads) - B: (batch, seqlen, n_heads, d_state) - C: (batch, seqlen, n_heads, d_state) - - Return - y: (batch, seqlen, n_heads, d_head) - - Source - 1. https://tridao.me/blog/2024/mamba2-part3-algorithm/ - 2. https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L34-L78 - """ - assert x.shape[1] % chunk_size == 0 - - # Rearrange into chunks - # Step 1, 2 and 4 of SSD can be computed in parallel for each chunk across devices (sequence parallel) - # This is not implemented and left as an exercise for the reader 😜 - x, A, B, C = [ - rearrange(m, "b (c l) ... -> b c l ...", l=chunk_size) for m in (x, A, B, C) - ] - - A = rearrange(A, "b c l h -> b h c l") - A_cumsum = torch.cumsum(A, dim=-1) - - # 1. Compute the output for each intra-chunk (diagonal blocks) - L = torch.exp(segsum(A, device=device)) - Y_diag = torch.einsum("bclhn, bcshn, bhcls, bcshp -> bclhp", C, B, L, x) - - # 2. Compute the state for each intra-chunk - # (right term of low-rank factorization of off-diagonal blocks; B terms) - decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum) - states = torch.einsum("bclhn, bhcl, bclhp -> bchpn", B, decay_states, x) - - # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries - # (middle term of factorization of off-diag blocks; A terms) - if initial_states is None: - initial_states = torch.zeros_like(states[:, :1]) - states = torch.cat([initial_states, states], dim=1) - decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)), device=device)) - new_states = torch.einsum("bhzc, bchpn -> bzhpn", decay_chunk, states) - states, final_state = new_states[:, :-1], new_states[:, -1] - - # 4. Compute state -> output conversion per chunk - # (left term of low-rank factorization of off-diagonal blocks; C terms) - state_decay_out = torch.exp(A_cumsum) - Y_off = torch.einsum("bclhn, bchpn, bhcl -> bclhp", C, states, state_decay_out) - - # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) - Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p") - - return Y, final_state - - -class RMSNorm(nn.Module): - def __init__(self, d: int, eps: float = 1e-5, device: Device = None): - """Gated Root Mean Square Layer Normalization - - Paper: https://arxiv.org/abs/1910.07467 - """ - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(d, device=device)) - - def forward(self, x, z=None): - if z is not None: - x = x * silu(z) - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight - - -def silu(x): - """Applies the Sigmoid Linear Unit (SiLU), element-wise. - - Define this manually since torch's version doesn't seem to work on MPS. - """ - return x * F.sigmoid(x) \ No newline at end of file + logits = self.lm_head(hidden_states) + return logits, mamba2_outputs.cache_params, mamba2_outputs.hidden_states \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index cb78f316..bd0f17ee 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -32,259 +32,272 @@ class ModelArgs(BaseModelArgs): rms_norm: bool chunk_size: int tie_word_embeddings: bool + intermediate_size: int = None time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf"))) time_step_rank: Union[int, str] = "auto" model_type: str = "mamba2" def __post_init__(self): - if not hasattr(self, "intermediate_size"): - self.intermediate_size = int(self.expand * self.hidden_size) + self.intermediate_size = int(self.expand * self.hidden_size) # E*D = ED + if not hasattr(self, "head_dim"): self.head_dim = self.hidden_size // self.num_heads if self.time_step_rank == "auto": self.time_step_rank = math.ceil(self.hidden_size / 16) - -def selective_scan(x, A, B, C, chunk_size): - """ - Selective scan implementation for training. - Arguments - x: (batch, seqlen, n_heads, d_head) - A: (batch, seqlen, n_heads) - B: (batch, seqlen, n_heads, d_state) - C: (batch, seqlen, n_heads, d_state) +class MambaRMSNormGated(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = mx.ones(hidden_size) + self.variance_epsilon = eps - Return - y: (batch, seqlen, n_heads, d_head) - """ - assert x.shape[1] % chunk_size == 0 - - # Reshape into chunks - def chunk_reshape(m): - shape = list(m.shape) - shape[1:2] = [shape[1] // chunk_size, chunk_size] - return m.reshape(shape) - - x, A, B, C = map(chunk_reshape, (x, A, B, C)) - A = mx.transpose(A, [0, 3, 1, 2]) - - # Compute cumulative sums - A_cumsum = mx.cumsum(A, axis=-1) - - # Process chunks - L = mx.exp(selective_cumsum(A)) - Y_diag = mx.einsum('bclhn,bcshn,bhcls,bcshp->bclhp', C, B, L, x) - - decay_states = mx.exp(A_cumsum[..., -1:] - A_cumsum) - states = mx.einsum('bclhn,bhcl,bclhp->bchpn', B, decay_states, x) - - initial_states = mx.zeros_like(states[:, :1]) - states = mx.concatenate([initial_states, states], axis=1) - decay_chunk = mx.exp(selective_cumsum(mx.pad(A_cumsum[..., -1], ((0,0), (0,0), (1,0))))) - new_states = mx.einsum('bhzc,bchpn->bzhpn', decay_chunk, states) - states = new_states[:, :-1] - - state_decay_out = mx.exp(A_cumsum) - Y_off = mx.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out) - - Y = (Y_diag + Y_off).reshape((-1, x.shape[1] * chunk_size, *Y_diag.shape[-2:])) - return Y + def forward(self, hidden_states, gate=None): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(mx.float32) -def selective_cumsum(x: mx.array) -> mx.array: - """Stable selective cumulative sum calculation.""" - T = x.shape[-1] - x = mx.repeat(x[..., None], T, axis=-1) - mask = mx.tril(mx.ones((T, T)), k=-1) - x = x * mask - x_cumsum = mx.cumsum(x, axis=-2) - mask = mx.tril(mx.ones((T, T)), k=0) - return mx.where(mask, x_cumsum, float('-inf')) + if gate is not None: + hidden_states = hidden_states * nn.functional.silu(gate.to(mx.float32)) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * math.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + - -class Mamba2Block(nn.Module): +class Mamba2Mixer(nn.Module): def __init__(self, args: ModelArgs): super().__init__() - self.args = args + # Model dimensions + self.hidden_size = args.hidden_size + self.num_heads = args.num_heads + self.head_dim = args.head_dim + self.ssm_state_size = args.state_size + self.n_groups = args.n_groups + self.intermediate_size = int(args.expand * args.hidden_size) - # Internal cache state - self.conv_state = None - self.ssm_state = None + # Convolution parameters + self.conv_kernel = args.conv_kernel + self.use_conv_bias = args.use_conv_bias - # Project input to get various components - d_in_proj = (2 * args.intermediate_size + 2 * self.args.n_groups * args.state_size + args.num_heads) + # Time step parameters + self.time_step_rank = int(args.time_step_rank) + self.time_step_min = args.time_step_min + self.time_step_max = args.time_step_max + + # Processing parameters + self.chunk_size = args.chunk_size + self.layer_norm_epsilon = args.layer_norm_epsilon + + # Calculate dimensions + self.conv_dim = (self.intermediate_size + + 2 * self.n_groups * self.ssm_state_size) + projection_size = (self.intermediate_size + + self.conv_dim + + self.num_heads) + + # Initialize layers self.in_proj = nn.Linear( - args.hidden_size, - d_in_proj, + self.hidden_size, + projection_size, + bias=args.use_bias + ) + + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + kernel_size=self.conv_kernel, + groups=self.conv_dim, + padding=self.conv_kernel - 1, + bias=self.use_conv_bias + ) + + # Initialize parameters + self.dt_bias = mx.ones(self.num_heads) + A = mx.arange(1, self.num_heads + 1) + self.A_log = mx.log(A) + self.D = mx.ones(self.num_heads) + + # Output layers + self.norm = MambaRMSNormGated( + self.intermediate_size, + eps=self.layer_norm_epsilon + ) + self.out_proj = nn.Linear( + self.intermediate_size, + self.hidden_size, bias=args.use_bias ) - # Convolution layer - conv_dim = args.intermediate_size + 2 * self.args.n_groups * args.state_size - self.conv1d = nn.Conv1d( - in_channels=conv_dim, - out_channels=conv_dim, - kernel_size=args.conv_kernel, - groups=conv_dim, - padding=args.conv_kernel - 1, - bias=args.use_conv_bias - ) - - # SSM parameters - dt_init_floor = math.log(args.time_step_floor) - self.dt_bias = mx.zeros((args.num_heads,)) * args.initializer_range - self.A_log = mx.zeros((args.num_heads,)) * args.initializer_range - self.D = mx.zeros((args.num_heads,)) * args.initializer_range - - # Output projections - self.norm = nn.RMSNorm(args.intermediate_size, eps=args.layer_norm_epsilon) - self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) - - def __call__(self, x: mx.array, cache=None) -> mx.array: - return self.forward_training(x) if x.shape[1] > 1 else self.forward_inference(x, cache) - - def forward_training(self, u: mx.array) -> mx.array: - # Reset cache during training - self.cache = None + def reshape_into_chunks(self, tensor, pad_size, chunk_size): + if pad_size > 0: + pad_shape = list(tensor.shape) + pad_shape[1] = pad_size + padding = mx.zeros(pad_shape, dtype=tensor.dtype) + tensor = mx.concatenate([tensor, padding], axis=1) - # Input projection and splitting - zxbcdt = self.in_proj(u) - z, xBC, dt = mx.split( - zxbcdt, - [ - self.args.intermediate_size, - self.args.intermediate_size + 2 * self.args.state_size - ], - axis=-1 - ) + chunk_shape = list(tensor.shape) + chunk_shape[1] = -1 + chunk_shape.insert(2, chunk_size) + return tensor.reshape(chunk_shape) - # Time step processing + def segment_sum(self, x): + return mx.cumsum(x, axis=-1) + + def process_single_token(self, hidden_states, B, C, dt, cache): + batch_size = hidden_states.shape[0] + + # Process convolution state + if cache is not None: + conv_state = cache.conv_states + # Roll the conv state and update the last position + conv_state = mx.roll(conv_state, shift=-1, axis=-1) + # Create new conv state with updated last position + new_conv_state = mx.array(conv_state) + new_conv_state = new_conv_state.at[:, :, -1].add(hidden_states) + conv_state = new_conv_state + + # Compute convolution + conv_out = mx.sum(conv_state * self.conv1d.weight[:, 0, :], axis=-1) + if self.use_conv_bias: + conv_out = conv_out + self.conv1d.bias + + # Apply SiLU activation + conv_out = mx.sigmoid(conv_out) * conv_out + + else: + # Initialize new cache + conv_state = mx.zeros((batch_size, self.conv_dim, self.conv_kernel - 1)) + conv_out = self.conv1d(hidden_states) + conv_out = mx.sigmoid(conv_out) * conv_out + + # Process SSM dt = mx.clip( nn.softplus(dt + self.dt_bias), - self.args.time_step_min, - self.args.time_step_max + self.time_step_min, + self.time_step_max ) - - # Convolution processing - xBC_t = mx.transpose(xBC, [0, 2, 1]) - conv_out = self.conv1d(xBC_t) - xBC = mx.transpose(conv_out, [0, 2, 1])[:, :u.shape[1]] - xBC = mx.sigmoid(xBC) * xBC # SiLU - - # Split states - x, B, C = mx.split( - xBC, - [self.args.intermediate_size, self.args.state_size], - axis=-1 - ) - - # Reshape for selective scan - x = x.reshape((-1, x.shape[1], self.args.num_heads, self.args.head_dim)) + A = -mx.exp(self.A_log) + dA = mx.exp(dt * A[None, :]) + + if cache is not None: + ssm_state = cache.ssm_states + else: + ssm_state = mx.zeros( + (batch_size, self.num_heads, self.head_dim, self.ssm_state_size) + ) + + # Compute SSM updates + dBx = mx.einsum('bh,bhs,bhd->bhds', dt, B, hidden_states) + next_state = ssm_state * dA[:, :, None, None] + dBx + y = mx.einsum('bhds,bhs->bhd', next_state, C) + + # Add skip connection + y = y + hidden_states * self.D[None, :, None] + + return y, conv_state, next_state - # Apply selective scan - y = selective_scan( - x * dt[..., None], - A * dt, - B[..., None, :], - C[..., None, :], - self.args.chunk_size + def process_long_sequence(self, hidden_states, B, C, dt, ssm_state): + batch_size, seq_len = hidden_states.shape[:2] + pad_size = self.chunk_size - (seq_len % self.chunk_size) + + # Reshape into chunks + x_chunks = self.reshape_into_chunks(hidden_states, pad_size, self.chunk_size) + B_chunks = self.reshape_into_chunks(B, pad_size, self.chunk_size) + C_chunks = self.reshape_into_chunks(C, pad_size, self.chunk_size) + + # Process time steps + dt = nn.softplus(dt + self.dt_bias) + dt = mx.clip(dt, self.time_step_min) + + # Prepare matrices + A = -mx.exp(self.A_log) + A = A * dt[:, None] + + # Process chunks + A_chunks = self.reshape_into_chunks( + mx.broadcast_to(A, (batch_size, seq_len + pad_size, self.num_heads)), + pad_size, + self.chunk_size ) - - # Output processing - y = y + x * self.D[None, None, :, None] - y = y.reshape((-1, y.shape[1], self.args.intermediate_size)) - y = self.norm(y, z) - y = self.out_proj(y) - return y - - def forward_inference(self, u: mx.array, cache=None) -> mx.array: - """Single token processing during inference.""" - assert u.shape[1] == 1, "Inference mode expects single token" + # Compute cumulative sums + A_cumsum = mx.cumsum(A_chunks, axis=-1) + L = mx.exp(self.segment_sum(A_chunks)) - batch_size = u.shape[0] - # Use provided cache or create new one - self.cache = cache if cache is not None else Mamba2Cache.get_cache(self.args, batch_size, None) + # Process diagonal blocks + G = mx.einsum('...lhn,...shn->...lsh', C_chunks, B_chunks) + M = G * L[..., None, :] + Y_diag = mx.einsum('...lsh,...sh->...lh', M, x_chunks) + + # Process off-diagonal blocks + decay_states = mx.exp(A_cumsum[..., -1:] - A_cumsum) + B_decay = B_chunks * decay_states[..., None] + states = mx.einsum('...shn,...sh->...hn', B_decay, x_chunks) + + # Combine results + y = Y_diag + states + + # Remove padding if necessary + if pad_size > 0: + y = y[:, :seq_len] + + return y, ssm_state + + def __call__(self, x: mx.array, cache: Optional[Mamba2Cache] = None) -> mx.array: + batch_size, seq_len, _ = x.shape # Project input - zxbcdt = self.in_proj(mx.squeeze(u, 1)) - parts = mx.split( - zxbcdt, - [ - self.args.intermediate_size, - self.args.intermediate_size + 2 * self.args.state_size - ], - axis=-1 - ) - z, xBC = parts[0], parts[1] - dt = zxbcdt[:, -self.args.num_heads:] # Extract dt separately - - # Update convolution state and apply - conv_state = self.cache.update_conv_state(xBC) - xBC = mx.sum( - conv_state * mx.transpose(self.conv1d.weight, [1, 0, 2]), - axis=-1 - ) - if self.args.use_conv_bias: - xBC = xBC + self.conv1d.bias - xBC = mx.sigmoid(xBC) * xBC # SiLU - - # Split states and ensure proper shapes - x_splits = mx.split( - xBC, - [self.args.intermediate_size, self.args.state_size], - axis=-1 - ) - x, B, C = x_splits[0], x_splits[1], x_splits[2] + projected_states = self.in_proj(x.squeeze(1)) - # Process time steps - ensure proper broadcasting - dt = mx.reshape(dt, (batch_size, self.args.num_heads)) - dt = mx.clip( - nn.softplus(dt + self.dt_bias[None, :]), - self.args.time_step_min, - self.args.time_step_max - ) + # Calculate d_mlp based on projection size + d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * + self.n_groups * self.ssm_state_size - self.num_heads) // 2 - # SSM step with explicit shapes - A = -mx.exp(self.A_log) - dA = mx.exp(dt * A[None, :]) # Shape: (batch_size, num_heads) + # Split projections with corrected dimensions + splits = [ + d_mlp, # z0 + d_mlp, # x0 + self.intermediate_size, # gate + self.conv_dim, # hidden_states + self.num_heads # dt + ] - # Reshape x considering intermediate size - # x shape should be (batch_size * num_heads, head_dim) - x = mx.reshape(x, (batch_size, self.args.num_heads, -1)) - assert x.shape[-1] == self.args.head_dim, f"Head dimension mismatch: {x.shape[-1]} vs {self.args.head_dim}" + z0, x0, x1, gate, hidden_states, dt = projected_states.split(splits, axis=-1) - # Reshape B and C for ssm computation - B = mx.reshape(B, (batch_size, -1)) # Should be (batch_size, state_size) - C = mx.reshape(C, (batch_size, -1)) # Should be (batch_size, state_size) + # Split hidden states into components + x_conv, BC = mx.split(hidden_states, [self.intermediate_size], axis=-1) + B, C = mx.split(BC, [self.n_groups * self.ssm_state_size], axis=-1) - # Compute dBx with explicit shapes - dBx = mx.einsum('bh,bs,bhd->bhds', dt, B, x) + # Process based on sequence length + if seq_len > 1 and cache is None: + y, next_state = self.process_long_sequence( + x_conv, B, C, dt, + mx.zeros((batch_size, self.num_heads, self.head_dim, self.ssm_state_size)) + ) + else: + # Reshape for single token processing + x_conv = x_conv.reshape(batch_size, -1, self.head_dim) + B = B.reshape(batch_size, self.num_heads, -1) + C = C.reshape(batch_size, self.num_heads, -1) + y, conv_state, next_state = self.process_single_token(x_conv, B, C, dt, cache) + + if cache is not None: + cache.update(conv_state, next_state) - ssm_state = self.cache.update_ssm_state(dA, dBx) - - y = mx.einsum('bhds,bs->bhd', ssm_state, C) - y = y + x * self.D[None, :, None] - y = mx.reshape(y, (batch_size, self.args.intermediate_size)) - - # Output processing - y = self.norm(y, z) - y = self.out_proj(y) - - return mx.expand_dims(y, 1) - + # Apply normalization and final projection + y = self.norm(y) * gate + return self.out_proj(y) + class ResidualBlock(nn.Module): def __init__(self, args: ModelArgs): super().__init__() - self.mixer = Mamba2Block(args) + self.mixer = Mamba2Mixer(args) self.norm = nn.RMSNorm(args.hidden_size) - def __call__(self, x: mx.array, cache=None) -> mx.array: + def __call__(self, x: mx.array, cache: Optional[Mamba2Cache] = None) -> mx.array: return self.mixer(self.norm(x), cache) + x - class Mamba2Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -295,19 +308,20 @@ class Mamba2Model(nn.Module): def __call__(self, x: mx.array, cache=None) -> mx.array: x = self.embeddings(x) + if cache is None: cache = [None] * len(self.layers) + for layer, layer_cache in zip(self.layers, cache): x = layer(x, layer_cache) - return self.norm_f(x) + return self.norm_f(x) class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args = args self.backbone = Mamba2Model(args) - if not args.tie_word_embeddings: self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) @@ -324,17 +338,24 @@ class Model(nn.Module): return logits def make_cache(self, batch_size=1): - return [Mamba2Cache( - batch_size=batch_size, - intermediate_size=self.args.intermediate_size, - state_size=self.args.state_size, - conv_kernel=self.args.conv_kernel, - num_heads=self.args.num_heads, - head_dim=self.args.head_dim - ) for _ in range(len(self.backbone.layers))] + return [ + Mamba2Cache( + batch_size=batch_size, + conv_dim=self.args.intermediate_size + 2 * self.args.n_groups * self.args.state_size, + kernel_size=self.args.conv_kernel, + num_heads=self.args.num_heads, + head_dim=self.args.head_dim, + state_size=self.args.state_size + ) + for _ in range(len(self.backbone.layers)) + ] def sanitize(self, weights): for k, v in weights.items(): if "conv1d.weight" in k and v.ndim == 3: weights[k] = v.moveaxis(2, 1) - return weights \ No newline at end of file + return weights + + @property + def layers(self): + return self.backbone.layers