From dfd51f16d6f0883f74cfbb1c49950d90903dd1c2 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Mon, 20 Jan 2025 18:59:16 +0100 Subject: [PATCH] Update MambaBlock, Batched Input Processing, Improved Cache Handling, Pre-computed Constants, Cleaner State Management, Explicit Return Values:. Before: 82.442 tokens-per-sec, after: 129.130 tokens-per-sec. --- llms/mlx_lm/models/mamba.py | 52 +++++++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 16 deletions(-) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 5c09c999..f8db469c 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -147,28 +147,48 @@ class MambaBlock(nn.Module): y = y + D * x return y, new_state - def __call__(self, x, cache): + def _process_sequence(self, x, conv_cache, state_cache): + """Process a sequence of inputs with cached states""" B, T, D = x.shape + # Project all tokens at once + xz = self.in_proj(x.reshape(-1, D)).reshape(B, T, -1) + x_t, z_t = xz.split(indices_or_sections=2, axis=-1) # Fixed: using split instead of chunk + # Handle convolution with cache + conv_out, new_conv_cache = self.conv1d(x_t, conv_cache) + x_t = nn.silu(conv_out) + + # Pre-compute A matrix A = -mx.exp(self.A_log) - - if cache is None: - cache = [None, None] - + + # Process sequence with state outputs = [] + current_state = state_cache for t in range(T): - xt = x[:, t, :] - xz = self.in_proj(xt) - x_t, z_t = xz.split(indices_or_sections=2, axis=1) - conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0]) - x_t = conv_out.squeeze(1) - x_t = nn.silu(x_t) - y_t, cache[1] = self.ssm_step(x_t, A, cache[1]) - z_t = nn.silu(z_t) - output_t = y_t * z_t - output_t = self.out_proj(output_t) + y_t, current_state = self.ssm_step(x_t[:, t], A, current_state) + z_curr = nn.silu(z_t[:, t]) + output_t = self.out_proj(y_t * z_curr) outputs.append(output_t) - output = mx.stack(outputs, axis=1) + + return mx.stack(outputs, axis=1), (new_conv_cache, current_state) + + def __call__(self, x, cache): + if cache is None or isinstance(cache, list): + # Handle legacy cache format + conv_cache, state_cache = cache if cache is not None else (None, None) + else: + # Handle MambaCache object + conv_cache, state_cache = cache.state + + output, (new_conv_cache, new_state_cache) = self._process_sequence( + x, conv_cache, state_cache + ) + + # Update cache + if isinstance(cache, MambaCache): + cache[0] = new_conv_cache + cache[1] = new_state_cache + return output