diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index c26e2925..eb4ae8b3 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -33,8 +33,7 @@ class ModelArgs(BaseModelArgs): time_step_min: float time_step_max: float time_step_floor: float - A_init_min: float = 1.0 - A_init_max: float = 16.0 + norm_before_gate: bool = True def __post_init__(self): if not hasattr(self, "intermediate_size"): @@ -46,17 +45,29 @@ class ModelArgs(BaseModelArgs): class MambaRMSNormGated(nn.Module): - def __init__(self, hidden_size, eps=1e-6): + def __init__(self, hidden_size, eps=1e-6, norm_before_gate=False): super().__init__() self.weight = mx.ones((hidden_size,)) self.variance_epsilon = eps - - def __call__(self, hidden_states, gate=None): - if gate is not None: - hidden_states = hidden_states * nn.silu(gate) - variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True) - hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states + self.norm_before_gate = norm_before_gate + + def rms_norm(self, x): + variance = mx.mean(x ** 2, axis=-1, keepdims=True) + x = x * mx.rsqrt(variance + self.variance_epsilon) + return self.weight * x + + def __call__(self, x, z=None): + if z is None: + return self.rms_norm(x) + + if self.norm_before_gate: + x = self.rms_norm(x) + x = x * nn.silu(z) + else: + x = x * nn.silu(z) + x = self.rms_norm(x) + + return x def silu(x): @@ -86,12 +97,71 @@ class DepthWiseConv1d(nn.Module): return y, x[:, -K + 1:, :] +def ssd_forward_attn( + x: mx.array, + dt: mx.array, + A: mx.array, + B: mx.array, + C: mx.array, + D: mx.array, + dt_bias: mx.array, + dt_min: float, + dt_max: float, +) -> Tuple[mx.array, mx.array]: + b, l, h, dh = x.shape + _, _, g, _ = B.shape + + if dt_bias is not None: + dt = dt + dt_bias.reshape(1, 1, -1) + + dt = nn.softplus(dt) + dt = mx.clip(dt, a_min=dt_min, a_max=dt_max) + + B = mx.swapaxes(mx.swapaxes(B, 1, 3), 1, 2) + C = mx.swapaxes(C, 1, 2) + + CB = C @ B + CB = mx.repeat(CB, repeats=h // g, axis=1) + + dtA = dt * A.reshape(1, 1, -1) + dtA = mx.swapaxes(dtA, 1, 2) + + decay = mx.exp(segsum(dtA)) + + surrogate_attention_matrix = mx.tril(CB * decay, 0) + + dtx = dt.reshape(b, l, h, 1) * x + y = surrogate_attention_matrix @ dtx.swapaxes(1, 2) + y = mx.swapaxes(y, 1, 2) + + decay = decay[:, :, -1, :].reshape(b, h, l).swapaxes(1, 2).reshape(b, l, h, 1) + B = mx.repeat(B, h // g, axis=1).swapaxes(2, 3) + dtxdecay = dtx * decay + dtxdecay = dtxdecay.swapaxes(1, 2).swapaxes(2, 3) + next_state = dtxdecay @ B + + if D is not None: + y += x * D.reshape(1, 1, h, 1) + + y = y.reshape(b, l, h * dh) + + return y, next_state + + +def segsum(x): + l = x.shape[-1] + x = mx.repeat(x[..., None], l, axis=-1) + x = mx.tril(x, -1) + x_segsum = mx.cumsum(x, axis=-2) + return x_segsum + + class Mamba2Block(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args = args - # Same dimensions as before + # Dimensions self.d_model = args.hidden_size self.d_state = args.state_size self.d_conv = args.conv_kernel @@ -106,14 +176,12 @@ class Mamba2Block(nn.Module): d_in_proj = 2 * self.d_inner + 2 * self.n_groups * self.d_state + self.n_heads self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=args.use_bias) + # Parameters self.dt_bias = mx.random.normal((self.n_heads,)) * args.initializer_range self.A_log = mx.random.normal((self.n_heads,)) * args.initializer_range self.D = mx.random.normal((self.n_heads,)) * args.initializer_range - - # Same D initialization - self.D = mx.random.normal((self.n_heads,)) * args.initializer_range - # Convolution with proper initialization + # Convolution self.conv1d = DepthWiseConv1d( channels=self.d_inner + 2 * self.n_groups * self.d_state, kernel_size=self.d_conv, @@ -122,7 +190,11 @@ class Mamba2Block(nn.Module): ) # Output projections - self.norm = MambaRMSNormGated(self.d_inner, eps=args.layer_norm_epsilon) + self.norm = MambaRMSNormGated( + self.d_inner, + eps=args.layer_norm_epsilon, + norm_before_gate=args.norm_before_gate + ) self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=args.use_bias) def __call__(self, u: mx.array, cache=None): @@ -131,103 +203,59 @@ class Mamba2Block(nn.Module): cache = [None, None] # Project input - zxbcdt = self.in_proj(u) # (B, L, d_in_proj) - A = -mx.exp(self.A_log) # (nheads) or (d_inner, d_state) - + zxBCdt = self.in_proj(u) + + # Split projections z, xBC, dt = mx.split( - zxbcdt, - indices_or_sections=[ - self.d_inner, - self.d_inner + (2 * self.n_groups * self.d_state + self.d_inner) - ], + zxBCdt, + [self.d_inner, 2 * self.d_inner + 2 * self.n_groups * self.d_state], axis=-1 ) - # Process dt - dt = nn.softplus(dt + self.dt_bias) # (B, L, nheads) - - # Conv1d and activation - xBC, conv_state = self.conv1d(xBC, cache[0] if cache else None) + # Process convolution + xBC, conv_state = self.conv1d(xBC, cache[0]) xBC = silu(xBC) - if cache is not None: cache[0] = conv_state - xBC = xBC[:, :seq_len, :] - # Split conv output and reshape + # Split and reshape conv output x, B, C = mx.split( - xBC, - indices_or_sections=[ - self.d_inner, - self.d_inner + self.n_groups * self.d_state - ], + xBC, + [self.d_inner, self.d_inner + self.d_state * self.n_groups], axis=-1 ) - # Reshape tensors + # Reshape for SSM processing + x = mx.reshape(x, (batch_size, seq_len, self.n_heads, self.d_head)) B = mx.reshape(B, (batch_size, seq_len, self.n_groups, -1)) C = mx.reshape(C, (batch_size, seq_len, self.n_groups, -1)) - x = mx.reshape(x, (batch_size, seq_len, self.n_heads, -1)) - # Initialize state - if cache and cache[1] is not None: - prev_state = cache[1] - else: - prev_state = mx.zeros((batch_size, self.n_heads, self.d_head, self.d_state)) + # Get parameters for attention computation + A = -mx.exp(self.A_log) - # Compute dA - dt = mx.reshape(dt, (batch_size, seq_len, self.n_heads)) - dA = mx.exp(dt * mx.expand_dims(A, axis=(0, 1))) + # Compute parallel attention + y, next_state = ssd_forward_attn( + x=x, + dt=dt, + A=A, + B=B, + C=C, + D=self.D, + dt_bias=self.dt_bias, + dt_min=self.args.time_step_min, + dt_max=self.args.time_step_max, + ) - # Process sequence in chunks - chunk_size = self.chunk_size - outputs = [] - next_state = prev_state - - # Process in chunks - for chunk_start in range(0, seq_len, chunk_size): - chunk_end = min(chunk_start + chunk_size, seq_len) - - # Get current chunk - x_chunk = x[:, chunk_start:chunk_end] - B_chunk = B[:, chunk_start:chunk_end] - C_chunk = C[:, chunk_start:chunk_end] - dA_chunk = dA[:, chunk_start:chunk_end] - z_chunk = z[:, chunk_start:chunk_end] - - # Process the chunk in batches - chunk_outputs = [] - chunk_state = next_state - - for t in range(chunk_end - chunk_start): - xt = x_chunk[:, t] - Bt = B_chunk[:, t] - Ct = C_chunk[:, t] - dAt = dA_chunk[:, t] - - # Update state - dBx = mx.einsum('bh,bgd,bhp->bhpd', dAt, Bt, xt) - chunk_state = chunk_state * mx.expand_dims(dAt, axis=(-1, -2)) + dBx - - # Compute output - yt = mx.einsum('bhpd,bgd->bhp', chunk_state, Ct) - yt = yt + xt * mx.expand_dims(self.D, -1) - - # Reshape and normalize - yt = mx.reshape(yt, (batch_size, 1, self.d_inner)) - yt = self.norm(yt, z_chunk[:, t:t+1]) - chunk_outputs.append(self.out_proj(yt)) - - # Update state for next chunk - next_state = chunk_state - outputs.extend(chunk_outputs) - - # Update cache with final state + # Update cache if cache is not None: cache[1] = next_state + + # Apply normalization and output projection + y = self.norm(y, z) + y = self.out_proj(y) - return mx.concatenate(outputs, axis=1) + return y class ResidualBlock(nn.Module): @@ -238,8 +266,8 @@ class ResidualBlock(nn.Module): self.norm = nn.RMSNorm(args.hidden_size) def __call__(self, x: mx.array, cache): - if self.residual_in_fp32: - x = x.astype(mx.float32) + # if self.residual_in_fp32: + # x = x.astype(mx.float32) normed = self.norm(x) output = self.mixer(normed, cache) return output + x