diff --git a/llms/mlx_lm/models/plamo2.py b/llms/mlx_lm/models/plamo2.py index 53d3a4e5..e7ffe488 100644 --- a/llms/mlx_lm/models/plamo2.py +++ b/llms/mlx_lm/models/plamo2.py @@ -1917,7 +1917,7 @@ class Model(PlamoPreTrainedModel): self.lm_head: nn.Module = nn.Linear( config.hidden_size, vocab_size, bias=False ) - + self._past_key_values: Optional[tuple[tuple[mx.array]]] = None # Initialize weights and apply final processing @@ -1940,19 +1940,14 @@ class Model(PlamoPreTrainedModel): def get_decoder(self) -> PlamoModel: return self.model - + def sanitize(self, weights: dict[Any, Any]) -> dict[Any, Any]: for k, v in weights.items(): if "conv1d.weight" in k and v.shape[-1] != 1: weights[k] = v.moveaxis(2, 1) return weights - def make_cache(self) -> PlamoCache: - print("make_cache") - return "a" - def __call__(self, inputs: mx.array, cache: PlamoCache | None = None) -> mx.array: - print(cache) output = self.forward( input_ids=inputs, use_cache=self.config.use_cache, @@ -2117,4 +2112,4 @@ class Model(PlamoPreTrainedModel): @property def layers(self): - return self.model.layers \ No newline at end of file + return self.model.layers