diff --git a/llms/mlx_lm/models/plamo2.py b/llms/mlx_lm/models/plamo2.py index 773a6839..e8b222f1 100644 --- a/llms/mlx_lm/models/plamo2.py +++ b/llms/mlx_lm/models/plamo2.py @@ -194,6 +194,10 @@ class PlamoAttentionCache(nn.Module): assert value.shape[2] == L self.key = key self.value = value + + @property + def state(self) -> tuple[mx.array, mx.array]: + return self.key, self.value class PlamoMambaCache(nn.Module): @@ -207,6 +211,9 @@ class PlamoMambaCache(nn.Module): self.conv_state = conv_state self.ssm_state = ssm_state + @property + def state(self) -> tuple[mx.array, mx.array]: + return self.conv_state, self.ssm_state PlamoLayerCache = PlamoAttentionCache | PlamoMambaCache