diff --git a/llms/mlx_lm/models/plamo2.py b/llms/mlx_lm/models/plamo2.py index e8b222f1..7d4a439e 100644 --- a/llms/mlx_lm/models/plamo2.py +++ b/llms/mlx_lm/models/plamo2.py @@ -194,11 +194,16 @@ class PlamoAttentionCache(nn.Module): assert value.shape[2] == L self.key = key self.value = value - + @property def state(self) -> tuple[mx.array, mx.array]: return self.key, self.value + @state.setter + def state(self, key: mx.array, value: mx.array) -> None: + self.key = key + self.value = value + class PlamoMambaCache(nn.Module): def __init__(self, conv_state: mx.array, ssm_state: mx.array) -> None: @@ -215,6 +220,12 @@ class PlamoMambaCache(nn.Module): def state(self) -> tuple[mx.array, mx.array]: return self.conv_state, self.ssm_state + @state.setter + def state(self, conv_satte: mx.array, ssm_state: mx.array) -> None: + self.conv_state = conv_satte + self.ssm_state = ssm_state + + PlamoLayerCache = PlamoAttentionCache | PlamoMambaCache @@ -293,6 +304,18 @@ class PlamoCache(nn.Module): layer_cache = self.cache[layer_idx] return layer_cache # type: ignore + def __iter__(self): + self._counter = 0 + return self + + def __next__(self): + if self._counter < len(self.cache): + layer_cache = self.cache[self._counter] + self._counter += 1 + return layer_cache + else: + raise StopIteration + @property def state(self): return self.cache @@ -1597,7 +1620,7 @@ class Model(PlamoPreTrainedModel): if not config.tie_word_embeddings: self.lm_head: nn.Module = nn.Linear(config.hidden_size, vocab_size, bias=False) - + # Initialize weights and apply final processing # self.post_init()