diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index d1c93eba..b38d0203 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -342,7 +342,7 @@ class MambaCache(_BaseCache): class Mamba2Cache(_BaseCache): conv_states: Optional[mx.array] = None - ssm_states: Optional[mx.array] = None + ssm_state: Optional[mx.array] = None def __getitem__(self, idx: int) -> Optional[mx.array]: if idx == 0: diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index f5a0e18a..01f9485b 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -103,89 +103,55 @@ class DepthWiseConv1d(nn.Module): assert in_channels == out_channels, "In and out channels must be same for depthwise convolution" assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution" - # Weight shape: (channels, 1, kernel_size) to match pretrained weights self.weight = mx.random.normal((in_channels, 1, kernel_size)) self.bias = mx.zeros((out_channels,)) if bias else None - def __call__(self, x: mx.array, cache=None, cache_idx: int = 0) -> mx.array: + def __call__(self, x: mx.array, cache=None) -> mx.array: B, L, C = x.shape K = self.kernel_size - # Validate input dimensions assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}" - # Handle padding and caching - if cache is not None: - conv_states = cache[cache_idx] + if cache is not None and 'conv_states' in cache: + conv_states = cache['conv_states'] if conv_states is not None: - # Validate cache shape assert conv_states.shape[0] == B, "Cache batch size mismatch" assert conv_states.shape[2] == C, "Cache channel count mismatch" x = mx.concatenate([conv_states, x], axis=1) - L = x.shape[1] - else: - # Add left padding of size (kernel_size - 1) - pad_left = K - 1 - x = mx.pad(x, [(0, 0), (pad_left, 0), (0, 0)]) - L = x.shape[1] - - # Pre-allocate output array if possible - outputs = [] - - # Process each channel independently - for c in range(C): - # Extract and prepare channel data - x_c = x[:, :, c] # Shape: [B, L] - x_c = mx.expand_dims(x_c, axis=1) # Shape: [B, 1, L] - # Prepare filter weights - w_c = self.weight[c] # Get channel weights - # Ensure filter is 3D: [depth(1), in_channels(1), kernel_size] + # Process each channel independently + outputs = [] + for c in range(C): + x_c = x[:, :, c] + x_c = mx.expand_dims(x_c, axis=1) + + w_c = self.weight[c] if w_c.ndim == 2: w_c = mx.expand_dims(w_c, axis=0) elif w_c.ndim == 1: w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0) - # Handle inference mode (single token) - if L < K: - pad_size = K - L - x_c = mx.pad(x_c, [(0, 0), (0, 0), (pad_size, 0)]) + # Apply convolution + y_c = mx.conv_general( + x_c, + w_c, + stride=1, + padding=0 + ) - # Apply 1D convolution - try: - y_c = mx.conv_general( - x_c, - w_c, - stride=1, - padding=0 # Padding already handled - ) - - if self.bias is not None: - y_c = y_c + self.bias[c] - - # Remove singleton dimension and add to outputs - outputs.append(mx.squeeze(y_c, axis=1)) - - except Exception as e: - raise RuntimeError(f"Convolution failed for channel {c}. Shapes: input={x_c.shape}, weight={w_c.shape}") from e + if self.bias is not None: + y_c = y_c + self.bias[c] + + outputs.append(mx.squeeze(y_c, axis=1)) - # Stack channel outputs along last dimension - y = mx.stack(outputs, axis=-1) # Shape: [B, L', C] + y = mx.stack(outputs, axis=-1) - # Update cache if needed + # Update cache if cache is not None: - # Store last (kernel_size - 1) tokens or entire input if shorter - new_cache = x[:, -(K-1):, :] if L >= K else x - cache[cache_idx] = new_cache - - if new_cache.shape != cache[cache_idx].shape: - cache[cache_idx] = new_cache - print(f"Cache updated at index {cache_idx}") - else: - print(f"Skipping cache update at index {cache_idx}, shapes are identical.") - + cache['conv_states'] = x[:, -K+1:, :] if x.shape[1] >= K else x + return y - + class Mamba2Block(nn.Module): def __init__(self, args: ModelArgs): @@ -217,34 +183,22 @@ class Mamba2Block(nn.Module): self.out_proj.weight = self.out_proj.weight * layer_scale def __call__(self, x: mx.array, cache=None): - # if cache is not None and self.args.use_cache: if cache is not None: return self.step(x, cache) - # Calculate sizes + # Regular forward pass code remains the same... d_model = self.args.intermediate_size d_state = self.args.state_size n_heads = self.args.num_heads - # Compute A A = -mx.exp(self.A_log) - - # Project input zxbcdt = self.in_proj(x) - # Correct splits for z, xBC, dt - splits = [ - d_model, # z - d_model + 2 * d_state, # xBC (delta, B, C concatenated) - n_heads # dt - ] - - # Split using cumulative indices + splits = [d_model, d_model + 2 * d_state, n_heads] z = zxbcdt[:, :, :splits[0]] xBC = zxbcdt[:, :, splits[0]:splits[0] + splits[1]] dt = zxbcdt[:, :, -splits[2]:] - # Process dt dt = mx.clip( nn.softplus(dt + self.dt_bias), self.args.time_step_min, @@ -252,46 +206,26 @@ class Mamba2Block(nn.Module): ) dt = mx.maximum(dt, self.args.time_step_floor) - # Process convolution xBC = silu(self.conv1d(xBC)) - # Split convolved xBC into x, B, C x = xBC[:, :, :d_model] B = xBC[:, :, d_model:d_model + d_state] C = xBC[:, :, -d_state:] - # Reshape for SSM computation b, l, hp = x.shape h = self.args.num_heads p = hp // h x = mx.reshape(x, (b, l, h, p)) - # Compute SSM - y, ssm_state = ssd( - x * mx.expand_dims(dt, -1), - A * dt, - B, - C, - self.args.chunk_size - ) - - # Add skip connection + y, ssm_state = ssd(x * mx.expand_dims(dt, -1), A * dt, B, C, self.args.chunk_size) y = y + x * mx.expand_dims(self.D, -1) - - # Reshape back y = mx.reshape(y, (b, l, h * p)) - # Apply norm and projection y = self.norm(y + z) y = self.out_proj(y) - # Update cache if needed - if cache is not None and self.args.use_cache: - cache[1] = ssm_state - - # Cast if needed if self.args.residual_in_fp32: - y.astype(mx.float32) + y = y.astype(mx.float32) return y @@ -300,9 +234,17 @@ class Mamba2Block(nn.Module): seq_len = u.shape[1] outputs = [] - # Initialize SSM state if needed - if cache[1] is None: - cache[1] = mx.zeros(( + # Initialize cache if needed + if cache.conv_states is None: + conv_dim = self.args.intermediate_size + 2 * self.args.state_size + cache.conv_states = mx.zeros(( + batch_size, + self.args.conv_kernel - 1, + conv_dim + )) + + if cache.ssm_state is None: + cache.ssm_state = mx.zeros(( batch_size, self.args.num_heads, self.args.head_dim, @@ -310,26 +252,17 @@ class Mamba2Block(nn.Module): )) for pos in range(seq_len): - # Getting stuck here in last position, also cache from pos 0 is the same. - # Get single token u_t = u[:, pos:pos+1, :] - - # Project input zxbcdt = self.in_proj(u_t) - # Calculate sizes d_model = self.args.intermediate_size d_state = self.args.state_size n_heads = self.args.num_heads - d_head = self.args.head_dim - # Split projected input - # conv_dim = d_model + 2 * d_state (this should match self.conv1d.in_channels) z = zxbcdt[:, :, :d_model] - xBC = zxbcdt[:, :, d_model:d_model + 2*d_state + d_model] # Include the full conv dimension + xBC = zxbcdt[:, :, d_model:d_model + 2*d_state + d_model] dt = zxbcdt[:, :, -(n_heads):] - # Process dt dt = mx.reshape(dt, (batch_size, n_heads)) dt = mx.clip( nn.softplus(dt + self.dt_bias), @@ -338,49 +271,43 @@ class Mamba2Block(nn.Module): ) dt = mx.maximum(dt, self.args.time_step_floor) - # Process convolution with correct dimensions - xBC = self.conv1d(xBC, cache=cache, cache_idx=0) + # Create a temporary cache dictionary for the convolution + conv_cache = {'conv_states': cache.conv_states} + xBC = self.conv1d(xBC, cache=conv_cache) + cache.conv_states = conv_cache['conv_states'] + xBC = silu(xBC) - # Split convolved xBC into x, B, C with correct dimensions x = xBC[:, :, :d_model] B = xBC[:, :, d_model:d_model + d_state] C = xBC[:, :, -d_state:] - # Reshape tensors for SSM computation - x = mx.reshape(x, (batch_size, 1, n_heads, d_head)) - x = mx.squeeze(x, axis=1) # (batch, heads, dim) - + x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim)) + x = mx.squeeze(x, axis=1) + B = mx.reshape(B, (batch_size, 1, d_state)) B = mx.broadcast_to(B, (batch_size, n_heads, d_state)) - B = mx.expand_dims(B, axis=2) # (batch, heads, 1, state) - + B = mx.expand_dims(B, axis=2) + C = mx.reshape(C, (batch_size, 1, d_state)) C = mx.broadcast_to(C, (batch_size, n_heads, d_state)) - C = mx.expand_dims(C, axis=3) # (batch, heads, state, 1) + C = mx.expand_dims(C, axis=3) - # Compute SSM updates A = -mx.exp(self.A_log) dA = mx.exp(dt * mx.expand_dims(A, 0)) - dA = mx.expand_dims(mx.expand_dims(dA, -1), -1) # (batch, heads, 1, 1) + dA = mx.expand_dims(mx.expand_dims(dA, -1), -1) - # Update state with proper shapes - x = mx.expand_dims(x, axis=3) # (batch, heads, dim, 1) - dBx = mx.matmul(x, B) # (batch, heads, dim, state) + x = mx.expand_dims(x, axis=3) + dBx = mx.matmul(x, B) - ssm_state = cache[1] - ssm_state = ssm_state * dA + dBx - cache[1] = ssm_state + cache.ssm_state = cache.ssm_state * dA + dBx - # Compute output - y = mx.matmul(ssm_state, C) # (batch, heads, dim, 1) - y = mx.squeeze(y, axis=-1) # (batch, heads, dim) + y = mx.matmul(cache.ssm_state, C) + y = mx.squeeze(y, axis=-1) - # Add skip connection y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1) - # Reshape and process output - y = mx.reshape(y, (batch_size, 1, n_heads * d_head)) + y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim)) y = self.norm(y + z) y = self.out_proj(y) @@ -440,7 +367,6 @@ class Model(nn.Module): else: logits = self.lm_head(x) - print('ouput') return logits def make_cache(self, batch_size=1):