diff --git a/llms/mlx_lm/models/plamo2.py b/llms/mlx_lm/models/plamo2.py index b78c6530..c7624811 100644 --- a/llms/mlx_lm/models/plamo2.py +++ b/llms/mlx_lm/models/plamo2.py @@ -285,6 +285,14 @@ class PlamoCache(nn.Module): assert layer_idx < len(self.cache) layer_cache = self.cache[layer_idx] return layer_cache # type: ignore + + @property + def state(self): + return self.cache + + @state.setter + def state(self, v): + self.cache = v def __len__(self) -> int: return len(self.cache) @@ -295,7 +303,7 @@ class PlamoCache(nn.Module): assert isinstance(c, PlamoAttentionCache) return c.key.shape[2] # type: ignore - sequence_length: int | None = None + sequence_length: int = 0 for layer_cache in self.cache: if isinstance(layer_cache, PlamoAttentionCache): sequence_length = ( @@ -303,7 +311,6 @@ class PlamoCache(nn.Module): if sequence_length is not None else layer_cache.key.shape[2] ) - assert sequence_length is not None return sequence_length def get_max_length(self) -> int | None: @@ -1244,7 +1251,7 @@ class PlamoDecoder(nn.Module): all_self_attns: Optional[tuple[mx.array, ...]] = () if x.output_attentions else None hidden_states = x.hidden_states - for layer_i, decoder_layer in enumerate(self.layers): + for decoder_layer in self.layers: if x.output_hidden_states: assert all_hidden_states is not None all_hidden_states += (hidden_states,) @@ -1590,8 +1597,6 @@ class Model(PlamoPreTrainedModel): if not config.tie_word_embeddings: self.lm_head: nn.Module = nn.Linear(config.hidden_size, vocab_size, bias=False) - self._cache: Optional[PlamoCache] = None - # Initialize weights and apply final processing # self.post_init() @@ -1623,16 +1628,12 @@ class Model(PlamoPreTrainedModel): return PlamoCache(self.config) def __call__(self, inputs: mx.array, cache: PlamoCache | None = None) -> mx.array: - if self._cache is not None: - plamo_cache = self._cache - else: - plamo_cache = None - output = self.forward( + model_inputs = self.prepare_inputs_for_generation( input_ids=inputs, - past_key_values=plamo_cache, + past_key_values=cache, use_cache=self.config.use_cache, - return_dict=True, ) + output = self.forward(**model_inputs) if not isinstance(output, CausalLMOutputWithPast): raise ValueError( f"Unexpected output type for causal language model: {type(output)} != CausalLMOutputWithPast"