mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 10:41:18 +08:00
notes
This commit is contained in:
parent
e43a2ab229
commit
9ab581d678
@ -310,6 +310,7 @@ class Mamba2Block(nn.Module):
|
|||||||
))
|
))
|
||||||
|
|
||||||
for pos in range(seq_len):
|
for pos in range(seq_len):
|
||||||
|
# Getting stuck here in last position, also cache from pos 0 is the same.
|
||||||
# Get single token
|
# Get single token
|
||||||
u_t = u[:, pos:pos+1, :]
|
u_t = u[:, pos:pos+1, :]
|
||||||
|
|
||||||
@ -443,7 +444,7 @@ class Model(nn.Module):
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
def make_cache(self, batch_size=1):
|
def make_cache(self, batch_size=1):
|
||||||
return [Mamba2Cache(batch_size, self.args.num_heads, self.args.head_dim, self.args.state_size) for _ in range(len(self.layers))]
|
return [Mamba2Cache() for _ in range(len(self.layers))]
|
||||||
|
|
||||||
def sanitize(self, weights):
|
def sanitize(self, weights):
|
||||||
sanitized = {}
|
sanitized = {}
|
||||||
|
Loading…
Reference in New Issue
Block a user