From 10adfa76bfae8dfb3fd1372ec25f09b085eb5461 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 12 Mar 2025 17:33:53 +0100 Subject: [PATCH] multible ssd step frunctions --- llms/mlx_lm/models/mamba2.py | 215 +++++++++++++++++++++++++++++++---- 1 file changed, 194 insertions(+), 21 deletions(-) diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index 521d3e8d..419bb2db 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -1,5 +1,5 @@ import math -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -42,10 +42,6 @@ class ModelArgs(BaseModelArgs): self.time_step_rank = math.ceil(self.hidden_size / 16) -def segsum(x): - return mx.cumsum(x, axis=-1).reshape(*x.shape[:-1], 1, x.shape[-1]) - - class DepthWiseConv1d(nn.Module): def __init__(self, channels, kernel_size, bias=True, padding=0): super().__init__() @@ -69,6 +65,10 @@ class DepthWiseConv1d(nn.Module): return y, x[:, -K + 1:, :] +def segsum(x): + return mx.cumsum(x, axis=-1).reshape(*x.shape[:-1], 1, x.shape[-1]) + + def ssd_forward_attn( x: mx.array, dt: mx.array, @@ -137,6 +137,166 @@ def ssd_forward_attn( return y, next_state +def ssd(x, A, B, C, chunk_size, initial_states=None): + """Structured State Space Duality (SSD) - the core of Mamba-2 + + Arguments + x: (batch, seqlen, n_heads, d_head) + A: (batch, seqlen, n_heads) + B: (batch, seqlen, n_heads, d_state) + C: (batch, seqlen, n_heads, d_state) + + Return (y, final_state) + y: (batch, seqlen, n_heads, d_head) + final_state: final state for next inference step + """ + # Verify sequence length is divisible by chunk_size + b, seqlen, h, dh = x.shape + assert seqlen % chunk_size == 0 + + # Rearrange into chunks + num_chunks = seqlen // chunk_size + x_chunks = x.reshape(b, num_chunks, chunk_size, h, dh) + A_chunks = A.reshape(b, num_chunks, chunk_size, h) + B_chunks = B.reshape(b, num_chunks, chunk_size, -1, B.shape[-1]) # Account for groups + C_chunks = C.reshape(b, num_chunks, chunk_size, -1, C.shape[-1]) + + # Transpose A for correct cumsum operation + A_chunks = mx.transpose(A_chunks, (0, 3, 1, 2)) # b h c l + A_cumsum = mx.cumsum(A_chunks, axis=-1) + + # 1. Compute output for each intra-chunk (diagonal blocks) + L = mx.exp(segsum(A_chunks)) + + # Handle the dimensions for einsum + # "bclhn, bcshn, bhcls, bcshp -> bclhp" + C_expanded = mx.expand_dims(C_chunks, axis=3) # b c l 1 h n + B_expanded = mx.expand_dims(B_chunks, axis=2) # b c 1 s h n + L_reshaped = mx.transpose(L, (0, 2, 3, 1, 4)) # b h c l s -> b c l h s + x_reshaped = mx.transpose(x_chunks, (0, 1, 2, 3, 4)) # b c l h p + + # Perform the computation using manual broadcasting and reductions + # This is a manual implementation of the einsum from PyTorch + BC = mx.matmul(mx.transpose(C_expanded, (0, 1, 2, 4, 3)), + mx.transpose(B_expanded, (0, 1, 3, 4, 2))) # b c l n n + L_x = mx.matmul(mx.transpose(L_reshaped, (0, 1, 2, 4, 3)), + mx.reshape(x_reshaped, (b, num_chunks, chunk_size, dh, 1))) # b c l s 1 + Y_diag = mx.matmul(BC, L_x) # b c l h dh + Y_diag = mx.reshape(Y_diag, (b, num_chunks, chunk_size, h, dh)) + + # 2. Compute state for each intra-chunk + decay_states = mx.exp(A_cumsum[:, :, :, -1:] - A_cumsum) + + # Compute states using matrix multiplications (replacing einsum) + # "bclhn, bhcl, bclhp -> bchpn" + B_decay = mx.matmul(B_chunks, + mx.reshape(decay_states, (b, h, num_chunks, chunk_size, 1))) + states = mx.matmul(B_decay, + mx.reshape(x_chunks, (b, num_chunks, chunk_size, h, dh, 1))) + states = mx.reshape(states, (b, num_chunks, h, dh, -1)) # b c h p n + + # 3. Compute inter-chunk recurrence + if initial_states is None: + initial_states = mx.zeros((b, 1, h, dh, B.shape[-1])) + + states = mx.concatenate([initial_states, states], axis=1) + + # Create padded A_cumsum for decay calculation + A_cumsum_last = A_cumsum[:, :, :, -1] + padded_A_cumsum = mx.pad(A_cumsum_last, [(0, 0), (0, 0), (1, 0)]) + decay_chunk = mx.exp(segsum(padded_A_cumsum)) + + # Compute new states (replacing einsum "bhzc, bchpn -> bzhpn") + decay_chunk_expanded = mx.reshape(decay_chunk, (b, h, -1, num_chunks+1, 1, 1)) + states_expanded = mx.reshape(states, (b, 1, num_chunks+1, h, dh, -1)) + new_states = decay_chunk_expanded * states_expanded + new_states = mx.sum(new_states, axis=2) + + states, final_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + state_decay_out = mx.exp(A_cumsum) + + # Compute Y_off (replacing einsum "bclhn, bchpn, bhcl -> bclhp") + state_decay_expanded = mx.reshape(state_decay_out, (b, h, num_chunks, chunk_size, 1)) + states_reshaped = mx.reshape(states, (b, num_chunks, h, dh, -1)) + C_states = mx.matmul(mx.transpose(C_chunks, (0, 1, 2, 4, 3)), + mx.transpose(states_reshaped, (0, 1, 3, 2, 4))) + Y_off = C_states * state_decay_expanded + Y_off = mx.sum(Y_off, axis=-1) + Y_off = mx.reshape(Y_off, (b, num_chunks, chunk_size, h, dh)) + + # Add diagonal and off-diagonal contributions + Y = Y_diag + Y_off + Y = mx.reshape(Y, (b, seqlen, h, dh)) + + return Y, final_state + +def ssd_inference_step(x, A, B, C, prev_state=None): + """Simple inference step for Mamba-2 + + Works with: + - x: (batch, seqlen, n_heads, d_head) + - A: (n_heads,) - scalar values + - B: (batch, seqlen, n_groups, d_state) + - C: (batch, seqlen, n_groups, d_state) + """ + # Extract dimensions + b, seqlen, h, dh = x.shape + _, _, g, d_state = B.shape + + # Compute decay factor + dA = mx.exp(A) # (n_heads,) + + # Output container + outputs = [] + + # Final state to return + final_state = prev_state + + # For each position in the sequence + for t in range(seqlen): + # Get current values + xt = x[:, t] # (batch, n_heads, d_head) + Bt = B[:, t] # (batch, n_groups, d_state) + Ct = C[:, t] # (batch, n_groups, d_state) + + # Handle groups vs heads if they differ + if g < h: + repeat_factor = h // g + Bt = mx.repeat(Bt, repeat_factor, axis=1) # (batch, n_heads, d_state) + Ct = mx.repeat(Ct, repeat_factor, axis=1) # (batch, n_heads, d_state) + + # Reshape for matrix operations + xt = mx.reshape(xt, (b, h, dh, 1)) + Bt = mx.reshape(Bt, (b, h, 1, d_state)) + + # Compute B·x + dBx = mx.matmul(xt, Bt) # (batch, n_heads, d_head, d_state) + + # Update state + if final_state is not None: + dA_expanded = mx.reshape(dA, (1, h, 1, 1)) + new_state = final_state * dA_expanded + dBx + else: + new_state = dBx + + # Compute output + Ct = mx.reshape(Ct, (b, h, d_state, 1)) + yt = mx.matmul(new_state, Ct) # (batch, n_heads, d_head, 1) + yt = mx.reshape(yt, (b, h, dh)) + + # Add to outputs + outputs.append(mx.expand_dims(yt, 1)) + + # Update state for next position + final_state = new_state + + # Combine all outputs + y = mx.concatenate(outputs, axis=1) + + return y, final_state + class Mamba2Block(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -170,57 +330,70 @@ class Mamba2Block(nn.Module): def __call__(self, u: mx.array, cache=None): batch_size, seq_len, _ = u.shape - if cache is None: cache = [None, None] else: conv_state, ssm_state = cache - + zxBCdt = self.in_proj(u) - + + # Split the projection into components z, xBC, dt = mx.split( zxBCdt, - [self.d_inner, 2 * self.d_inner + 2 * self.n_groups * self.d_state], + [self.d_inner, 2*self.d_inner + 2*self.n_groups*self.d_state], axis=-1 ) - + + # Apply convolution and gating xBC, conv_state = self.conv1d(xBC, conv_state) xBC = xBC * mx.sigmoid(xBC) xBC = xBC[:, :seq_len, :] - + + # Split into the various components x, B, C = mx.split( - xBC, - [self.d_inner, self.d_inner + self.d_state * self.n_groups], + xBC, + [self.d_inner, self.d_inner + self.d_state*self.n_groups], axis=-1 ) - + + # Reshape for SSM computation x = mx.reshape(x, (batch_size, seq_len, self.n_heads, self.d_head)) B = mx.reshape(B, (batch_size, seq_len, self.n_groups, -1)) C = mx.reshape(C, (batch_size, seq_len, self.n_groups, -1)) - - A = -mx.exp(self.A_log) + + # Process dt - similar to your ssd_forward_attn function + dt = mx.reshape(dt, (batch_size, seq_len, self.n_heads)) + dt = dt + self.dt_bias.reshape(1, 1, -1) # Apply bias + dt = nn.softplus(dt) # Ensure positive time steps + dt = mx.clip(dt, self.args.time_step_min, self.args.time_step_max) + + # For inference, we use ssd_forward_attn which you already know works y, next_ssm_state = ssd_forward_attn( x=x, dt=dt, - A=-mx.exp(self.A_log), + A=self.A_log, # Use A_log directly, the function will process it B=B, C=C, D=self.D, - dt_bias=self.dt_bias, + dt_bias=None, # We already applied dt_bias above dt_min=self.args.time_step_min, dt_max=self.args.time_step_max, prev_state=ssm_state ) - + + # Reshape output + y = mx.reshape(y, (batch_size, seq_len, self.d_inner)) + + # Apply normalization and gating if self.args.norm_before_gate: y = self.norm(y) y = y * nn.silu(z) else: y = y * nn.silu(z) y = self.norm(y) - + y = self.out_proj(y) - + cache[0] = conv_state cache[1] = next_ssm_state return y