mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 10:41:18 +08:00
notes
This commit is contained in:
parent
e43a2ab229
commit
9ab581d678
@ -310,6 +310,7 @@ class Mamba2Block(nn.Module):
|
||||
))
|
||||
|
||||
for pos in range(seq_len):
|
||||
# Getting stuck here in last position, also cache from pos 0 is the same.
|
||||
# Get single token
|
||||
u_t = u[:, pos:pos+1, :]
|
||||
|
||||
@ -443,7 +444,7 @@ class Model(nn.Module):
|
||||
return logits
|
||||
|
||||
def make_cache(self, batch_size=1):
|
||||
return [Mamba2Cache(batch_size, self.args.num_heads, self.args.head_dim, self.args.state_size) for _ in range(len(self.layers))]
|
||||
return [Mamba2Cache() for _ in range(len(self.layers))]
|
||||
|
||||
def sanitize(self, weights):
|
||||
sanitized = {}
|
||||
|
Loading…
Reference in New Issue
Block a user