diff --git a/llms/mlx_lm/models/helium.py b/llms/mlx_lm/models/helium.py index 6ca46a72..993ce9d5 100644 --- a/llms/mlx_lm/models/helium.py +++ b/llms/mlx_lm/models/helium.py @@ -1,3 +1,5 @@ +# Copyright © 2024 Apple Inc. + from dataclasses import dataclass from typing import Any, Optional, Tuple diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index f8db469c..e3877e19 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -148,20 +148,15 @@ class MambaBlock(nn.Module): return y, new_state def _process_sequence(self, x, conv_cache, state_cache): - """Process a sequence of inputs with cached states""" B, T, D = x.shape - # Project all tokens at once xz = self.in_proj(x.reshape(-1, D)).reshape(B, T, -1) - x_t, z_t = xz.split(indices_or_sections=2, axis=-1) # Fixed: using split instead of chunk + x_t, z_t = xz.split(indices_or_sections=2, axis=-1) - # Handle convolution with cache conv_out, new_conv_cache = self.conv1d(x_t, conv_cache) x_t = nn.silu(conv_out) - # Pre-compute A matrix A = -mx.exp(self.A_log) - # Process sequence with state outputs = [] current_state = state_cache for t in range(T): @@ -174,17 +169,14 @@ class MambaBlock(nn.Module): def __call__(self, x, cache): if cache is None or isinstance(cache, list): - # Handle legacy cache format conv_cache, state_cache = cache if cache is not None else (None, None) else: - # Handle MambaCache object conv_cache, state_cache = cache.state output, (new_conv_cache, new_state_cache) = self._process_sequence( x, conv_cache, state_cache ) - # Update cache if isinstance(cache, MambaCache): cache[0] = new_conv_cache cache[1] = new_state_cache