From 9ab581d6783416b3a66cd283b6fa4f1c606d0fcc Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 22 Oct 2024 22:10:53 +0200 Subject: [PATCH] notes --- llms/mlx_lm/models/mamba2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index 739cb400..f5a0e18a 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -310,6 +310,7 @@ class Mamba2Block(nn.Module): )) for pos in range(seq_len): + # Getting stuck here in last position, also cache from pos 0 is the same. # Get single token u_t = u[:, pos:pos+1, :] @@ -443,7 +444,7 @@ class Model(nn.Module): return logits def make_cache(self, batch_size=1): - return [Mamba2Cache(batch_size, self.args.num_heads, self.args.head_dim, self.args.state_size) for _ in range(len(self.layers))] + return [Mamba2Cache() for _ in range(len(self.layers))] def sanitize(self, weights): sanitized = {}