From 906f972d36be8a5d9fa9e244542190b5d9c7542e Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 6 Nov 2024 16:35:46 +0100 Subject: [PATCH] save push --- llms/mlx_lm/models/cache.py | 33 +- llms/mlx_lm/models/mamba2 copy.py | 424 ----------- llms/mlx_lm/models/mamba2-other.py | 288 -------- llms/mlx_lm/models/mamba2-prch-minimal.py | 449 ++++++++++++ llms/mlx_lm/models/mamba2-prch.py | 673 ++++++++++++++++-- llms/mlx_lm/models/mamba2-works-hella-alow.py | 337 +++++++++ llms/mlx_lm/models/mamba2.py | 482 ++++++------- llms/mlx_lm/models/mamba22.py | 316 ++++++++ llms/mlx_lm/models/mamba23.py | 357 ++++++++++ llms/mlx_lm/models/mamba24.py | 430 +++++++++++ 10 files changed, 2777 insertions(+), 1012 deletions(-) delete mode 100644 llms/mlx_lm/models/mamba2 copy.py delete mode 100644 llms/mlx_lm/models/mamba2-other.py create mode 100644 llms/mlx_lm/models/mamba2-prch-minimal.py create mode 100644 llms/mlx_lm/models/mamba2-works-hella-alow.py create mode 100644 llms/mlx_lm/models/mamba22.py create mode 100644 llms/mlx_lm/models/mamba23.py create mode 100644 llms/mlx_lm/models/mamba24.py diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index 2764ce44..b66ede89 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -341,13 +341,28 @@ class MambaCache(_BaseCache): self.cache = v -class Mamba2Cache: - def __init__(self, batch_size, conv_dim, kernel_size, num_heads, head_dim, state_size): - self.conv_states = mx.zeros((batch_size, conv_dim, kernel_size - 1)) - self.ssm_states = mx.zeros((batch_size, num_heads, head_dim, state_size)) - self.seqlen_offset = 0 - def update(self, new_conv_state, new_ssm_state): - self.conv_states = new_conv_state - self.ssm_states = new_ssm_state - self.seqlen_offset += 1 \ No newline at end of file +class Mamba2Cache(_BaseCache): + def __init__( + self, + batch_size, + conv_kernel + ): + self.conv_kernel: mx.array = conv_kernel + self.conv_states: mx.array = [None] + self.ssm_states = [None] + self.seqlen_offset = 0 + + def reset(self): + self.conv_states = None + self.ssm_state = None + + def update(self, layer_idx: int, new_conv_state: mx.array, cache_position: mx.array) -> mx.array: + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel - 1) + + conv_state = conv_state.roll(shifts=-1, dims=-1) + conv_state[:, :, cache_position] = new_conv_state + self.conv_states[layer_idx].zero_() + self.conv_states[layer_idx] += conv_state + return self.conv_states[layer_idx] \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba2 copy.py b/llms/mlx_lm/models/mamba2 copy.py deleted file mode 100644 index fc3f23d8..00000000 --- a/llms/mlx_lm/models/mamba2 copy.py +++ /dev/null @@ -1,424 +0,0 @@ -import math -from dataclasses import dataclass, field -from typing import Optional, Tuple, Union - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs -from .cache import Mamba2Cache - -@dataclass -class ModelArgs(BaseModelArgs): - num_heads: int - head_dim: int - vocab_size: int - hidden_size: int - state_size: int - num_hidden_layers: int - layer_norm_epsilon: float - expand: int - conv_kernel: int - n_groups: int - use_bias: bool - use_conv_bias: bool - initializer_range: float - residual_in_fp32: bool - time_step_min: float - time_step_max: float - time_step_floor: float - rescale_prenorm_residual: bool - use_cache: bool - rms_norm: bool - chunk_size: int - tie_word_embeddings: bool - intermediate_size: int = None - time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf"))) - time_step_rank: Union[int, str] = "auto" - model_type: str = "mamba2" - - def __post_init__(self): - self.intermediate_size = int(self.expand * self.hidden_size) # E*D = ED - - if not hasattr(self, "head_dim"): - self.head_dim = self.hidden_size // self.num_heads - if self.time_step_rank == "auto": - self.time_step_rank = math.ceil(self.hidden_size / 16) - - -def selective_scan(x, A, B, C, chunk_size): - """ - Selective scan implementation for training. - - Arguments - x: (batch, seqlen, n_heads, d_head) - A: (batch, seqlen, n_heads) - B: (batch, seqlen, n_heads, d_state) - C: (batch, seqlen, n_heads, d_state) - - Return - y: (batch, seqlen, n_heads, d_head) - """ - assert x.shape[1] % chunk_size == 0 - - # Reshape into chunks - def chunk_reshape(m): - shape = list(m.shape) - shape[1:2] = [shape[1] // chunk_size, chunk_size] - return m.reshape(shape) - - x, A, B, C = map(chunk_reshape, (x, A, B, C)) - A = mx.transpose(A, [0, 3, 1, 2]) - - # Compute cumulative sums - A_cumsum = mx.cumsum(A, axis=-1) - - # Process chunks - L = mx.exp(selective_cumsum(A)) - Y_diag = mx.einsum('bclhn,bcshn,bhcls,bcshp->bclhp', C, B, L, x) - - decay_states = mx.exp(A_cumsum[..., -1:] - A_cumsum) - states = mx.einsum('bclhn,bhcl,bclhp->bchpn', B, decay_states, x) - - initial_states = mx.zeros_like(states[:, :1]) - states = mx.concatenate([initial_states, states], axis=1) - decay_chunk = mx.exp(selective_cumsum(mx.pad(A_cumsum[..., -1], ((0,0), (0,0), (1,0))))) - new_states = mx.einsum('bhzc,bchpn->bzhpn', decay_chunk, states) - states = new_states[:, :-1] - - state_decay_out = mx.exp(A_cumsum) - Y_off = mx.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out) - - Y = (Y_diag + Y_off).reshape((-1, x.shape[1] * chunk_size, *Y_diag.shape[-2:])) - return Y - -def selective_cumsum(x: mx.array) -> mx.array: - """Stable selective cumulative sum calculation.""" - T = x.shape[-1] - x = mx.repeat(x[..., None], T, axis=-1) - mask = mx.tril(mx.ones((T, T)), k=-1) - x = x * mask - x_cumsum = mx.cumsum(x, axis=-2) - mask = mx.tril(mx.ones((T, T)), k=0) - return mx.where(mask, x_cumsum, float('-inf')) - - -class Mamba2Block(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - - # Project input to get various components [z, x, B, C, dt] - projection_size = (2 * args.intermediate_size + 2 * args.n_groups * args.state_size + args.num_heads) - self.in_proj = nn.Linear( - args.hidden_size, - projection_size, - bias=args.use_bias - ) - - # Convolution layer - conv_dim = args.intermediate_size + 2 * args.n_groups * args.state_size - self.conv1d = nn.Conv1d( - in_channels=conv_dim, - out_channels=conv_dim, - kernel_size=args.conv_kernel, - groups=conv_dim, - padding=args.conv_kernel - 1, - bias=args.use_conv_bias - ) - - # SSM parameters - self.dt_bias = mx.zeros(args.num_heads) - self.A_log = mx.zeros(args.num_heads) - self.D = mx.ones(args.num_heads) - - # Output projections - self.norm = nn.RMSNorm(args.intermediate_size, eps=args.layer_norm_epsilon) - self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) - - def __call__(self, u: mx.array, cache=None) -> mx.array: - # return self.forward_training(x) if x.shape[1] > 1 else self.forward_inference(x, cache) - - # def forward_training(self, u: mx.array) -> mx.array: - # # Reset cache during training - # self.cache = None - - # # Input projection and splitting - # zxbcdt = self.in_proj(u) - # z, xBC, dt = mx.split( - # zxbcdt, - # [ - # self.args.hidden_size, - # self.args.hidden_size + 2 * self.args.state_size - # ], - # axis=-1 - # ) - - # # Time step processing - # dt = mx.clip( - # nn.softplus(dt + self.dt_bias), - # self.args.time_step_min, - # self.args.time_step_max - # ) - - # # Convolution processing - # xBC_t = mx.transpose(xBC, [0, 2, 1]) - # conv_out = self.conv1d(xBC_t) - # xBC = mx.transpose(conv_out, [0, 2, 1])[:, :u.shape[1]] - # xBC = mx.sigmoid(xBC) * xBC # SiLU - - # # Split states - # x, B, C = mx.split( - # xBC, - # [self.args.hidden_size, self.args.state_size], - # axis=-1 - # ) - - # # Reshape for selective scan - # x = x.reshape((-1, x.shape[1], self.args.num_heads, self.args.head_dim)) - # A = -mx.exp(self.A_log) - - # # Apply selective scan - # y = selective_scan( - # x * dt[..., None], - # A * dt, - # B[..., None, :], - # C[..., None, :], - # self.args.chunk_size - # ) - - # # Output processing - # y = y + x * self.D[None, None, :, None] - # y = y.reshape((-1, y.shape[1], self.args.hidden_size)) - # y = self.norm(y, z) - # y = self.out_proj(y) - - # return y - - # def forward_inference(self, u: mx.array, cache=None) -> mx.array: - # """ - # u: (B, 1, D) - # cache: (h_cache, conv_cache) - # """ - # """Single token processing during inference.""" - # assert u.shape[1] == 1, "Inference mode expects single token" - - # batch_size = u.shape[0] - # # Use provided cache or create new one - # self.cache = cache if cache is not None else Mamba2Cache.get_cache(self.args, batch_size, None) - - # # Project input - # zxbcdt = self.in_proj(u.squeeze(1)) # (B, 2D) - # d_mlp = (zxbcdt.shape[-1] - 2 * self.args.hidden_size - 2 * self.args.n_groups * self.args.state_size - self.args.num_heads) // 2 - - # # (1, 768) (1, 0) (1, 0) (1, 256) (1, 0) (1, 3328) - # y0, z0, x0, z, xBC, dt = mx.split( - # zxbcdt, - # [ - # d_mlp, - # d_mlp, - # self.args.hidden_size, - # self.args.hidden_size + 2 * self.args.n_groups * self.args.state_size, - # self.args.num_heads - # ], - # axis=-1 - # ) - - # # Update convolution state and apply - # conv_state = self.cache.update_conv_state(xBC) - # xBC = mx.sum(conv_state[:, :, -1] * mx.transpose(self.conv1d.weight, [1, 0, 2]), axis=-1) # (B, D) (4, 1792) - - # if self.args.use_conv_bias: - # xBC = xBC + self.conv1d.bias - - # xBC = mx.sigmoid(xBC) * xBC # SiLU (4, 1792) - - # # Split states and ensure proper shapes - # a0, x, B, C = mx.split( - # xBC, # (4, 1792) - # [ - # self.args.hidden_size, - # self.args.n_groups * self.args.state_size, - # self.args.n_groups * self.args.state_size - # ], - # axis=-1 - # ) - - # # SSM step with explicit shapes - # A = -mx.exp(self.A_log) # (num_heads) (24,) - # print(A.shape) # (24,) - # print(dt.shape) # (1, 3328) - # dA = mx.exp(dt * A[None, :]) # Shape: (batch_size, num_heads) <------- her eis the error - - # # Reshape x considering intermediate size - # # x shape should be (batch_size * num_heads, head_dim) - # x = mx.reshape(x, (batch_size, self.args.num_heads, -1)) - # assert x.shape[-1] == self.args.head_dim, f"Head dimension mismatch: {x.shape[-1]} vs {self.args.head_dim}" - - # B = mx.reshape(B, (batch_size, -1)) # Should be (batch_size, state_size) - # C = mx.reshape(C, (batch_size, -1)) # Should be (batch_size, state_size) - - # # Compute dBx with explicit shapes - # dBx = mx.einsum('bh,bs,bhd->bhds', dt, B, x) - - # ssm_state = self.cache.update_ssm_state(dA, dBx) - - # y = mx.einsum('bhds,bs->bhd', ssm_state, C) - # y = y + x * self.D[None, :, None] - # y = mx.reshape(y, (batch_size, self.args.hidden_size)) - - # # Output processing - # y = self.norm(y, z) - - # if d_mlp > 0: - # y = mx.cat([nn.silu(z0) * x0, y], axis=-1) - - # y = self.out_proj(y) - - # return mx.expand_dims(y, 1) - - assert u.shape[1] == 1, "Inference mode expects single token" - - batch_size = u.shape[0] - # Use provided cache or create new one - self.cache = cache if cache is not None else Mamba2Cache.get_cache(self.args, batch_size, None) - - # Project input - zxbcdt = self.in_proj(u.squeeze(1)) # (B, projection_size) - - # Calculate splits based on model dimensions - d_mlp = self.args.intermediate_size - d_state = self.args.state_size * self.args.n_groups - - # Split the projection into its components - splits = [ - d_mlp, # y0 - d_mlp, # z0 - self.args.hidden_size, # x0 - self.args.hidden_size, # z - d_state * 2, # xBC (includes both B and C) - self.args.num_heads # dt - ] - - y0, z0, x0, z, xBC, dt = mx.split(zxbcdt, splits[:-1], axis=-1) - - # Update convolution state and apply - conv_state = self.cache.update_conv_state(xBC) - xBC = mx.sum(conv_state[:, :, -1] * mx.transpose(self.conv1d.weight, [1, 0, 2]), axis=-1) - - if self.args.use_conv_bias: - xBC = xBC + self.conv1d.bias - - xBC = mx.sigmoid(xBC) * xBC # SiLU - - # Split states and reshape - x, BC = mx.split(xBC, [self.args.intermediate_size], axis=-1) - B, C = mx.split(BC, [d_state], axis=-1) - - # Reshape for SSM computation - x = mx.reshape(x, (batch_size, self.args.num_heads, -1)) # (B, H, head_dim) - B = mx.reshape(B, (batch_size, self.args.num_heads, -1)) # (B, H, state_per_head) - C = mx.reshape(C, (batch_size, self.args.num_heads, -1)) # (B, H, state_per_head) - - # Process dt to match expected shape - dt = mx.reshape(dt, (batch_size, self.args.num_heads)) # (B, H) - dt = mx.clip( - nn.softplus(dt + self.dt_bias), - self.args.time_step_min, - self.args.time_step_max - ) - - # SSM step - A = -mx.exp(self.A_log) # (H,) - dA = mx.exp(dt * A[None, :]) # (B, H) - - # Compute dBx - dBx = mx.einsum('bh,bhs,bhd->bhds', dt, B, x) - - # Update SSM state and compute output - ssm_state = self.cache.update_ssm_state(dA, dBx) - y = mx.einsum('bhds,bhs->bhd', ssm_state, C) - y = y + x * self.D[None, :, None] - - # Reshape output - y = mx.reshape(y, (batch_size, self.args.hidden_size)) - - # Final output processing - y = self.norm(y, z) - - if d_mlp > 0: - y = mx.concat([nn.silu(z0) * x0, y], axis=-1) - - y = self.out_proj(y) - - return mx.expand_dims(y, 1) # (B, 1, D) - - -class ResidualBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.mixer = Mamba2Block(args) - self.norm = nn.RMSNorm(args.hidden_size) - - def __call__(self, x: mx.array, cache=None) -> mx.array: - # x : (B, L, D) - return self.mixer(self.norm(x), cache) + x # (B, L, D) - - -class Mamba2Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] - self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) - - def __call__(self, x: mx.array, cache=None) -> mx.array: - # x : (B, L) - x = self.embeddings(x) - # x : (B, L, D) - if cache is None: - cache = [None] * len(self.layers) - - for layer, layer_cache in zip(self.layers, cache): - x = layer(x, layer_cache) - return self.norm_f(x) - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.backbone = Mamba2Model(args) - - if not args.tie_word_embeddings: - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__(self, inputs: mx.array, cache=None) -> mx.array: - # inputs : (B, L) - B, T = inputs.shape - - x = self.backbone(inputs, cache) - - if self.args.tie_word_embeddings: - logits = self.backbone.embeddings.as_linear(x) - else: - logits = self.lm_head(x) - - return logits - - def make_cache(self, batch_size=1): - return [Mamba2Cache( - batch_size=batch_size, - hidden_size=self.args.hidden_size, - state_size=self.args.state_size, - conv_kernel=self.args.conv_kernel, - num_heads=self.args.num_heads, - head_dim=self.args.head_dim - ) for _ in range(len(self.backbone.layers))] - - def sanitize(self, weights): - for k, v in weights.items(): - if "conv1d.weight" in k and v.ndim == 3: - weights[k] = v.moveaxis(2, 1) - return weights \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba2-other.py b/llms/mlx_lm/models/mamba2-other.py deleted file mode 100644 index 22064021..00000000 --- a/llms/mlx_lm/models/mamba2-other.py +++ /dev/null @@ -1,288 +0,0 @@ -# Copyright © 2024 Apple Inc. - -import math -from dataclasses import dataclass, field -from typing import Tuple, Union - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str = "mamba2" - num_heads: int = 128 - head_dim: int = 64 - vocab_size: int = 32768 - hidden_size: int = 4096 - state_size: int = 128 - num_hidden_layers: int = 64 - layer_norm_epsilon: float = 1e-5 - expand: int = 2 - conv_kernel: int = 4 - n_groups: int = 8 - use_bias: bool = False - use_conv_bias: bool = True - initializer_range: float = 0.1 - residual_in_fp32: bool = True - time_step_rank: Union[int, str] = "auto" - time_step_min: float = 0.001 - time_step_max: float = 0.1 - time_step_floor: float = 1e-4 - time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf"))) - rescale_prenorm_residual: bool = False - use_cache: bool = True - rms_norm: bool = True - chunk_size: int = 256 - tie_word_embeddings: bool = False - - def __post_init__(self): - if not hasattr(self, "intermediate_size"): - self.intermediate_size = int(self.expand * self.hidden_size) - if not hasattr(self, "head_dim"): - self.head_dim = self.hidden_size // self.num_heads - if self.time_step_rank == "auto": - self.time_step_rank = math.ceil(self.hidden_size / 16) - - -class Mamba2Cache: - def __init__(self): - self.cache = [None, None] - - def __setitem__(self, idx, value): - self.cache[idx] = value - - def __getitem__(self, idx): - return self.cache[idx] - - @property - def state(self): - return self.cache - - -class MambaRMSNormGated(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - super().__init__() - self.weight = mx.ones((hidden_size,)) - self.variance_epsilon = eps - - def __call__(self, hidden_states, gate=None): - if gate is not None: - hidden_states = hidden_states * nn.silu(gate) - variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True) - hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states - -class DepthWiseConv1d(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.padding = padding - self.groups = groups if groups is not None else in_channels - - # Ensure in_channels and out_channels are the same for depthwise conv - assert in_channels == out_channels, "In and out channels must be the same for depthwise convolution" - # Ensure groups is equal to in_channels for depthwise conv - assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution" - - # Initialize weight with shape (out_channels, kernel_size, 1) - self.weight = mx.random.normal((out_channels, kernel_size, 1)) - self.bias = mx.zeros((out_channels,)) if bias else None - - def __call__(self, x, cache=None): - B, L, C = x.shape - _, K, _ = self.weight.shape - - if cache is not None: - x = mx.concatenate([cache, x], axis=1) - else: - x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)]) - - y = mx.conv_general(x, self.weight, groups=self.groups) - - if self.bias is not None: - y = y + self.bias - - return y, x[:, -K + 1 :, :] - - -class Mamba2Mixer(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.intermediate_size = args.intermediate_size - self.time_step_rank = args.time_step_rank - self.conv_kernel_size = args.conv_kernel - self.hidden_size = args.hidden_size - self.state_size = args.state_size - self.num_heads = args.num_heads - self.head_dim = args.hidden_size // args.num_heads - self.n_groups = args.n_groups - - self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size - self.conv1d = DepthWiseConv1d( - in_channels=self.conv_dim, - out_channels=self.conv_dim, - bias=args.use_conv_bias, - kernel_size=args.conv_kernel, - groups=self.conv_dim, - padding=args.conv_kernel - 1 - ) - - projection_size = self.intermediate_size + self.conv_dim + self.num_heads - self.in_proj = nn.Linear( - self.hidden_size, - projection_size, - bias=args.use_bias - ) - - self.dt_bias = mx.ones((self.num_heads,)) - self.A_log = mx.log(mx.arange(1, self.num_heads + 1)) - self.D = mx.ones((self.num_heads,)) - - self.norm = MambaRMSNormGated(self.intermediate_size, eps=args.layer_norm_epsilon) - - self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) - - def ssm_step(self, x, state, dt_proj): - A = -mx.exp(self.A_log) - D = self.D - delta = nn.softplus(dt_proj + self.dt_bias) - - B, C = mx.split(x, indices_or_sections=[self.state_size * self.n_groups], axis=-1) - - B = B.reshape(-1, self.n_groups, self.state_size) - C = C.reshape(-1, self.n_groups, self.state_size) - - if state is None: - new_state = mx.expand_dims(delta, -1) * B - else: - new_state = mx.expand_dims(delta, -1) * (B + state * mx.exp(mx.expand_dims(delta, -1) * A)) - - y = mx.sum(new_state * C, axis=-1) - y = y + D * x[:, :self.num_heads] - return y, new_state - - def __call__(self, x, cache): - B, T, D = x.shape - if cache is None: - cache = [None, None] - - outputs = [] - for t in range(T): - xt = x[:, t, :] - xz = self.in_proj(xt) - - x_t, z_t, dt_proj = mx.split( - xz, - indices_or_sections=[self.conv_dim, self.conv_dim + self.intermediate_size], - axis=-1 - ) - - conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0]) - x_t = conv_out.squeeze(1) - x_t = nn.silu(x_t) - y_t, cache[1] = self.ssm_step(x_t, cache[1], dt_proj) - z_t = nn.silu(z_t) - - # Print shapes for debugging - print(f"y_t shape: {y_t.shape}") - print(f"z_t shape: {z_t.shape}") - - # Reshape y_t to (B, num_heads, head_dim) - y_t_reshaped = y_t.reshape(B, self.num_heads, -1) - - # Reshape z_t to (B, num_heads, intermediate_size // num_heads) - z_t_reshaped = z_t.reshape(B, self.num_heads, -1) - - print(f"y_t_reshaped shape: {y_t_reshaped.shape}") - print(f"z_t_reshaped shape: {z_t_reshaped.shape}") - - # Element-wise multiplication (broadcasting across the last dimension) - output_t = y_t_reshaped * z_t_reshaped - - # Reshape to match the expected input of out_proj - output_t = output_t.reshape(B, -1) - - print(f"output_t shape before out_proj: {output_t.shape}") - print(f"out_proj weight shape: {self.out_proj.weight.shape}") - - output_t = self.out_proj(output_t) - outputs.append(output_t) - - output = mx.stack(outputs, axis=1) - return output - - -class Mamba2Block(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.mixer = Mamba2Mixer(args) - self.norm = nn.RMSNorm(args.hidden_size) - - def __call__(self, x: mx.array, cache): - return self.mixer(self.norm(x), cache) + x - - -class Mamba2(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [Mamba2Block(args) for idx in range(args.num_hidden_layers)] - self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) - - def __call__( - self, - inputs: mx.array, - cache=None - ): - hidden_states = self.embeddings(inputs) - - if cache is None: - cache = Mamba2Cache(len(self.layers)) - - for i, layer in enumerate(self.layers): - hidden_states = layer(hidden_states, cache[i]) - - hidden_states = self.norm_f(hidden_states) - return hidden_states - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.backbone = Mamba2(args) - if not args.tie_word_embeddings: - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__(self, inputs: mx.array, cache=None): - B, T = inputs.shape - - x = self.backbone(inputs, cache) - - if self.args.tie_word_embeddings: - logits = self.backbone.embeddings.as_linear(x) - else: - logits = self.lm_head(x) - - return logits - - def sanitize(self, weights): - for k, v in weights.items(): - if "conv1d.weight" in k and v.ndim == 3: - weights[k] = v.moveaxis(2, 1) - return weights - - def make_cache(self, batch_size: int = 1): - return [Mamba2Cache() for _ in range(len(self.layers))] - - @property - def layers(self): - return self.backbone.layers diff --git a/llms/mlx_lm/models/mamba2-prch-minimal.py b/llms/mlx_lm/models/mamba2-prch-minimal.py new file mode 100644 index 00000000..52d27f00 --- /dev/null +++ b/llms/mlx_lm/models/mamba2-prch-minimal.py @@ -0,0 +1,449 @@ +# coding=utf-8 +# Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch MAMBA2 model.""" + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +logger = logging.get_logger(__name__) + + +def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): + """ + Padding x tensor with `pad_size` on the seq_len dim (dim=1) + + Assumes that we only have tensors of either size 4 or 3 + """ + pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) + + return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) + + +def reshape_into_chunks(input_tensor, pad_size, chunk_size): + """ + Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and + simultaneously splitting it into chunk sequences. + + Assumes that we only have tensors of either size 4 or 3 + """ + # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] + input_tensor = pad_tensor_by_size(input_tensor, pad_size) + + if len(input_tensor.shape) == 3: + # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] + return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) + else: + # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size] + return input_tensor.reshape( + input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3] + ) + + +def segment_sum(input_tensor): + """ + More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions. + """ + chunk_size = input_tensor.size(-1) + # 1. expand input tensor to have an additional dimension and repeat along that dimension + # [..., chunk_size] -> [..., chunk_size, chunk_size] + input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size) + # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1) + input_tensor = input_tensor.masked_fill(~mask, 0) + # 3. compute actual cumsum + tensor_segsum = torch.cumsum(input_tensor, dim=-2) + + # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time) + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0) + tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf) + return tensor_segsum + + +class Mamba2Cache: + """ + Arguments: + config: ModelArgs + batch_size: int + dtype: torch.dtype + device: torch.device + + Attributes: + seqlen_offset: int + dtype: torch.dtype + conv_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, conv_kernel] + ssm_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, ssm_state_size] + """ + + def __init__( + self, config: ModelArgs, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None + ): + self.seqlen_offset = 0 + self.dtype = dtype + self.conv_kernel = config.conv_kernel + self.intermediate_size = int(config.expand * config.hidden_size) + + self.conv_states = { + i: torch.zeros( + batch_size, + self.intermediate_size + 2 * config.n_groups * config.state_size, + self.conv_kernel, + device=device, + dtype=dtype, + ) + for i in range(config.num_hidden_layers) + } + self.ssm_states = { + i: torch.zeros( + batch_size, config.num_heads, config.head_dim, config.state_size, device=device, dtype=dtype + ) + for i in range(config.num_hidden_layers) + } + + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor + ) -> torch.Tensor: + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel - 1) + + conv_state = conv_state.roll(shifts=-1, dims=-1) + conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device) + self.conv_states[layer_idx].zero_() + self.conv_states[layer_idx] += conv_state + return self.conv_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + +class MambaRMSNormGated(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states, gate=None): + input_dtype = hidden_states.dtype + hidden_states = hidden_states + + if gate is not None: + hidden_states = hidden_states * nn.functional.silu(gate) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + return self.weight * hidden_states + + +class Mamba2Mixer(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.num_heads = config.num_heads + self.hidden_size = config.hidden_size + self.state_size = config.state_size + self.conv_kernel = config.conv_kernel + self.intermediate_size = int(config.expand * self.hidden_size) + self.time_step_rank = int(config.time_step_rank) + self.use_conv_bias = config.use_conv_bias + + self.layer_norm_epsilon = config.layer_norm_epsilon + + self.n_groups = config.n_groups + self.head_dim = config.head_dim + self.chunk_size = config.chunk_size + + self.time_step_limit = config.time_step_limit + self.time_step_min = config.time_step_min + self.time_step_max = config.time_step_max + + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=config.use_conv_bias, + kernel_size=config.conv_kernel, + groups=self.conv_dim, + padding=config.conv_kernel - 1, + ) + + # projection of the input hidden states + projection_size = self.intermediate_size + self.conv_dim + self.num_heads + self.in_proj = nn.Linear( + self.hidden_size, + projection_size, + bias=config.use_bias, + ) + + self.dt_bias = torch.ones(self.num_heads) + A = torch.arange(1, self.num_heads + 1) + self.A_log = torch.log(A) + self.D = torch.ones(self.num_heads) + + self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon) + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) + + def forward(self, input_states, cache: Optional[Mamba2Cache]=None): + batch_size, seq_len, _ = input_states.shape + # Gated MLP's linear projection + projected_states = self.in_proj(input_states.squeeze(1)) + d_mlp = ( + projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.state_size- self.num_heads) // 2 + _, _, gate, hidden_states, dt = projected_states.split( + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # Convolution sequence transformation + ssm_state = cache.ssm_states[self.layer_idx].clone() + ssm_state = ssm_state.to(hidden_states.device) + + if cache.seqlen_offset > 0: + conv_state = cache.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel] + conv_state = torch.roll(conv_state, shifts=-1, dims=-1) + + # handle batched generation - states are copied through + conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states + cache.conv_states[self.layer_idx].copy_(conv_state) + hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + + if self.use_conv_bias: + hidden_states += self.conv1d.bias + hidden_states = nn.silu(hidden_states)[:, None, ...] # [batch, 1, intermediate_size] : decoding + else: + hidden_states = hidden_states.transpose(1,2) + conv_state = nn.functional.pad( + hidden_states, + (self.conv_kernel - hidden_states.shape[-1], 0) + ) + cache.conv_states[self.layer_idx].copy_(conv_state) + hidden_states = nn.silu(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len] + + hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.state_size, self.n_groups * self.state_size], dim=-1) + A = -torch.exp(self.A_log.float()) # [num_heads] + + if cache is not None and cache.seqlen_offset > 0: + # Note: there is no need to pad parameter matrices here, as there is just one new token + # for batched generation + dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...] + dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) + # [num_heads] -> [num_heads, head_dim] + dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) + + dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype)) + dt = torch.clamp(dt, self.time_step_min) #, self.time_step_max) + A = A[..., None, None].expand(self.num_heads, self.head_dim, self.state_size).to(dtype=torch.float32) + # [bsz, num_heads, head_dim, state_size] + dA = torch.exp(dt[..., None] * A) + + # Discretize B + # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> + # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] + B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] + B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous() + B = B.reshape(batch_size, -1, B.shape[-1]) + # [bsz, num_heads, head_dim, state_size] + dB = dt[..., None] * B[..., None, :] + + # Discretize x into dB + # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] + hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) + dBx = dB * hidden_states[..., None] + + # State calculation + cache.ssm_states[self.layer_idx].copy_( + cache.ssm_states[self.layer_idx] * dA + dBx + ) + + # Subsequent output + # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] + C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] + C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous() + C = C.reshape(batch_size, -1, C.shape[-1]) + # [bsz, num_heads, head_dim] + + ssm_states = cache.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n] + # Reshape ssm_states to merge the first two dimensions + ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.state_size) # Shape: [b*h, d, n] + C_reshaped = C.view(batch_size * self.num_heads, self.state_size, 1) # Shape: [b*h, n, 1] + y = torch.bmm(ssm_states_reshaped, C_reshaped) + y = y.view(batch_size, self.num_heads, self.head_dim) + + # D skip connection + # [num_heads] -> [num_heads, head_dim] + D = self.D[..., None].expand(self.D.shape[0], self.head_dim) + y = (y + hidden_states * D).to(y.dtype) + + # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size] + y = y.reshape(batch_size, -1)[:, None, ...] + else: + # begin ssd naive implementation without einsums + dt = nn.functional.softplus(dt + self.dt_bias) + dt = torch.clamp(dt, self.time_step_min) + hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() + B = B.reshape(batch_size, seq_len, -1, self.state_size).float() + C = C.reshape(batch_size, seq_len, -1, self.state_size).float() + B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) + C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) + pad_size = self.chunk_size - (seq_len % self.chunk_size) + + D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) + + # Discretize x and A + hidden_states = hidden_states * dt[..., None] + A = A.to(hidden_states.dtype) * dt + + # Rearrange into blocks/chunks + hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] + + + # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] + A = A.permute(0, 3, 1, 2) + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + # This is the analog of a causal mask + L = torch.exp(segment_sum(A)) + + # First, contraction of C and B to get G (attention-weights like) + G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n) + G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) + + + # Step 2: Compute M, equivalent to applying attention mask to weights + M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] + M = M_intermediate.sum(dim=-1) + + # Step 3: Compute Y_diag (apply to values) + Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3) + + # (right term of low-rank factorization of off-diagonal blocks; B terms) + + decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None] + # permute back B * decay states + states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3) + if cache is not None and cache.seqlen_offset > 0: + previous_states = cache.ssm_states[self.layer_idx][:, None, ...] + else: + previous_states = torch.zeros_like(states[:, :1]) + states = torch.cat([previous_states, states], dim=1) + decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) + + states_permuted = states.permute(0, 2, 1, 3, 4) + result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2) + new_states = result.permute(0, 2, 1, 3, 4) + states, ssm_state = new_states[:, :-1], new_states[:, -1] + + # Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + # compute Yoff + C_times_states = (C[..., None, :] * states[:, :, None, ...]) + state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) + Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + + y = Y_diag + Y_off + # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] + y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) + + y = y + D_residual + # Cutting off padded chunks + if pad_size > 0: + y = y[:, :seq_len, :, :] + y = y.reshape(batch_size, seq_len, -1) + if ssm_state is not None and cache is not None: + cache.ssm_states[self.layer_idx].copy_(ssm_state) + + scan_output = self.norm(y, gate) + # end ssd naive + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output) # [batch, seq_len, hidden_size] + return contextualized_states + + +class Mamba2Block(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.norm = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.mixer = Mamba2Mixer(config) + + def forward( + self, + hidden_states, + cache: Optional[Mamba2Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + ): + x = self.mixer( + self.norm(hidden_states), cache=cache, cache_position=cache_position + ) + return x + hidden_states + + +class Mamba2Model(nn.Module): + def __init__(self, config): + super().__init__(config) + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList([Mamba2Block(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]) + + self.gradient_checkpointing = False + self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + cache: Optional[Mamba2Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + ): + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + for mixer_block in self.layers: + hidden_states = mixer_block( + hidden_states, + cache=cache, + cache_position=cache_position, + ) + + cache.seqlen_offset += inputs_embeds.shape[1] + return self.norm_f(hidden_states), cache + + + +class Mamba2ForCausalLM(nn.Module): + def __init__(self, config): + super().__init__(config) + self.backbone = Mamba2Model(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + cache: Optional[Mamba2Cache] = None, + cache_position: Optional[torch.Tensor] = None, + ): + out, cache = self.backbone( + input_ids, + cache=cache, + cache_position=cache_position, + ) + logits = self.lm_head(out) + return logits, cache \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba2-prch.py b/llms/mlx_lm/models/mamba2-prch.py index 84bf2174..69390ea9 100644 --- a/llms/mlx_lm/models/mamba2-prch.py +++ b/llms/mlx_lm/models/mamba2-prch.py @@ -23,9 +23,42 @@ import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss +from ...activations import ACT2FN +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available +from .configuration_mamba2 import Mamba2Config + + logger = logging.get_logger(__name__) +if is_mamba_2_ssm_available(): + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined +else: + selective_state_update = None + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_update, causal_conv1d_fn = None, None + +is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + +_CHECKPOINT_FOR_DOC = "mistralai/mamba-codestral-7B-v0.1" +_CONFIG_FOR_DOC = "Mamba2Config" + + +# Helper methods for segment sum computation + + def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): """ Padding x tensor with `pad_size` on the seq_len dim (dim=1) @@ -80,7 +113,7 @@ def segment_sum(input_tensor): class Mamba2Cache: """ Arguments: - config: ModelArgs + config: Mamba2Config batch_size: int dtype: torch.dtype device: torch.device @@ -93,7 +126,7 @@ class Mamba2Cache: """ def __init__( - self, config: ModelArgs, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None + self, config: Mamba2Config, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None ): self.seqlen_offset = 0 self.dtype = dtype @@ -116,6 +149,8 @@ class Mamba2Cache: ) for i in range(config.num_hidden_layers) } + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] def update_conv_state( self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor @@ -142,18 +177,25 @@ class MambaRMSNormGated(torch.nn.Module): def forward(self, hidden_states, gate=None): input_dtype = hidden_states.dtype - hidden_states = hidden_states + hidden_states = hidden_states.to(torch.float32) if gate is not None: - hidden_states = hidden_states * nn.functional.silu(gate) + hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32)) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states + return self.weight * hidden_states.to(input_dtype) class Mamba2Mixer(nn.Module): - def __init__(self, config: ModelArgs): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + """ + + def __init__(self, config: Mamba2Config, layer_idx: int): super().__init__() self.num_heads = config.num_heads self.hidden_size = config.hidden_size @@ -161,8 +203,10 @@ class Mamba2Mixer(nn.Module): self.conv_kernel_size = config.conv_kernel self.intermediate_size = int(config.expand * self.hidden_size) self.time_step_rank = int(config.time_step_rank) + self.layer_idx = layer_idx self.use_conv_bias = config.use_conv_bias - self.act = nn.silu + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] self.layer_norm_epsilon = config.layer_norm_epsilon self.rms_norm = config.rms_norm @@ -192,23 +236,178 @@ class Mamba2Mixer(nn.Module): projection_size, bias=config.use_bias, ) + # selective projection used to make dt, B and C input dependant - self.dt_bias = torch.ones(self.num_heads) + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded A = torch.arange(1, self.num_heads + 1) - self.A_log = torch.log(A) - self.D = torch.ones(self.num_heads) - + self.A_log = nn.Parameter(torch.log(A)) + self.A_log._no_weight_decay = True self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon) - self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) + self.D = nn.Parameter(torch.ones(self.num_heads)) + self.D._no_weight_decay = True - def forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, cache_position:Optional[torch.LongTensor]=None): + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) + self.use_bias = config.use_bias + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" + " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + + def cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[Mamba2Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + # set up dimensions for reshapes later + + batch_size, seq_len, _ = hidden_states.shape + groups_time_state_size = self.n_groups * self.ssm_state_size + d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads + + # getting projected states from cache if it exists + if cache_params is not None and cache_params.seqlen_offset > 0: + in_projected_states = self.in_proj(hidden_states.squeeze(1)) # (B 2D) + d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2 + split_projection_dim = [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads] + _, _, gate, hidden_states_B_C, dt = torch.split(in_projected_states, split_projection_dim, dim=-1) + + hidden_states_B_C = causal_conv1d_update( + hidden_states_B_C, + cache_params.conv_states[self.layer_idx], + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, groups_time_state_size, groups_time_state_size], + dim=-1, + ) + A = -torch.exp(self.A_log.float()) # (nheads,) + + A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups) + C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) + hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) + hidden_states = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states_reshaped, + dt, + A, + B, + C, + D, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + ) + hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim) + hidden_states = self.norm(hidden_states, gate) + out = self.out_proj(hidden_states)[:, None, ...] + # if no cache is found, calling the kernel + else: + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states) + A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) + dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} + + if self.training and cache_params is None: + out, ssm_state = mamba_split_conv1d_scan_combined( + projected_states, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.dt_bias, + A, + D=self.D, + chunk_size=self.chunk_size, + seq_idx=None, # was seq_idx + activation=self.activation, + rmsnorm_weight=self.norm.weight, + rmsnorm_eps=self.norm.variance_epsilon, + outproj_weight=self.out_proj.weight, + outproj_bias=self.out_proj.bias, + headdim=self.head_dim, + ngroups=self.n_groups, + norm_before_gate=False, + return_final_states=True, + **dt_limit_kwargs, + ) + + else: + gate, hidden_states_B_C, time_step = torch.split( + projected_states, + [self.intermediate_size, self.conv_dim, self.num_heads], + dim=-1, + ) + + time_step = nn.functional.softplus(time_step + self.dt_bias) + # 1D Convolution + if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: + hidden_states_B_C = self.act( + self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len] + ) # (B, L, self.d_inner + 2 * ngroups * d_state) + else: + hidden_states_B_C = causal_conv1d_fn( + x=hidden_states_B_C.transpose(1, 2), + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + ).transpose(1, 2)[:, :seq_len] + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, groups_time_state_size, groups_time_state_size], + dim=-1, + ) + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + scan_output, ssm_state = mamba_chunk_scan_combined( + hidden_states.view(batch_size, seq_len, -1, self.head_dim), + time_step, + A, + B.view(batch_size, seq_len, self.n_groups, -1), + C.view(batch_size, seq_len, self.n_groups, -1), + chunk_size=self.chunk_size, + D=self.D, + z=None, + seq_idx=None, + return_final_states=True, + **dt_limit_kwargs, + ) + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + scan_output = scan_output.view(batch_size, seq_len, -1) + # Multiply "gate" branch and apply extra normalization layer + scan_output = self.norm(scan_output, gate) + out = self.out_proj(scan_output) + return out + + # fmt: off + def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype - # Gated MLP's linear projection projected_states = self.in_proj(input_states.squeeze(1)) - d_mlp = ( - projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size- self.num_heads) // 2 + d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size- self.num_heads) // 2 _, _, gate, hidden_states, dt = projected_states.split( [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 ) @@ -223,10 +422,10 @@ class Mamba2Mixer(nn.Module): # handle batched generation - states are copied through conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states cache_params.conv_states[self.layer_idx].copy_(conv_state) - hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1) if self.use_conv_bias: hidden_states += self.conv1d.bias - hidden_states = self.act(hidden_states)[:, None, ...] # [batch, 1, intermediate_size] : decoding + hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding else: hidden_states = hidden_states.transpose(1,2) conv_state = nn.functional.pad( @@ -235,16 +434,18 @@ class Mamba2Mixer(nn.Module): ) cache_params.conv_states[self.layer_idx].copy_(conv_state) hidden_states = self.act(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len] + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) else: ssm_state = torch.zeros( (batch_size, self.num_heads, self.head_dim, self.ssm_state_size), - device=hidden_states.device + device=hidden_states.device, dtype=dtype ) hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2)) - hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1) A = -torch.exp(self.A_log.float()) # [num_heads] - if cache_params is not None and cache_params.seqlen_offset > 0: # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation @@ -384,8 +585,25 @@ class Mamba2Mixer(nn.Module): # end ssd naive # 4. Final linear projection - contextualized_states = self.out_proj(scan_output) # [batch, seq_len, hidden_size] + contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size] return contextualized_states + # fmt: on + + def forward( + self, + hidden_states, + cache_params: Optional[Mamba2Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: + return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) + dtype = hidden_states.dtype + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) class Mamba2RMSNorm(nn.Module): @@ -399,47 +617,258 @@ class Mamba2RMSNorm(nn.Module): def forward(self, hidden_states): input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states + return self.weight * hidden_states.to(input_dtype) class Mamba2Block(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx): super().__init__() self.config = config + self.layer_idx = layer_idx + self.residual_in_fp32 = config.residual_in_fp32 self.norm = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) - self.mixer = Mamba2Mixer(config) + self.mixer = Mamba2Mixer(config, layer_idx=layer_idx) def forward( self, hidden_states, cache_params: Optional[Mamba2Cache] = None, cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, ): - x = self.mixer( - self.norm(hidden_states), cache_params=cache_params, cache_position=cache_position + residual = hidden_states + hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + + hidden_states = self.mixer( + hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask ) - return x + hidden_states + hidden_states = residual + hidden_states + return hidden_states -class Mamba2Model(nn.Module): +class Mamba2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Mamba2Config + base_model_prefix = "backbone" + _no_split_modules = ["Mamba2Block"] + supports_gradient_checkpointing = True + _is_stateful = True + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, Mamba2Mixer): + module.A_log._no_weight_decay = True + module.D._no_weight_decay = True + + dt = torch.exp( + torch.rand(self.config.num_heads) + * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + + math.log(self.config.time_step_min) + ).clamp(min=self.config.time_step_floor) + + # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + module.dt_bias.copy_(inv_dt) + module.dt_bias._no_reinit = True + + if isinstance(module, nn.Linear): + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=self.config.initializer_range) + + if self.config.rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(self.config.num_hidden_layers) + + +@dataclass +# Copied from transformers.models.mamba.modeling_mamba.MambaOutput with MAMBA->MAMBA2,Mamba->Mamba2 +class Mamba2Output(ModelOutput): + """ + Class for the MAMBA2 model outputs. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + cache_params (`Mamba2Cache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + cache_params: Optional[Mamba2Cache] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.mamba.modeling_mamba.MambaCausalLMOutput with Mamba->Mamba2 +class Mamba2CausalLMOutput(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + cache_params (`Mamba2Cache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + cache_params: Optional[Mamba2Cache] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +MAMBA2_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Mamba2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +MAMBA2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + Indices of input sequence tokens in the vocabulary. + + If `cache_params.seqlen_offset>0`, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + cache_params (`Mamba2Cache`, *optional*): + If passed along, the model uses the previous state in all the blocks (which will give the output for the + `input_ids` provided as if the model add `state_input_ids + input_ids` as context). + use_cache (`bool`, *optional*): + If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare MAMBA2 Model transformer outputting raw hidden-states without any specific head on top.", + MAMBA2_START_DOCSTRING, +) +class Mamba2Model(Mamba2PreTrainedModel): def __init__(self, config): super().__init__(config) + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([Mamba2Block(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]) self.gradient_checkpointing = False self.norm_f = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + # Initialize weights and apply final processing + self._register_load_state_dict_pre_hook(self.load_hook) + self.post_init() + def load_hook(self, state_dict, prefix, *args): + for k in state_dict: + if "embedding." in k: + state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k) + break + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings = new_embeddings + + @add_start_docstrings_to_model_forward(MAMBA2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=Mamba2Output, + config_class=_CONFIG_FOR_DOC, + ) def forward( self, input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, cache_params: Optional[Mamba2Cache] = None, use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - ): - inputs_embeds = self.embeddings(input_ids) + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[Tuple, Mamba2Output]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + if self.gradient_checkpointing and self.training and use_cache: + use_cache = False if use_cache: if cache_params is None: @@ -447,44 +876,206 @@ class Mamba2Model(nn.Module): self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype ) cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device) + elif cache_position is None: + # cases when we do manual forward instead of using `model.generate` which will initiate + # `cache_position` and makes sure it is not None, throw error here instead of doing some + # hack to conjecture the current cache position + raise ValueError( + "You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, " + "you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will " + "be initialized for you automatically" + ) else: cache_params = None hidden_states = inputs_embeds + all_hidden_states = () if output_hidden_states else None for mixer_block in self.layers: - hidden_states = mixer_block( - hidden_states, - cache_params=cache_params, - cache_position=cache_position, - ) + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask + ) + else: + hidden_states = mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) if use_cache: cache_params.seqlen_offset += inputs_embeds.shape[1] - return self.norm_f(hidden_states), cache_params if use_cache else None + hidden_states = self.norm_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None) + + return Mamba2Output( + last_hidden_state=hidden_states, + cache_params=cache_params if use_cache else None, + hidden_states=all_hidden_states, + ) +@add_start_docstrings( + """ + The MAMBA2 Model transformer with a language modeling head on top (linear layer with weights not tied to the input + embeddings). + """, + MAMBA2_START_DOCSTRING, +) +class Mamba2ForCausalLM(Mamba2PreTrainedModel): + _tied_weights_keys = [] -class Mamba2ForCausalLM(nn.Module): def __init__(self, config): super().__init__(config) self.backbone = Mamba2Model(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + # Initialize weights and apply final processing + self.post_init() + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.backbone.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.backbone.set_input_embeddings(new_embeddings) + + def prepare_inputs_for_generation( + self, + input_ids, + inputs_embeds=None, + use_cache=None, + cache_params: Optional[Mamba2Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ): + if inputs_embeds is not None: + past_len = inputs_embeds.shape[1] + input_ids.shape[1] + else: + past_len = input_ids.shape[1] + if use_cache: + # `cache_position` should have been initialized in `generate` + if cache_position is None: + raise ValueError( + "`cache_position` should not be None as it should have been initialized in " + "`model.generate`, you are responsible for passing in a valid `cache_position` if " + "you are calling `prepare_inputs_for_generation` directly with `use_cache=True`" + ) + # how do we detect that we are in decoding without cache? + if cache_position[0] > 0: + input_ids = input_ids[:, -1][..., None] + attention_mask = attention_mask[:, -1][..., None] + else: + # we initialize the `cache_position` to full size of `conv_states` at prefill stage + # considering padding will be applied when input length is shorter, and truncation + # will be applied when it is longer, so it will be equivalent to always have it match + # the length of `cache_params.conv_states`, which is `config.conv_kernel` + cache_position = torch.arange(0, past_len, device=input_ids.device) + # if the cache is not used, we also do have to extend the attention mask here + # TODO there is likely a cleverer way to do this + extended_mask = torch.ones( + attention_mask.size(0), past_len - attention_mask.shape[1], device=attention_mask.device + ) + attention_mask = torch.cat([attention_mask, extended_mask], dim=1) + cache_params = None + + if attention_mask.shape[1] < past_len: + # we have to update manually the attention mask if + # we are in decoding without cache + # and we don't have position_ids here + # TODO but we should be able to use cache_position though at a later time + extended_mask = torch.ones( + attention_mask.size(0), past_len - attention_mask.shape[1], device=attention_mask.device + ) + attention_mask = torch.cat([attention_mask, extended_mask], dim=1) + if inputs_embeds is not None and cache_params is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "attention_mask": attention_mask, + "cache_params": cache_params, + "use_cache": use_cache, + "cache_position": cache_position, + } + ) + return model_inputs + + @add_start_docstrings_to_model_forward(MAMBA2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=Mamba2CausalLMOutput, + config_class=_CONFIG_FOR_DOC, + ) def forward( self, input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, cache_params: Optional[Mamba2Cache] = None, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.Tensor] = None, - ): + attention_mask: Optional[torch.Tensor] = None, + **kwargs, # for now we need this for generation + ) -> Union[Tuple, Mamba2CausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + mamba2_outputs = self.backbone( input_ids, cache_params=cache_params, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + return_dict=return_dict, use_cache=use_cache, cache_position=cache_position, + attention_mask=attention_mask, ) hidden_states = mamba2_outputs[0] - logits = self.lm_head(hidden_states) - return logits, mamba2_outputs.cache_params, mamba2_outputs.hidden_states \ No newline at end of file + logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float() + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (logits,) + mamba2_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Mamba2CausalLMOutput( + loss=loss, + logits=logits, + cache_params=mamba2_outputs.cache_params, + hidden_states=mamba2_outputs.hidden_states, + ) diff --git a/llms/mlx_lm/models/mamba2-works-hella-alow.py b/llms/mlx_lm/models/mamba2-works-hella-alow.py new file mode 100644 index 00000000..4468432f --- /dev/null +++ b/llms/mlx_lm/models/mamba2-works-hella-alow.py @@ -0,0 +1,337 @@ +import math +from dataclasses import dataclass, field +from typing import Tuple, Union +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs +from .cache import Mamba2Cache + +@dataclass +class ModelArgs(BaseModelArgs): + num_heads: int + head_dim: int + vocab_size: int + hidden_size: int + state_size: int + num_hidden_layers: int + layer_norm_epsilon: float + expand: int + conv_kernel: int + n_groups: int + use_bias: bool + use_conv_bias: bool + initializer_range: float + residual_in_fp32: bool + time_step_min: float + time_step_max: float + time_step_floor: float + rescale_prenorm_residual: bool + rms_norm: bool + chunk_size: int + tie_word_embeddings: bool + use_cache: bool = True + time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf"))) + time_step_rank: Union[int, str] = "auto" + model_type: str = "mamba2" + + def __post_init__(self): + if not hasattr(self, "intermediate_size"): + self.intermediate_size = int(self.expand * self.hidden_size) + if not hasattr(self, "head_dim"): + self.head_dim = self.hidden_size // self.num_heads + if self.time_step_rank == "auto": + self.time_step_rank = math.ceil(self.hidden_size / 16) + + +class MambaRMSNormGated(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = mx.ones((hidden_size,)) + self.variance_epsilon = eps + + def __call__(self, hidden_states, gate=None): + if gate is not None: + hidden_states = hidden_states * nn.silu(gate) + variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True) + hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states + + +def silu(x): + return x * mx.sigmoid(x) + +def ssd(x, A, B, C, chunk_size): + # Replace einsum operations with explicit reshape and matrix multiply + batch, seqlen, nheads, dim = x.shape + B = mx.expand_dims(B, axis=2) + C = mx.expand_dims(C, axis=2) + + state = mx.zeros((batch, nheads, dim, B.shape[-1])) + outputs = [] + + for i in range(0, seqlen, chunk_size): + chunk = slice(i, min(i + chunk_size, seqlen)) + dA = mx.exp(mx.expand_dims(A[chunk], axis=0)) + + # Replace einsum with explicit operations + x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim] + x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size] + B_chunk = B[:, chunk] # [batch, chunk_size, state_size] + dBx = mx.matmul(x_chunk, B_chunk) # [batch, nheads, dim, state_size] + + state = state * mx.expand_dims(dA, axis=-1) + dBx + + # Replace einsum with explicit operations + C_chunk = C[:, chunk] # [batch, chunk_size, state_size] + y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size] + y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim] + outputs.append(y) + + return mx.concatenate(outputs, axis=1), state + + +class DepthWiseConv1d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.padding = padding + self.groups = groups if groups is not None else in_channels + + assert in_channels == out_channels, "In and out channels must be same for depthwise convolution" + assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution" + + self.weight = mx.random.normal((in_channels, 1, kernel_size)) + self.bias = mx.zeros((out_channels,)) if bias else None + + def __call__(self, x: mx.array, cache=None) -> mx.array: + B, L, C = x.shape + K = self.kernel_size + + assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}" + + if cache is not None and cache.conv_states[0] is not None: + # Convert None to proper array if needed + if isinstance(cache.conv_states[0], type(None)): + cache.conv_states[0] = mx.zeros((B, K-1, C)) + x = mx.concatenate([cache.conv_states[0], x], axis=1) + + # Process each channel independently + outputs = [] + for c in range(C): + x_c = x[:, :, c] + x_c = mx.expand_dims(x_c, axis=1) + + w_c = self.weight[c] + if w_c.ndim == 2: + w_c = mx.expand_dims(w_c, axis=0) + elif w_c.ndim == 1: + w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0) + + # Apply convolution + y_c = mx.conv_general( + x_c, + w_c, + stride=1, + padding=0 + ) + + if self.bias is not None: + y_c = y_c + self.bias[c] + + outputs.append(mx.squeeze(y_c, axis=1)) + + y = mx.stack(outputs, axis=-1) + + # Update cache + if cache is not None: + cache.conv_states[0] = x[:, -K+1:, :] if x.shape[1] >= K else x + + return y + + +class Mamba2Block(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + + d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads + self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias) + + conv_dim = args.intermediate_size + 2 * args.state_size + self.conv1d = DepthWiseConv1d( + in_channels=conv_dim, + out_channels=conv_dim, + kernel_size=args.conv_kernel, + groups=conv_dim, + bias=args.use_conv_bias, + padding=args.conv_kernel - 1 + ) + + self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range + self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range + self.D = mx.random.normal((args.num_heads,)) * args.initializer_range + + self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon) + self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) + + if args.rescale_prenorm_residual: + layer_scale = math.sqrt(1.0 / args.num_hidden_layers) + self.out_proj.weight = self.out_proj.weight * layer_scale + + def __call__(self, u: mx.array, cache=None): + batch_size = u.shape[0] + seq_len = u.shape[1] + outputs = [] + + # Initialize states if needed + if cache.conv_states[0] is None: + conv_dim = self.args.intermediate_size + 2 * self.args.state_size + cache.conv_states[0] = mx.zeros(( + batch_size, + self.args.conv_kernel - 1, + conv_dim + )) + + if cache.ssm_states[0] is None: + cache.ssm_states[0] = mx.zeros(( + batch_size, + self.args.num_heads, + self.args.head_dim, + self.args.state_size + )) + + for pos in range(seq_len): + u_t = u[:, pos:pos+1, :] + zxbcdt = self.in_proj(u_t) + + n_heads = self.args.num_heads + + z = zxbcdt[:, :, :self.args.intermediate_size] + xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size] + dt = zxbcdt[:, :, -(n_heads):] + + dt = mx.reshape(dt, (batch_size, n_heads)) + dt = mx.clip( + nn.softplus(dt + self.dt_bias), + self.args.time_step_min, + self.args.time_step_max + ) + dt = mx.maximum(dt, self.args.time_step_floor) + + xBC = self.conv1d(xBC, cache=cache) + xBC = silu(xBC) + + x = xBC[:, :, :self.args.intermediate_size] + B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size] + C = xBC[:, :, -self.args.state_size:] + + x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim)) + x = mx.squeeze(x, axis=1) + + B = mx.reshape(B, (batch_size, 1, self.args.state_size)) + B = mx.broadcast_to(B, (batch_size, n_heads, self.args.state_size)) + B = mx.expand_dims(B, axis=2) + + C = mx.reshape(C, (batch_size, 1, self.args.state_size)) + C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size)) + C = mx.expand_dims(C, axis=3) + + A = -mx.exp(self.A_log) + dA = mx.exp(dt * mx.expand_dims(A, 0)) + dA = mx.expand_dims(mx.expand_dims(dA, -1), -1) + + x = mx.expand_dims(x, axis=3) + dBx = mx.matmul(x, B) + + cache.ssm_states[0] = cache.ssm_states[0] * dA + dBx + + y = mx.matmul(cache.ssm_states[0], C) + y = mx.squeeze(y, axis=-1) + + y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1) + + y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim)) + y = self.norm(y + z) + y = self.out_proj(y) + outputs.append(y) + + return mx.concatenate(outputs, axis=1) + + +class ResidualBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.residual_in_fp32 = args.residual_in_fp32 + + self.mixer = Mamba2Block(args) + self.norm = nn.RMSNorm(args.hidden_size) + + def __call__(self, x: mx.array, cache): + if self.residual_in_fp32: + x = x.astype(mx.float32) + return self.mixer(self.norm(x), cache) + x + + +class Mamba2(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] + self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + + def __call__(self, x: mx.array, cache): + x = self.embeddings(x) + if cache is None: + cache = [None] * len(self.layers) + for layer, c in zip(self.layers, cache): + x = layer(x, c) + return self.norm_f(x) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + + self.backbone = Mamba2(args) + + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__(self, inputs: mx.array, cache=None): + B, T = inputs.shape + + x = self.backbone(inputs, cache) + + if self.args.tie_word_embeddings: + logits = self.backbone.embeddings.as_linear(x) + else: + logits = self.lm_head(x) + + return logits + + def make_cache(self, batch_size=1): + return [Mamba2Cache(batch_size, self.args.conv_kernel) for _ in range(len(self.layers))] + + def sanitize(self, weights): + sanitized = {} + for k, v in weights.items(): + if "conv1d.weight" in k: + # Ensure weights are in correct shape (channels, 1, kernel_size) + if v.ndim == 2: + v = mx.expand_dims(v, axis=1) + elif v.ndim == 1: + v = mx.expand_dims(mx.expand_dims(v, axis=0), axis=0) + sanitized[k] = v + else: + sanitized[k] = v + return sanitized + + @property + def layers(self): + return self.backbone.layers \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index bd0f17ee..e305fae8 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -1,7 +1,6 @@ import math from dataclasses import dataclass, field -from typing import Optional, Tuple, Union - +from typing import Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -28,18 +27,17 @@ class ModelArgs(BaseModelArgs): time_step_max: float time_step_floor: float rescale_prenorm_residual: bool - use_cache: bool rms_norm: bool chunk_size: int tie_word_embeddings: bool - intermediate_size: int = None + use_cache: bool = True time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf"))) time_step_rank: Union[int, str] = "auto" model_type: str = "mamba2" def __post_init__(self): - self.intermediate_size = int(self.expand * self.hidden_size) # E*D = ED - + if not hasattr(self, "intermediate_size"): + self.intermediate_size = int(self.expand * self.hidden_size) if not hasattr(self, "head_dim"): self.head_dim = self.hidden_size // self.num_heads if self.time_step_rank == "auto": @@ -49,256 +47,241 @@ class ModelArgs(BaseModelArgs): class MambaRMSNormGated(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() - self.weight = mx.ones(hidden_size) + self.weight = mx.ones((hidden_size,)) self.variance_epsilon = eps - def forward(self, hidden_states, gate=None): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(mx.float32) - + def __call__(self, hidden_states, gate=None): if gate is not None: - hidden_states = hidden_states * nn.functional.silu(gate.to(mx.float32)) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * math.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - + hidden_states = hidden_states * nn.silu(gate) + variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True) + hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states -class Mamba2Mixer(nn.Module): + +def silu(x): + return x * mx.sigmoid(x) + +def ssd(x, A, B, C, chunk_size): + # Replace einsum operations with explicit reshape and matrix multiply + batch, seqlen, nheads, dim = x.shape + B = mx.expand_dims(B, axis=2) + C = mx.expand_dims(C, axis=2) + + state = mx.zeros((batch, nheads, dim, B.shape[-1])) + outputs = [] + + for i in range(0, seqlen, chunk_size): + chunk = slice(i, min(i + chunk_size, seqlen)) + dA = mx.exp(mx.expand_dims(A[chunk], axis=0)) + + # Replace einsum with explicit operations + x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim] + x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size] + B_chunk = B[:, chunk] # [batch, chunk_size, state_size] + dBx = mx.matmul(x_chunk, B_chunk) # [batch, nheads, dim, state_size] + + state = state * mx.expand_dims(dA, axis=-1) + dBx + + # Replace einsum with explicit operations + C_chunk = C[:, chunk] # [batch, chunk_size, state_size] + y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size] + y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim] + outputs.append(y) + + return mx.concatenate(outputs, axis=1), state + + +class DepthWiseConv1d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.padding = padding + self.groups = groups if groups is not None else in_channels + + assert in_channels == out_channels, "In and out channels must be same for depthwise convolution" + assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution" + + # Initialize weight with correct shape [C_out, 1, kernel_size] + self.weight = mx.random.normal((out_channels, 1, kernel_size)) + self.bias = mx.zeros((out_channels,)) if bias else None + + def __call__(self, x: mx.array, cache=None) -> mx.array: + B, L, C = x.shape + K = self.kernel_size + + assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}" + + # Handle caching for sequential processing + if cache is not None and cache.conv_states[0] is not None: + if isinstance(cache.conv_states[0], type(None)): + cache.conv_states[0] = mx.zeros((B, K-1, C)) + x = mx.concatenate([cache.conv_states[0], x], axis=1) + + # Process each channel independently + outputs = [] + for c in range(C): + # Extract and reshape the channel + x_c = x[:, :, c] # [B, L] + x_c = mx.expand_dims(x_c, axis=1) # [B, 1, L] + + # Get weight for this channel - already in correct shape [1, 1, K] + w_c = mx.expand_dims(self.weight[c], axis=0) # Ensure [1, 1, K] + + # Apply convolution + y_c = mx.conv_general( + x_c, + w_c, + stride=1, + padding=self.padding + ) + + if self.bias is not None: + y_c = y_c + self.bias[c] + + outputs.append(mx.squeeze(y_c, axis=1)) + + y = mx.stack(outputs, axis=-1) + + # Update cache + if cache is not None: + cache.conv_states[0] = x[:, -K+1:, :] if x.shape[1] >= K else x + + return y + + +class Mamba2Block(nn.Module): def __init__(self, args: ModelArgs): super().__init__() - # Model dimensions - self.hidden_size = args.hidden_size - self.num_heads = args.num_heads - self.head_dim = args.head_dim - self.ssm_state_size = args.state_size - self.n_groups = args.n_groups - self.intermediate_size = int(args.expand * args.hidden_size) - - # Convolution parameters - self.conv_kernel = args.conv_kernel - self.use_conv_bias = args.use_conv_bias - - # Time step parameters - self.time_step_rank = int(args.time_step_rank) - self.time_step_min = args.time_step_min - self.time_step_max = args.time_step_max - - # Processing parameters + self.args = args + self.chunk_size = args.chunk_size - self.layer_norm_epsilon = args.layer_norm_epsilon - - # Calculate dimensions - self.conv_dim = (self.intermediate_size + - 2 * self.n_groups * self.ssm_state_size) - projection_size = (self.intermediate_size + - self.conv_dim + - self.num_heads) - - # Initialize layers - self.in_proj = nn.Linear( - self.hidden_size, - projection_size, - bias=args.use_bias - ) - - self.conv1d = nn.Conv1d( + + d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads + self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias) + + self.conv_dim = args.intermediate_size + 2 * args.state_size + self.conv1d = DepthWiseConv1d( in_channels=self.conv_dim, out_channels=self.conv_dim, - kernel_size=self.conv_kernel, + kernel_size=args.conv_kernel, groups=self.conv_dim, - padding=self.conv_kernel - 1, - bias=self.use_conv_bias + bias=args.use_conv_bias, + padding=args.conv_kernel - 1 ) - # Initialize parameters - self.dt_bias = mx.ones(self.num_heads) - A = mx.arange(1, self.num_heads + 1) - self.A_log = mx.log(A) - self.D = mx.ones(self.num_heads) - - # Output layers - self.norm = MambaRMSNormGated( - self.intermediate_size, - eps=self.layer_norm_epsilon - ) - self.out_proj = nn.Linear( - self.intermediate_size, - self.hidden_size, - bias=args.use_bias - ) + self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range + self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range + self.D = mx.random.normal((args.num_heads,)) * args.initializer_range - def reshape_into_chunks(self, tensor, pad_size, chunk_size): - if pad_size > 0: - pad_shape = list(tensor.shape) - pad_shape[1] = pad_size - padding = mx.zeros(pad_shape, dtype=tensor.dtype) - tensor = mx.concatenate([tensor, padding], axis=1) - - chunk_shape = list(tensor.shape) - chunk_shape[1] = -1 - chunk_shape.insert(2, chunk_size) - return tensor.reshape(chunk_shape) + self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon) + self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) - def segment_sum(self, x): - return mx.cumsum(x, axis=-1) + if args.rescale_prenorm_residual: + layer_scale = math.sqrt(1.0 / args.num_hidden_layers) + self.out_proj.weight = self.out_proj.weight * layer_scale - def process_single_token(self, hidden_states, B, C, dt, cache): - batch_size = hidden_states.shape[0] - - # Process convolution state - if cache is not None: - conv_state = cache.conv_states - # Roll the conv state and update the last position - conv_state = mx.roll(conv_state, shift=-1, axis=-1) - # Create new conv state with updated last position - new_conv_state = mx.array(conv_state) - new_conv_state = new_conv_state.at[:, :, -1].add(hidden_states) - conv_state = new_conv_state - - # Compute convolution - conv_out = mx.sum(conv_state * self.conv1d.weight[:, 0, :], axis=-1) - if self.use_conv_bias: - conv_out = conv_out + self.conv1d.bias - - # Apply SiLU activation - conv_out = mx.sigmoid(conv_out) * conv_out - - else: - # Initialize new cache - conv_state = mx.zeros((batch_size, self.conv_dim, self.conv_kernel - 1)) - conv_out = self.conv1d(hidden_states) - conv_out = mx.sigmoid(conv_out) * conv_out - - # Process SSM + def __call__(self, u: mx.array, cache=None): + # Expect input shape: [batch_size, 1, hidden_size] + batch_size, seq_len, _ = u.shape + pad_size = self.chunk_size - (seq_len % self.chunk_size) + + # Initialize states if needed + if cache.conv_states[0] is None: + cache.conv_states[0] = mx.zeros(( + batch_size, + self.args.conv_kernel - 1, + self.conv_dim + )) + + if cache.ssm_states[0] is None: + cache.ssm_states[0] = mx.zeros(( + batch_size, + self.args.num_heads, + self.args.head_dim, + self.args.state_size + )) + + # Project input + zxbcdt = self.in_proj(u) + + # Split projections + z = zxbcdt[:, :, :self.args.intermediate_size] + xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size] + dt = zxbcdt[:, :, -(self.args.num_heads):] + + # Process delta time + dt = mx.reshape(dt, (batch_size, seq_len, self.args.num_heads)) + dt = mx.squeeze(dt, axis=0) # Remove sequence dimension for single token dt = mx.clip( nn.softplus(dt + self.dt_bias), - self.time_step_min, - self.time_step_max + self.args.time_step_min, + self.args.time_step_max ) - - A = -mx.exp(self.A_log) - dA = mx.exp(dt * A[None, :]) - - if cache is not None: - ssm_state = cache.ssm_states - else: - ssm_state = mx.zeros( - (batch_size, self.num_heads, self.head_dim, self.ssm_state_size) - ) - - # Compute SSM updates - dBx = mx.einsum('bh,bhs,bhd->bhds', dt, B, hidden_states) - next_state = ssm_state * dA[:, :, None, None] + dBx - y = mx.einsum('bhds,bhs->bhd', next_state, C) - - # Add skip connection - y = y + hidden_states * self.D[None, :, None] - - return y, conv_state, next_state + dt = mx.maximum(dt, self.args.time_step_floor) - def process_long_sequence(self, hidden_states, B, C, dt, ssm_state): - batch_size, seq_len = hidden_states.shape[:2] - pad_size = self.chunk_size - (seq_len % self.chunk_size) - - # Reshape into chunks - x_chunks = self.reshape_into_chunks(hidden_states, pad_size, self.chunk_size) - B_chunks = self.reshape_into_chunks(B, pad_size, self.chunk_size) - C_chunks = self.reshape_into_chunks(C, pad_size, self.chunk_size) - - # Process time steps - dt = nn.softplus(dt + self.dt_bias) - dt = mx.clip(dt, self.time_step_min) - - # Prepare matrices + # Convolution step + xBC = self.conv1d(xBC, cache=cache) + xBC = silu(xBC) + + # Split conv output + x = xBC[:, :, :self.args.intermediate_size] + B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size] + C = xBC[:, :, -self.args.state_size:] + + # Reshape for SSM + x = mx.reshape(x, (batch_size, 1, self.args.num_heads, self.args.head_dim)) + x = mx.squeeze(x, axis=1) + + B = mx.reshape(B, (batch_size, 1, self.args.state_size)) + B = mx.broadcast_to(B, (batch_size, self.args.num_heads, self.args.state_size)) + B = mx.expand_dims(B, axis=2) + + C = mx.reshape(C, (batch_size, 1, self.args.state_size)) + C = mx.broadcast_to(C, (batch_size, self.args.num_heads, self.args.state_size)) + C = mx.expand_dims(C, axis=3) + + # SSM state update A = -mx.exp(self.A_log) - A = A * dt[:, None] - - # Process chunks - A_chunks = self.reshape_into_chunks( - mx.broadcast_to(A, (batch_size, seq_len + pad_size, self.num_heads)), - pad_size, - self.chunk_size - ) - - # Compute cumulative sums - A_cumsum = mx.cumsum(A_chunks, axis=-1) - L = mx.exp(self.segment_sum(A_chunks)) - - # Process diagonal blocks - G = mx.einsum('...lhn,...shn->...lsh', C_chunks, B_chunks) - M = G * L[..., None, :] - Y_diag = mx.einsum('...lsh,...sh->...lh', M, x_chunks) - - # Process off-diagonal blocks - decay_states = mx.exp(A_cumsum[..., -1:] - A_cumsum) - B_decay = B_chunks * decay_states[..., None] - states = mx.einsum('...shn,...sh->...hn', B_decay, x_chunks) - - # Combine results - y = Y_diag + states - - # Remove padding if necessary + dA = mx.exp(dt * mx.expand_dims(A, 0)) + dA = mx.expand_dims(mx.expand_dims(dA, -1), -1) + + x = mx.expand_dims(x, axis=3) + dBx = mx.matmul(x, B) + + cache.ssm_states[0] = cache.ssm_states[0] * dA + dBx + + # Output computation + y = mx.matmul(cache.ssm_states[0], C) + y = mx.squeeze(y, axis=-1) + + # y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1) if pad_size > 0: - y = y[:, :seq_len] - - return y, ssm_state + y = y[:, :seq_len, :, :] + + # Final reshape and projections + y = mx.reshape(y, (batch_size, 1, self.args.num_heads * self.args.head_dim)) + y = self.norm(y + z) - def __call__(self, x: mx.array, cache: Optional[Mamba2Cache] = None) -> mx.array: - batch_size, seq_len, _ = x.shape - - # Project input - projected_states = self.in_proj(x.squeeze(1)) - - # Calculate d_mlp based on projection size - d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * - self.n_groups * self.ssm_state_size - self.num_heads) // 2 - - # Split projections with corrected dimensions - splits = [ - d_mlp, # z0 - d_mlp, # x0 - self.intermediate_size, # gate - self.conv_dim, # hidden_states - self.num_heads # dt - ] - - z0, x0, x1, gate, hidden_states, dt = projected_states.split(splits, axis=-1) - - # Split hidden states into components - x_conv, BC = mx.split(hidden_states, [self.intermediate_size], axis=-1) - B, C = mx.split(BC, [self.n_groups * self.ssm_state_size], axis=-1) - - # Process based on sequence length - if seq_len > 1 and cache is None: - y, next_state = self.process_long_sequence( - x_conv, B, C, dt, - mx.zeros((batch_size, self.num_heads, self.head_dim, self.ssm_state_size)) - ) - else: - # Reshape for single token processing - x_conv = x_conv.reshape(batch_size, -1, self.head_dim) - B = B.reshape(batch_size, self.num_heads, -1) - C = C.reshape(batch_size, self.num_heads, -1) - y, conv_state, next_state = self.process_single_token(x_conv, B, C, dt, cache) - - if cache is not None: - cache.update(conv_state, next_state) - - # Apply normalization and final projection - y = self.norm(y) * gate return self.out_proj(y) - + class ResidualBlock(nn.Module): def __init__(self, args: ModelArgs): super().__init__() - self.mixer = Mamba2Mixer(args) + self.residual_in_fp32 = args.residual_in_fp32 + + self.mixer = Mamba2Block(args) self.norm = nn.RMSNorm(args.hidden_size) - def __call__(self, x: mx.array, cache: Optional[Mamba2Cache] = None) -> mx.array: + def __call__(self, x: mx.array, cache): + if self.residual_in_fp32: + x = x.astype(mx.float32) return self.mixer(self.norm(x), cache) + x -class Mamba2Model(nn.Module): + +class Mamba2(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args = args @@ -306,26 +289,27 @@ class Mamba2Model(nn.Module): self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) - def __call__(self, x: mx.array, cache=None) -> mx.array: + def __call__(self, x: mx.array, cache): x = self.embeddings(x) - if cache is None: cache = [None] * len(self.layers) - - for layer, layer_cache in zip(self.layers, cache): - x = layer(x, layer_cache) - + for layer, c in zip(self.layers, cache): + x = layer(x, c) return self.norm_f(x) + class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args = args - self.backbone = Mamba2Model(args) + self.model_type = args.model_type + + self.backbone = Mamba2(args) + if not args.tie_word_embeddings: self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - def __call__(self, inputs: mx.array, cache=None) -> mx.array: + def __call__(self, inputs: mx.array, cache=None): B, T = inputs.shape x = self.backbone(inputs, cache) @@ -336,26 +320,24 @@ class Model(nn.Module): logits = self.lm_head(x) return logits - + def make_cache(self, batch_size=1): - return [ - Mamba2Cache( - batch_size=batch_size, - conv_dim=self.args.intermediate_size + 2 * self.args.n_groups * self.args.state_size, - kernel_size=self.args.conv_kernel, - num_heads=self.args.num_heads, - head_dim=self.args.head_dim, - state_size=self.args.state_size - ) - for _ in range(len(self.backbone.layers)) - ] - + return [Mamba2Cache(batch_size, self.args.conv_kernel) for _ in range(len(self.layers))] + def sanitize(self, weights): + sanitized = {} for k, v in weights.items(): - if "conv1d.weight" in k and v.ndim == 3: - weights[k] = v.moveaxis(2, 1) - return weights - + if "conv1d.weight" in k: + # Ensure weights are in correct shape (channels, 1, kernel_size) + if v.ndim == 2: + v = mx.expand_dims(v, axis=1) + elif v.ndim == 1: + v = mx.expand_dims(mx.expand_dims(v, axis=0), axis=0) + sanitized[k] = v + else: + sanitized[k] = v + return sanitized + @property def layers(self): return self.backbone.layers diff --git a/llms/mlx_lm/models/mamba22.py b/llms/mlx_lm/models/mamba22.py new file mode 100644 index 00000000..c0cbe1d7 --- /dev/null +++ b/llms/mlx_lm/models/mamba22.py @@ -0,0 +1,316 @@ +import math +from dataclasses import dataclass, field +from typing import Tuple, Union +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs +from .cache import Mamba2Cache + + +@dataclass +class ModelArgs(BaseModelArgs): + num_heads: int + head_dim: int + vocab_size: int + hidden_size: int + state_size: int + num_hidden_layers: int + layer_norm_epsilon: float + expand: int + conv_kernel: int + n_groups: int + use_bias: bool + use_conv_bias: bool + initializer_range: float + residual_in_fp32: bool + time_step_min: float + time_step_max: float + time_step_floor: float + rescale_prenorm_residual: bool + rms_norm: bool + chunk_size: int + tie_word_embeddings: bool + use_cache: bool = True + time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf"))) + time_step_rank: Union[int, str] = "auto" + model_type: str = "mamba2" + + def __post_init__(self): + if not hasattr(self, "intermediate_size"): + self.intermediate_size = int(self.expand * self.hidden_size) + if not hasattr(self, "head_dim"): + self.head_dim = self.hidden_size // self.num_heads + if self.time_step_rank == "auto": + self.time_step_rank = math.ceil(self.hidden_size / 16) + + +def silu(x): + return x * mx.sigmoid(x) + + +class MambaRMSNormGated(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = mx.ones((hidden_size,)) + self.variance_epsilon = eps + + def __call__(self, hidden_states, gate=None): + # Fuse operations where possible + if gate is not None: + hidden_states = hidden_states * nn.silu(gate) + # Compute variance in fp32 for better numerical stability + hidden_states_fp32 = hidden_states.astype(mx.float32) + variance = mx.mean(hidden_states_fp32 * hidden_states_fp32, axis=-1, keepdims=True) + hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states + + +def ssd_optimized(x, A, B, C, chunk_size): + batch, seqlen, nheads, dim = x.shape + B = mx.expand_dims(B, axis=2) + C = mx.expand_dims(C, axis=2) + + output = mx.zeros((batch, seqlen, nheads, dim)) + state = mx.zeros((batch, nheads, dim, B.shape[-1])) + + for i in range(0, seqlen, chunk_size): + chunk = slice(i, min(i + chunk_size, seqlen)) + chunk_size_actual = min(chunk_size, seqlen - i) + + dA = mx.exp(mx.expand_dims(A[chunk], axis=0)) + x_chunk = mx.transpose(x[:, chunk], [0, 2, 3, 1]) + dBx = mx.matmul(x_chunk, B[:, chunk]) + state = state * mx.expand_dims(dA, axis=-1) + dBx + y = mx.matmul(state, mx.transpose(C[:, chunk], [0, 2, 1])) + output[:, i:i+chunk_size_actual] = mx.transpose(y, [0, 3, 1, 2]) + + return output, state + + +def update_conv_cache(x: mx.array, cache, kernel_size: int) -> Tuple[mx.array, mx.array]: + """Update convolution cache for sequential processing.""" + B, L, C = x.shape + + if cache is None: + # Initialize cache with zeros + cache = mx.zeros((B, kernel_size - 1, C)) + + # Concatenate cache with current input + x_with_cache = mx.concatenate([cache, x], axis=1) + + # Update cache with the last (kernel_size - 1) elements + new_cache = x_with_cache[:, -kernel_size+1:] if x_with_cache.shape[1] >= kernel_size else x_with_cache + + return x_with_cache, new_cache + + +class Mamba2Block(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + + self.intermediate_size = int(args.expand * args.hidden_size) + self.state_size = args.state_size + self.num_heads = args.num_heads + self.head_dim = args.head_dim + self.ssm_state_size = args.state_size + self.n_groups = args.n_groups + self.conv_kernel = args.conv_kernel + + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + projection_size = self.intermediate_size + self.conv_dim + self.num_heads + + self.in_proj = nn.Linear(args.hidden_size, projection_size, bias=args.use_bias) + + # Using built-in Conv1d instead of custom DepthwiseConv1d + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + kernel_size=args.conv_kernel, + groups=self.conv_dim, # For depthwise convolution + padding=0, # We'll handle padding manually with the cache + bias=args.use_conv_bias + ) + + self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range + self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range + self.D = mx.random.normal((args.num_heads,)) * args.initializer_range + + self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon) + self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) + + if args.rescale_prenorm_residual: + layer_scale = math.sqrt(1.0 / args.num_hidden_layers) + self.out_proj.weight = self.out_proj.weight * layer_scale + + def __call__(self, u: mx.array, cache): + batch_size, seq_len, _ = u.shape + + projected = self.in_proj(u) + d_conv = self.conv_dim + + z = projected[..., :self.intermediate_size] + xBC = projected[..., self.intermediate_size:self.intermediate_size + d_conv] + dt = projected[..., -self.num_heads:] + + dt = mx.clip( + nn.softplus(dt + self.dt_bias), + self.args.time_step_min, + self.args.time_step_max + ) + dt = mx.maximum(dt, self.args.time_step_floor) + + # Handle convolution with separate cache update + if cache is not None: + # Update cache and get padded input + xBC_padded, new_cache = update_conv_cache(xBC, cache.conv_states, self.conv_kernel) + cache.conv_states = new_cache + + # Prepare input for conv1d: [B, L, C] -> [B, C, L] + xBC_conv = mx.transpose(xBC_padded, [0, 2, 1]) + + # Apply convolution + xBC = self.conv1d(xBC_conv) + + # Transform back: [B, C, L] -> [B, L, C] + xBC = mx.transpose(xBC, [0, 2, 1]) + + # Take only the relevant part corresponding to input length + xBC = xBC[:, :seq_len] + else: + # For training, use regular convolution with padding + xBC = mx.transpose(xBC, [0, 2, 1]) + xBC = self.conv1d(xBC) + xBC = mx.transpose(xBC, [0, 2, 1]) + + xBC = silu(xBC) + + x = xBC[..., :self.intermediate_size] + BC = xBC[..., self.intermediate_size:] + B = BC[..., :self.state_size] + C = BC[..., self.state_size:] + + x = mx.reshape(x, (-1, seq_len, self.num_heads, self.intermediate_size // self.num_heads)) + + A = -mx.exp(self.A_log) + D_expanded = mx.expand_dims(self.D, -1) + + if cache is not None and cache.ssm_state is None: + cache.ssm_state = mx.zeros(( + batch_size, + self.num_heads, + self.intermediate_size // self.num_heads, + self.state_size + )) + + if cache is not None: + output = mx.zeros((batch_size, seq_len, self.args.hidden_size)) + + for pos in range(seq_len): + x_t = x[:, pos:pos+1] + + dA = mx.exp(dt[:, pos:pos+1] * mx.expand_dims(A, 0)) + dA = mx.expand_dims(mx.expand_dims(dA, -1), -1) + + x_expanded = mx.expand_dims(x_t, axis=3) + dBx = mx.matmul(x_expanded, mx.expand_dims(B[:, pos:pos+1], axis=2)) + + cache.ssm_state = cache.ssm_state * dA + dBx + + y = mx.matmul(cache.ssm_state, mx.expand_dims(C[:, pos:pos+1], axis=3)) + y = mx.squeeze(y, axis=-1) + y = y + x_t * D_expanded + + y = mx.reshape(y, (batch_size, 1, -1)) + y = self.norm(y + z[:, pos:pos+1]) + y = self.out_proj(y) + + if self.args.residual_in_fp32: + y = y.astype(mx.float32) + + output = output.at[:, pos:pos+1].set(y) + else: + y, ssm_state = ssd_optimized( + x * mx.expand_dims(dt, -1), + -mx.exp(self.A_log) * dt, + B, C, + self.args.chunk_size + ) + + y = mx.reshape( + y + x * mx.expand_dims(self.D, -1), + (batch_size, seq_len, -1) + ) + + y = self.norm(y + z) + output = self.out_proj(y) + + if self.args.residual_in_fp32: + output = output.astype(mx.float32) + + return output + + +class ResidualBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.mixer = Mamba2Block(args) + self.norm = nn.RMSNorm(args.hidden_size) + + def __call__(self, x: mx.array, cache): + return self.mixer(self.norm(x), cache) + x + + +class Mamba2(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] + self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + + def __call__(self, x: mx.array, cache): + x = self.embeddings(x) + if cache is None: + cache = [None] * len(self.layers) + for layer, c in zip(self.layers, cache): + x = layer(x, c) + return self.norm_f(x) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + + self.backbone = Mamba2(args) + + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__(self, inputs: mx.array, cache=None): + B, T = inputs.shape + + x = self.backbone(inputs, cache) + + if self.args.tie_word_embeddings: + logits = self.backbone.embeddings.as_linear(x) + else: + logits = self.lm_head(x) + + return logits + + def make_cache(self, batch_size=1): + return [Mamba2Cache() for _ in range(len(self.layers))] + + def sanitize(self, weights): + for k, v in weights.items(): + if "conv1d.weight" in k and v.ndim == 3: + weights[k] = v.moveaxis(2, 1) + return weights + + @property + def layers(self): + return self.backbone.layers \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba23.py b/llms/mlx_lm/models/mamba23.py new file mode 100644 index 00000000..efbe54a4 --- /dev/null +++ b/llms/mlx_lm/models/mamba23.py @@ -0,0 +1,357 @@ +import math +from dataclasses import dataclass, field +from typing import Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs +from .cache import Mamba2Cache + +@dataclass +class ModelArgs(BaseModelArgs): + num_heads: int + head_dim: int + vocab_size: int + hidden_size: int + state_size: int + num_hidden_layers: int + layer_norm_epsilon: float + expand: int + conv_kernel: int + n_groups: int + use_bias: bool + use_conv_bias: bool + initializer_range: float + residual_in_fp32: bool + time_step_min: float + time_step_max: float + time_step_floor: float + rescale_prenorm_residual: bool + use_cache: bool + rms_norm: bool + chunk_size: int + tie_word_embeddings: bool + intermediate_size: int = None + time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf"))) + time_step_rank: Union[int, str] = "auto" + model_type: str = "mamba2" + + def __post_init__(self): + self.intermediate_size = int(self.expand * self.hidden_size) # E*D = ED + + if not hasattr(self, "head_dim"): + self.head_dim = self.hidden_size // self.num_heads + if self.time_step_rank == "auto": + self.time_step_rank = math.ceil(self.hidden_size / 16) + + +class MambaRMSNormGated(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = mx.ones(hidden_size) + self.variance_epsilon = eps + + def forward(self, hidden_states, gate=None): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(mx.float32) + + if gate is not None: + hidden_states = hidden_states * nn.functional.silu(gate.to(mx.float32)) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * math.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class Mamba2Mixer(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + # Model dimensions + self.hidden_size = args.hidden_size + self.num_heads = args.num_heads + self.head_dim = args.head_dim + self.ssm_state_size = args.state_size + self.n_groups = args.n_groups + self.intermediate_size = int(args.expand * args.hidden_size) + + # Convolution parameters + self.conv_kernel = args.conv_kernel + self.use_conv_bias = args.use_conv_bias + + # Time step parameters + self.time_step_rank = int(args.time_step_rank) + self.time_step_min = args.time_step_min + self.time_step_max = args.time_step_max + + # Processing parameters + self.chunk_size = args.chunk_size + self.layer_norm_epsilon = args.layer_norm_epsilon + + # Calculate dimensions + self.conv_dim = (self.intermediate_size + + 2 * self.n_groups * self.ssm_state_size) + projection_size = (self.intermediate_size + + self.conv_dim + + self.num_heads) + + # Initialize layers + self.in_proj = nn.Linear( + self.hidden_size, + projection_size, + bias=args.use_bias + ) + + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + kernel_size=self.conv_kernel, + groups=self.conv_dim, + padding=self.conv_kernel - 1, + bias=self.use_conv_bias + ) + + # Initialize parameters + self.dt_bias = mx.ones(self.num_heads) + A = mx.arange(1, self.num_heads + 1) + self.A_log = mx.log(A) + self.D = mx.ones(self.num_heads) + + # Output layers + self.norm = MambaRMSNormGated( + self.intermediate_size, + eps=self.layer_norm_epsilon + ) + self.out_proj = nn.Linear( + self.intermediate_size, + self.hidden_size, + bias=args.use_bias + ) + + def reshape_into_chunks(self, tensor, pad_size, chunk_size): + if pad_size > 0: + pad_shape = list(tensor.shape) + pad_shape[1] = pad_size + padding = mx.zeros(pad_shape, dtype=tensor.dtype) + tensor = mx.concatenate([tensor, padding], axis=1) + + chunk_shape = list(tensor.shape) + chunk_shape[1] = -1 + chunk_shape.insert(2, chunk_size) + return tensor.reshape(chunk_shape) + + def segment_sum(self, x): + return mx.cumsum(x, axis=-1) + + def process_single_token(self, hidden_states, B, C, dt, cache): + batch_size = hidden_states.shape[0] + + # Process convolution state + if cache is not None and cache.conv_states is not None: + conv_state = cache.conv_states + # Roll the conv state and update the last position + conv_state = mx.roll(conv_state, shift=-1, axis=-1) + # Create new conv state with updated last position + new_conv_state = mx.array(conv_state) + new_conv_state = new_conv_state.at[:, :, -1].add(hidden_states) + conv_state = new_conv_state + + # Compute convolution + conv_out = mx.sum(conv_state * self.conv1d.weight[:, 0, :], axis=-1) + if self.use_conv_bias: + conv_out = conv_out + self.conv1d.bias + + # Apply SiLU activation + conv_out = mx.sigmoid(conv_out) * conv_out + + else: + # Initialize new cache and process convolution + conv_state = mx.zeros((batch_size, self.conv_dim, self.conv_kernel - 1)) + + # Reshape hidden_states for conv1d + hidden_states_reshaped = hidden_states.reshape(batch_size, -1, 1) + conv_out = self.conv1d(hidden_states_reshaped) + conv_out = mx.squeeze(conv_out, axis=-1) # Remove the last dimension + conv_out = mx.sigmoid(conv_out) * conv_out + + # Process SSM + dt = mx.clip( + nn.softplus(dt + self.dt_bias), + self.time_step_min, + self.time_step_max + ) + + A = -mx.exp(self.A_log) + dA = mx.exp(dt[:, None] * A[None, :]) + + if cache is not None and cache.ssm_states is not None: + ssm_state = cache.ssm_states + else: + ssm_state = mx.zeros( + (batch_size, self.num_heads, self.head_dim, self.ssm_state_size) + ) + + # Compute SSM updates + dBx = mx.einsum('bh,bhs,bhd->bhds', dt, B, hidden_states) + next_state = ssm_state * dA[:, :, None, None] + dBx + y = mx.einsum('bhds,bhs->bhd', next_state, C) + + # Add skip connection + y = y + hidden_states * self.D[None, :, None] + + return y, conv_state, next_state + + def process_long_sequence(self, hidden_states, B, C, dt, ssm_state): + batch_size, seq_len = hidden_states.shape[:2] + pad_size = self.chunk_size - (seq_len % self.chunk_size) + + # Reshape into chunks + x_chunks = self.reshape_into_chunks(hidden_states, pad_size, self.chunk_size) + B_chunks = self.reshape_into_chunks(B, pad_size, self.chunk_size) + C_chunks = self.reshape_into_chunks(C, pad_size, self.chunk_size) + + # Process time steps + dt = nn.softplus(dt + self.dt_bias) + dt = mx.clip(dt, self.time_step_min) + + # Prepare matrices + A = -mx.exp(self.A_log) + A = A * dt[:, None] + + # Process chunks + A_chunks = self.reshape_into_chunks( + mx.broadcast_to(A, (batch_size, seq_len + pad_size, self.num_heads)), + pad_size, + self.chunk_size + ) + + # Compute cumulative sums + A_cumsum = mx.cumsum(A_chunks, axis=-1) + L = mx.exp(self.segment_sum(A_chunks)) + + # Process diagonal blocks + G = mx.einsum('...lhn,...shn->...lsh', C_chunks, B_chunks) + M = G * L[..., None, :] + Y_diag = mx.einsum('...lsh,...sh->...lh', M, x_chunks) + + # Process off-diagonal blocks + decay_states = mx.exp(A_cumsum[..., -1:] - A_cumsum) + B_decay = B_chunks * decay_states[..., None] + states = mx.einsum('...shn,...sh->...hn', B_decay, x_chunks) + + # Combine results + y = Y_diag + states + + # Remove padding if necessary + if pad_size > 0: + y = y[:, :seq_len] + + return y, ssm_state + + def __call__(self, x: mx.array, cache: Optional[Mamba2Cache] = None) -> mx.array: + batch_size, seq_len, _ = x.shape + + # Project input + projected_states = self.in_proj(x) + + # Calculate d_mlp based on projection size + d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * + self.n_groups * self.ssm_state_size - self.num_heads) // 2 + + # Split projections with corrected dimensions + splits = [ + d_mlp, # z0 + d_mlp, # x0 + self.intermediate_size, # gate + self.conv_dim, # hidden_states + self.num_heads # dt + ] + + z0, x0, x1, gate, hidden_states, dt = projected_states.split(splits, axis=-1) + + # Split hidden states into components + x_conv, BC = mx.split(hidden_states, [self.intermediate_size], axis=-1) + B, C = mx.split(BC, [self.n_groups * self.ssm_state_size], axis=-1) + + # Process based on sequence length + if seq_len > 1 and cache is None: + y, next_state = self.process_long_sequence( + x_conv, B, C, dt, + mx.zeros((batch_size, self.num_heads, self.head_dim, self.ssm_state_size)) + ) + else: + # Reshape for single token processing + x_conv = x_conv.reshape(batch_size, -1, self.head_dim) + B = B.reshape(batch_size, self.num_heads, -1) + C = C.reshape(batch_size, self.num_heads, -1) + y, conv_state, next_state = self.process_single_token(x_conv, B, C, dt, cache) + + if cache is not None: + cache.update(conv_state, next_state) + + # Apply normalization and final projection + y = self.norm(y) * gate + return self.out_proj(y) + + +class ResidualBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.mixer = Mamba2Mixer(args) + self.norm = nn.RMSNorm(args.hidden_size) + + def __call__(self, x: mx.array, cache: Optional[Mamba2Cache] = None) -> mx.array: + return self.mixer(self.norm(x), cache) + x + + +class Mamba2Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] + self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + + def __call__(self, x: mx.array, cache=None) -> mx.array: + x = self.embeddings(x) + + if cache is None: + cache = [None] * len(self.layers) + + for layer, layer_cache in zip(self.layers, cache): + x = layer(x, layer_cache) + + return self.norm_f(x) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.backbone = Mamba2Model(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__(self, inputs: mx.array, cache=None) -> mx.array: + B, T = inputs.shape + + x = self.backbone(inputs, cache) + + if self.args.tie_word_embeddings: + logits = self.backbone.embeddings.as_linear(x) + else: + logits = self.lm_head(x) + + return logits + + def make_cache(self, batch_size=1): + return [Mamba2Cache() for _ in range(len(self.backbone.layers))] + + def sanitize(self, weights): + for k, v in weights.items(): + if "conv1d.weight" in k and v.ndim == 3: + weights[k] = v.moveaxis(2, 1) + return weights + + @property + def layers(self): + return self.backbone.layers diff --git a/llms/mlx_lm/models/mamba24.py b/llms/mlx_lm/models/mamba24.py new file mode 100644 index 00000000..b1ada1df --- /dev/null +++ b/llms/mlx_lm/models/mamba24.py @@ -0,0 +1,430 @@ +import math +from dataclasses import dataclass, field +from typing import Tuple, Union +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs +from .cache import Mamba2Cache + +@dataclass +class ModelArgs(BaseModelArgs): + num_heads: int + head_dim: int + vocab_size: int + hidden_size: int + state_size: int + num_hidden_layers: int + layer_norm_epsilon: float + expand: int + conv_kernel: int + n_groups: int + use_bias: bool + use_conv_bias: bool + initializer_range: float + residual_in_fp32: bool + time_step_min: float + time_step_max: float + time_step_floor: float + rescale_prenorm_residual: bool + rms_norm: bool + chunk_size: int + tie_word_embeddings: bool + use_cache: bool = True + time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf"))) + time_step_rank: Union[int, str] = "auto" + model_type: str = "mamba2" + + def __post_init__(self): + if not hasattr(self, "intermediate_size"): + self.intermediate_size = int(self.expand * self.hidden_size) + if not hasattr(self, "head_dim"): + self.head_dim = self.hidden_size // self.num_heads + if self.time_step_rank == "auto": + self.time_step_rank = math.ceil(self.hidden_size / 16) + + +class MambaRMSNormGated(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = mx.ones((hidden_size,)) + self.variance_epsilon = eps + + def __call__(self, hidden_states, gate=None): + if gate is not None: + hidden_states = hidden_states * nn.silu(gate) + variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True) + hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states + + +def pad_tensor_by_size(input_tensor: mx.array, pad_size: int): + """ + Padding x tensor with `pad_size` on the seq_len dim (dim=1) + + Assumes that we only have tensors of either size 4 or 3 + """ + pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) + + return mx.pad(input_tensor, pad_shape, mode="constant", value=0) + + +def reshape_into_chunks(input_tensor, pad_size, chunk_size): + """ + Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and + simultaneously splitting it into chunk sequences. + + Assumes that we only have tensors of either size 4 or 3 + """ + # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] + input_tensor = pad_tensor_by_size(input_tensor, pad_size) + + if len(input_tensor.shape) == 3: + # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] + return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) + else: + # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size] + return input_tensor.reshape( + input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3] + ) + + +def segment_sum(input_tensor): + """ + More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions. + """ + chunk_size = input_tensor.size(-1) + # 1. expand input tensor to have an additional dimension and repeat along that dimension + # [..., chunk_size] -> [..., chunk_size, chunk_size] + input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size) + # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag + mask = mx.tril(mx.ones(chunk_size, chunk_size, device=input_tensor.device), diagonal=-1) + input_tensor = input_tensor.masked_fill(~mask, 0) + # 3. compute actual cumsum + tensor_segsum = mx.cumsum(input_tensor, dim=-2) + + # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time) + mask = mx.tril(mx.ones(chunk_size, chunk_size, device=input_tensor.device), diagonal=0) + tensor_segsum = tensor_segsum.masked_fill(~mask, -mx.inf) + return tensor_segsum + + +class Mamba2Block(nn.Module): + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + self.args = args + + self.hidden_size = args.hidden_size + self.num_heads = args.num_heads + self.head_dim = args.head_dim + self.state_size = args.state_size + self.n_groups = args.n_groups + self.conv_kernel = args.conv_kernel + self.intermediate_size = int(args.expand * args.hidden_size) + self.time_step_rank = int(args.time_step_rank) + self.time_step_min = args.time_step_min + self.time_step_max = args.time_step_max + self.chunk_size = args.chunk_size + + + # Convolution dimension includes both intermediate sizes + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=args.use_conv_bias, + kernel_size=args.conv_kernel, + groups=self.conv_dim, + padding=args.conv_kernel - 1 + ) + + # Compute input projection dimension + projection_size = self.intermediate_size + self.conv_dim + self.num_heads + self.in_proj = nn.Linear(args.hidden_size, projection_size, bias=args.use_bias) + + self.dt_bias = mx.ones(self.num_heads) + A = mx.arange(1, self.num_heads + 1) + self.A_log = mx.log(A) + self.D = mx.ones(self.num_heads) + + self.norm = MambaRMSNormGated(self.intermediate_size, eps=args.layer_norm_epsilon) + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) + + def __call__(self, input_states: mx.array, cache): + batch_size, seq_len, _ = input_states.shape + + # Gated MLP's linear projection + projected_states = self.in_proj(input_states) # [1, 1, projection_size] + + d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - + 2 * self.n_groups * self.state_size - self.num_heads) // 2 + + # Split projected states + *_, gate, hidden_states, dt = projected_states.split( + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], + axis=-1 + ) + # hidden_states shape: [1, 1, conv_dim] + + # Get SSM state from cache + ssm_state = cache.ssm_states[self.layer_idx] + + if cache.seqlen_offset > 0: + # Handle cached generation case + conv_state = cache.conv_states[self.layer_idx] # [batch, conv_dim, conv_kernel] + conv_state = mx.roll(conv_state, shifts=-1, axis=-1) + + # Handle batched generation - states are copied through + # Properly reshape hidden_states for the conv_state update + conv_state = conv_state.at[:, :, -1].set(hidden_states[:, 0, :]) + cache.conv_states[self.layer_idx] = conv_state + + # Compute convolution output + hidden_states = mx.sum(conv_state * self.conv1d.weight[:, 0, :], axis=-1) + if self.args.use_conv_bias: + hidden_states += self.conv1d.bias + hidden_states = nn.silu(hidden_states)[:, None, ...] # [batch, 1, conv_dim] : decoding + + else: + # Handle normal forward pass + # Properly transpose while preserving the sequence dimension + hidden_states = hidden_states.transpose(0, 2, 1) # [1, conv_dim, 1] + + # Pad the convolution state + padding_size = self.conv_kernel - 1 + conv_state = mx.pad( + hidden_states, + ((0, 0), (0, 0), (padding_size, 0)) + ) + + # Store in cache + cache.conv_states[self.layer_idx] = conv_state + + # Apply convolution with proper padding + hidden_states = self.conv1d(hidden_states) # [1, conv_dim, 1] + hidden_states = hidden_states.transpose(0, 2, 1) # [1, 1, conv_dim] + hidden_states = nn.silu(hidden_states) + + # Split hidden states for SSM computation + hidden_states, B, C = mx.split( + hidden_states, + [self.intermediate_size, self.n_groups * self.state_size, self.n_groups * self.state_size], + axis=-1 + ) + + # Compute A matrix + A = -mx.exp(self.A_log) + + if cache is not None and cache.seqlen_offset > 0: + # Note: there is no need to pad parameter matrices here, as there is just one new token + # for batched generation + dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...] + dt = dt.transpose(0, 2, 1).expand(batch_size, dt.shape[-1], self.head_dim) + # [num_heads] -> [num_heads, head_dim] + dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) + + dt = nn.softplus(dt + dt_bias) + dt = mx.clamp(dt, self.time_step_min) #, self.time_step_max) + A = A[..., None, None].expand(self.num_heads, self.head_dim, self.state_size) + # [bsz, num_heads, head_dim, state_size] + dA = mx.exp(dt[..., None] * A) + + # Discretize B + # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> + # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] + B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] + B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous() + B = B.reshape(batch_size, -1, B.shape[-1]) + # [bsz, num_heads, head_dim, state_size] + dB = dt[..., None] * B[..., None, :] + + # Discretize x into dB + # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] + hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) + dBx = dB * hidden_states[..., None] + + # State calculation + cache.ssm_states[self.layer_idx].copy_( + cache.ssm_states[self.layer_idx] * dA + dBx + ) + # Subsequent output + # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] + C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] + C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous() + C = C.reshape(batch_size, -1, C.shape[-1]) + # [bsz, num_heads, head_dim] + + ssm_states = cache.ssm_states[self.layer_idx] # Shape: [b, h, d, n] + # Reshape ssm_states to merge the first two dimensions + ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.state_size) # Shape: [b*h, d, n] + C_reshaped = C.view(batch_size * self.num_heads, self.state_size, 1) # Shape: [b*h, n, 1] + y = ssm_states_reshaped @ C_reshaped + y = y.view(batch_size, self.num_heads, self.head_dim) + + # D skip connection + # [num_heads] -> [num_heads, head_dim] + D = self.D[..., None].expand(self.D.shape[0], self.head_dim) + y = (y + hidden_states * D) + + # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size] + y = y.reshape(batch_size, -1)[:, None, ...] + else: + # begin ssd naive implementation without einsums + dt = nn.functional.softplus(dt + self.dt_bias) + dt = mx.clamp(dt, self.time_step_min) + hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim) + B = B.reshape(batch_size, seq_len, -1, self.state_size) + C = C.reshape(batch_size, seq_len, -1, self.state_size) + B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) + C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) + pad_size = self.chunk_size - (seq_len % self.chunk_size) + + D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) + + # Discretize x and A + hidden_states = hidden_states * dt[..., None] + A = A * dt + + # Rearrange into blocks/chunks + hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] + + + # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] + A = A.permute(0, 3, 1, 2) + A_cumsum = mx.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + # This is the analog of a causal mask + L = mx.exp(segment_sum(A)) + + # First, contraction of C and B to get G (attention-weights like) + G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n) + G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) + + + # Step 2: Compute M, equivalent to applying attention mask to weights + M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] + M = M_intermediate.sum(dim=-1) + + # Step 3: Compute Y_diag (apply to values) + Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3) + + # (right term of low-rank factorization of off-diagonal blocks; B terms) + + decay_states = mx.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None] + # permute back B * decay states + states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3) + if cache is not None and cache.seqlen_offset > 0: + previous_states = cache.ssm_states[self.layer_idx][:, None, ...] + else: + previous_states = mx.zeros_like(states[:, :1]) + states = mx.concat([previous_states, states], dim=1) + decay_chunk = mx.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) + + states_permuted = states.permute(0, 2, 1, 3, 4) + result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2) + new_states = result.permute(0, 2, 1, 3, 4) + states, ssm_state = new_states[:, :-1], new_states[:, -1] + + # Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = mx.exp(A_cumsum) + # compute Yoff + C_times_states = (C[..., None, :] * states[:, :, None, ...]) + state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) + Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + + y = Y_diag + Y_off + # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] + y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) + + y = y + D_residual + # Cutting off padded chunks + if pad_size > 0: + y = y[:, :seq_len, :, :] + y = y.reshape(batch_size, seq_len, -1) + + if ssm_state is not None and cache is not None: + cache.ssm_states[self.layer_idx] = ssm_state + + scan_output = self.norm(y, gate) + # end ssd naive + + # 4. Final linear projection + return self.out_proj(scan_output) # [batch, seq_len, hidden_size] + + +class ResidualBlock(nn.Module): + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__() + self.residual_in_fp32 = args.residual_in_fp32 + self.mixer = Mamba2Block(args, layer_idx) + self.norm = nn.RMSNorm(args.hidden_size) + + def __call__(self, x: mx.array, cache): + return self.mixer(self.norm(x), cache) + x + + +class Mamba2(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ResidualBlock(args, idx) for idx in range(args.num_hidden_layers)] + self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + + def __call__(self, x: mx.array, cache): + x = self.embeddings(x) + if cache is None: + cache = [None] * len(self.layers) + for layer, c in zip(self.layers, cache): + x = layer(x, c) + return self.norm_f(x) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + + self.backbone = Mamba2(args) + + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__(self, inputs: mx.array, cache=None): + B, T = inputs.shape + + x = self.backbone(inputs, cache) + + if self.args.tie_word_embeddings: + logits = self.backbone.embeddings.as_linear(x) + else: + logits = self.lm_head(x) + + return logits + + def make_cache(self, batch_size=1): + return [Mamba2Cache( + batch_size, + self.args.intermediate_size, + self.args.conv_kernel, + self.args.head_dim, + self.args.num_heads, + self.args.n_groups, + self.args.state_size + ) for _ in range(len(self.layers))] + + def sanitize(self, weights): + for k, v in weights.items(): + if "conv1d.weight" in k and v.ndim == 3: + weights[k] = v.moveaxis(2, 1) + return weights + + @property + def layers(self): + return self.backbone.layers \ No newline at end of file