diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index 739cb400..f5a0e18a 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -310,6 +310,7 @@ class Mamba2Block(nn.Module): )) for pos in range(seq_len): + # Getting stuck here in last position, also cache from pos 0 is the same. # Get single token u_t = u[:, pos:pos+1, :] @@ -443,7 +444,7 @@ class Model(nn.Module): return logits def make_cache(self, batch_size=1): - return [Mamba2Cache(batch_size, self.args.num_heads, self.args.head_dim, self.args.state_size) for _ in range(len(self.layers))] + return [Mamba2Cache() for _ in range(len(self.layers))] def sanitize(self, weights): sanitized = {}