This commit is contained in:
Goekdeniz-Guelmez 2024-10-22 22:10:53 +02:00
parent e43a2ab229
commit 9ab581d678

View File

@ -310,6 +310,7 @@ class Mamba2Block(nn.Module):
)) ))
for pos in range(seq_len): for pos in range(seq_len):
# Getting stuck here in last position, also cache from pos 0 is the same.
# Get single token # Get single token
u_t = u[:, pos:pos+1, :] u_t = u[:, pos:pos+1, :]
@ -443,7 +444,7 @@ class Model(nn.Module):
return logits return logits
def make_cache(self, batch_size=1): def make_cache(self, batch_size=1):
return [Mamba2Cache(batch_size, self.args.num_heads, self.args.head_dim, self.args.state_size) for _ in range(len(self.layers))] return [Mamba2Cache() for _ in range(len(self.layers))]
def sanitize(self, weights): def sanitize(self, weights):
sanitized = {} sanitized = {}