diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index 155755b0..84fadd06 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -350,28 +350,11 @@ class MambaCache: return [self.cache[1]] - -class Mamba2Cache(_BaseCache): - def __init__( - self, - batch_size, - conv_kernel - ): - self.conv_kernel: mx.array = conv_kernel - self.conv_states: mx.array = [None] - self.ssm_states = [None] - self.seqlen_offset = 0 +class Mamba2Cache: + def __init__(self): + self.conv_states = [None] # Initialize as None, will be set on first use + self.ssm_states = [None] # Initialize as None, will be set on first use - def reset(self): - self.conv_states = None - self.ssm_state = None - - def update(self, layer_idx: int, new_conv_state: mx.array, cache_position: mx.array) -> mx.array: - conv_state = self.conv_states[layer_idx] - cache_position = cache_position.clamp(0, self.conv_kernel - 1) - - conv_state = conv_state.roll(shifts=-1, dims=-1) - conv_state[:, :, cache_position] = new_conv_state - self.conv_states[layer_idx].zero_() - self.conv_states[layer_idx] += conv_state - return self.conv_states[layer_idx] \ No newline at end of file + @property + def state(self): + return [self.conv_states[0], self.ssm_states[0]] \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index 0fe0d9a8..51820221 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -1,11 +1,11 @@ import math from dataclasses import dataclass, field -from typing import Tuple, Union, Optional +from typing import Tuple, Union import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs -from .cache import MambaCache +from .cache import Mamba2Cache @dataclass class ModelArgs(BaseModelArgs): @@ -56,186 +56,217 @@ class MambaRMSNormGated(nn.Module): variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True) hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states - + def silu(x): return x * mx.sigmoid(x) - def ssd(x, A, B, C, chunk_size): - # Not getting used + # Replace einsum operations with explicit reshape and matrix multiply batch, seqlen, nheads, dim = x.shape B = mx.expand_dims(B, axis=2) C = mx.expand_dims(C, axis=2) - + state = mx.zeros((batch, nheads, dim, B.shape[-1])) outputs = [] - + for i in range(0, seqlen, chunk_size): chunk = slice(i, min(i + chunk_size, seqlen)) dA = mx.exp(mx.expand_dims(A[chunk], axis=0)) - + # Replace einsum with explicit operations x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim] x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size] B_chunk = B[:, chunk] # [batch, chunk_size, state_size] dBx = mx.matmul(x_chunk, B_chunk) # [batch, nheads, dim, state_size] - + state = state * mx.expand_dims(dA, axis=-1) + dBx - + # Replace einsum with explicit operations C_chunk = C[:, chunk] # [batch, chunk_size, state_size] y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size] y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim] outputs.append(y) - + return mx.concatenate(outputs, axis=1), state +class DepthWiseConv1d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.padding = padding + self.groups = groups if groups is not None else in_channels + + assert in_channels == out_channels, "In and out channels must be same for depthwise convolution" + assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution" + + self.weight = mx.random.normal((in_channels, 1, kernel_size)) + self.bias = mx.zeros((out_channels,)) if bias else None + + def __call__(self, x: mx.array, cache=None) -> mx.array: + B, L, C = x.shape + K = self.kernel_size + + assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}" + + if cache is not None: + if isinstance(cache.conv_states[0], type(None)): + cache.conv_states[0] = mx.zeros((B, K-1, C)) + + x = mx.concatenate([cache.conv_states[0], x], axis=1) + + outputs = [] + for c in range(C): + # Input prep debug + x_c = x[:, :, c] + x_c = mx.expand_dims(x_c, axis=1) + + # Weight prep debug + w_c = self.weight[c] + if w_c.ndim == 2: + w_c = mx.expand_dims(w_c, axis=0) + elif w_c.ndim == 1: + w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0) + + y_c = mx.conv_general( + x_c, + w_c, + stride=1, + padding=0 + ) + if self.bias is not None: + y_c = y_c + self.bias[c] + + y_c = mx.squeeze(y_c, axis=1) + outputs.append(y_c) + + # Output statistics + y = mx.stack(outputs, axis=-1) + + # Cache update debug + if cache is not None: + cache.conv_states[0] = x[:, -K+1:, :] if x.shape[1] >= K else x + + return y + + class Mamba2Block(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args = args - self.chunk_size = args.chunk_size d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias) - - self.conv_dim = args.intermediate_size + 2 * args.state_size - - # Replace DepthWiseConv1d with grouped nn.Conv1d - self.conv1d = nn.Conv1d( - in_channels=self.conv_dim, - out_channels=self.conv_dim, + + conv_dim = args.intermediate_size + 2 * args.state_size + self.conv1d = DepthWiseConv1d( + in_channels=conv_dim, + out_channels=conv_dim, kernel_size=args.conv_kernel, - groups=self.conv_dim, # Makes it depthwise + groups=conv_dim, bias=args.use_conv_bias, - padding=0 # We'll handle padding via cache + padding=args.conv_kernel - 1 ) - - self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range - self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range - self.D = mx.random.normal((args.num_heads,)) * args.initializer_range - + + # Initialize parameters + self.dt_bias = mx.ones(args.num_heads) + A = mx.arange(1, args.num_heads + 1) + self.A_log = mx.log(A) + self.D = mx.ones(args.num_heads) + self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon) self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) - + if args.rescale_prenorm_residual: layer_scale = math.sqrt(1.0 / args.num_hidden_layers) self.out_proj.weight = self.out_proj.weight * layer_scale - - def __call__(self, u: mx.array, cache: Optional[MambaCache] = None): - batch_size, seq_len, _ = u.shape - pad_size = self.chunk_size - (seq_len % self.chunk_size) - + + def __call__(self, u: mx.array, cache=None): + # Expect input to be shape [batch_size, 1, dim] + batch_size, seq_len, dimension = u.shape + assert seq_len == 1, "Input should be a single token" + # Initialize cache if needed - if cache is None: - cache = MambaCache() - - # Initialize states if needed - if cache[0] is None: # conv state - cache[0] = mx.zeros(( - batch_size, - self.args.conv_kernel - 1, - self.conv_dim - )) - - if cache[1] is None: # ssm state - cache[1] = mx.zeros(( + if cache.conv_states[0] is None: + conv_dim = self.args.intermediate_size + 2 * self.args.state_size + cache.conv_states[0] = mx.zeros((batch_size, self.args.conv_kernel - 1, conv_dim)) + + if cache.ssm_states[0] is None: + cache.ssm_states[0] = mx.zeros(( batch_size, self.args.num_heads, self.args.head_dim, self.args.state_size )) - + # Project input zxbcdt = self.in_proj(u) # Split projections + n_heads = self.args.num_heads z = zxbcdt[:, :, :self.args.intermediate_size] xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size] - dt = zxbcdt[:, :, -(self.args.num_heads):] - - # Process delta time - dt = mx.reshape(dt, (batch_size, seq_len, self.args.num_heads)) - dt = mx.squeeze(dt, axis=0) - dt = mx.clip( - nn.softplus(dt + self.dt_bias), - self.args.time_step_min, - self.args.time_step_max - ) + dt = zxbcdt[:, :, -(n_heads):] + + # Time steps + dt = mx.reshape(dt, (batch_size, n_heads)) + dt = mx.clip(nn.softplus(dt + self.dt_bias), self.args.time_step_min, self.args.time_step_max) dt = mx.maximum(dt, self.args.time_step_floor) - - # Handle convolution caching and padding - conv_state = cache[0] - if conv_state is not None: - xBC = mx.concatenate([conv_state, xBC], axis=1) - - # Prepare input for conv1d: [B, C, L] - xBC = mx.transpose(xBC, [0, 2, 1]) - - # Apply convolution - xBC = self.conv1d(xBC) - - # Update cache state - cache[0] = mx.transpose(xBC, [0, 2, 1])[:, -self.args.conv_kernel+1:, :] - - # Return to [B, L, C] format - xBC = mx.transpose(xBC, [0, 2, 1]) + + # Convolution + xBC = self.conv1d(xBC, cache=cache) xBC = silu(xBC) - - # Split conv output + + # Split states x = xBC[:, :, :self.args.intermediate_size] B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size] C = xBC[:, :, -self.args.state_size:] - + # Reshape for SSM - x = mx.reshape(x, (batch_size, seq_len, self.args.num_heads, self.args.head_dim)) - - B = mx.reshape(B, (batch_size, seq_len, self.args.state_size)) - B = mx.broadcast_to(B, (batch_size, self.args.num_heads, self.args.state_size)) - - C = mx.reshape(C, (batch_size, seq_len, self.args.state_size)) - C = mx.broadcast_to(C, (batch_size, self.args.num_heads, self.args.state_size)) - - # SSM state update - ssm_state = cache[1] + x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim)) + x = mx.squeeze(x, axis=1) + B = mx.reshape(B, (batch_size, 1, self.args.state_size)) + B = mx.broadcast_to(B, (batch_size, n_heads, self.args.state_size)) + B = mx.expand_dims(B, axis=2) + C = mx.reshape(C, (batch_size, 1, self.args.state_size)) + C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size)) + C = mx.expand_dims(C, axis=3) + + # SSM updates A = -mx.exp(self.A_log) dA = mx.exp(dt * mx.expand_dims(A, 0)) - - x = mx.expand_dims(x, axis=-1) - dBx = mx.matmul(x, mx.expand_dims(B, axis=-2)) - - new_ssm_state = ssm_state * mx.expand_dims(dA, -1) + dBx - cache[1] = new_ssm_state - - # Output computation - y = mx.matmul(new_ssm_state, mx.expand_dims(C, axis=-1)) - y = mx.squeeze(y, axis=-1) - - if pad_size > 0: - y = y[:, :seq_len, :, :] - - # Final reshape and projections - y = mx.reshape(y, (batch_size, seq_len, -1)) - y = self.norm(y + z) - - return self.out_proj(y) + dA = mx.expand_dims(mx.expand_dims(dA, -1), -1) + # Update state + x = mx.expand_dims(x, axis=3) + dBx = mx.matmul(x, B) + cache.ssm_states[0] = cache.ssm_states[0] * dA + dBx + + # Compute output + y = mx.matmul(cache.ssm_states[0], C) + y = mx.squeeze(y, axis=-1) + y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1) + y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim)) + y = self.norm(y + z) + + return self.out_proj(y) class ResidualBlock(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.residual_in_fp32 = args.residual_in_fp32 - self.mixer = Mamba2Block(args) self.norm = nn.RMSNorm(args.hidden_size) def __call__(self, x: mx.array, cache): if self.residual_in_fp32: x = x.astype(mx.float32) - return self.mixer(self.norm(x), cache) + x - + normed = self.norm(x) + output = self.mixer(normed, cache) + return output + x class Mamba2(nn.Module): def __init__(self, args: ModelArgs): @@ -249,9 +280,11 @@ class Mamba2(nn.Module): x = self.embeddings(x) if cache is None: cache = [None] * len(self.layers) + + hidden = x for layer, c in zip(self.layers, cache): - x = layer(x, c) - return self.norm_f(x) + hidden = layer(hidden, c) + return self.norm_f(hidden) class Model(nn.Module): @@ -259,33 +292,24 @@ class Model(nn.Module): super().__init__() self.args = args self.model_type = args.model_type - self.backbone = Mamba2(args) if not args.tie_word_embeddings: self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) def __call__(self, inputs: mx.array, cache=None): - B, T = inputs.shape - - x = self.backbone(inputs, cache) - + hidden = self.backbone(inputs, cache) + if self.args.tie_word_embeddings: - logits = self.backbone.embeddings.as_linear(x) + logits = self.backbone.embeddings.as_linear(hidden) else: - logits = self.lm_head(x) - + logits = self.lm_head(hidden) + return logits - def make_cache(self, batch_size=1): - return [MambaCache() for _ in range(len(self.backbone.layers))] - - def sanitize(self, weights): - for k, v in weights.items(): - if "conv1d.weight" in k and v.shape[-1] != 1: - weights[k] = v.moveaxis(2, 1) - return weights + def make_cache(self): + return [Mamba2Cache() for _ in range(len(self.layers))] @property def layers(self): - return self.backbone.layers + return self.backbone.layers \ No newline at end of file