mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 15:08:37 +08:00
Add __iter__ and __next__ methods to PlamoCache
This commit is contained in:
parent
e2d9d619c4
commit
fb1559e1f3
@ -194,11 +194,16 @@ class PlamoAttentionCache(nn.Module):
|
||||
assert value.shape[2] == L
|
||||
self.key = key
|
||||
self.value = value
|
||||
|
||||
|
||||
@property
|
||||
def state(self) -> tuple[mx.array, mx.array]:
|
||||
return self.key, self.value
|
||||
|
||||
@state.setter
|
||||
def state(self, key: mx.array, value: mx.array) -> None:
|
||||
self.key = key
|
||||
self.value = value
|
||||
|
||||
|
||||
class PlamoMambaCache(nn.Module):
|
||||
def __init__(self, conv_state: mx.array, ssm_state: mx.array) -> None:
|
||||
@ -215,6 +220,12 @@ class PlamoMambaCache(nn.Module):
|
||||
def state(self) -> tuple[mx.array, mx.array]:
|
||||
return self.conv_state, self.ssm_state
|
||||
|
||||
@state.setter
|
||||
def state(self, conv_satte: mx.array, ssm_state: mx.array) -> None:
|
||||
self.conv_state = conv_satte
|
||||
self.ssm_state = ssm_state
|
||||
|
||||
|
||||
PlamoLayerCache = PlamoAttentionCache | PlamoMambaCache
|
||||
|
||||
|
||||
@ -293,6 +304,18 @@ class PlamoCache(nn.Module):
|
||||
layer_cache = self.cache[layer_idx]
|
||||
return layer_cache # type: ignore
|
||||
|
||||
def __iter__(self):
|
||||
self._counter = 0
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self._counter < len(self.cache):
|
||||
layer_cache = self.cache[self._counter]
|
||||
self._counter += 1
|
||||
return layer_cache
|
||||
else:
|
||||
raise StopIteration
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self.cache
|
||||
@ -1597,7 +1620,7 @@ class Model(PlamoPreTrainedModel):
|
||||
|
||||
if not config.tie_word_embeddings:
|
||||
self.lm_head: nn.Module = nn.Linear(config.hidden_size, vocab_size, bias=False)
|
||||
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
# self.post_init()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user