From e2d9d619c4fc5827ccae41aec542f0c95cfda52d Mon Sep 17 00:00:00 2001 From: Shunta Saito Date: Sun, 23 Feb 2025 16:06:38 +0900 Subject: [PATCH] Add state property to PlamoCache --- llms/mlx_lm/models/plamo2.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/llms/mlx_lm/models/plamo2.py b/llms/mlx_lm/models/plamo2.py index 773a6839..e8b222f1 100644 --- a/llms/mlx_lm/models/plamo2.py +++ b/llms/mlx_lm/models/plamo2.py @@ -194,6 +194,10 @@ class PlamoAttentionCache(nn.Module): assert value.shape[2] == L self.key = key self.value = value + + @property + def state(self) -> tuple[mx.array, mx.array]: + return self.key, self.value class PlamoMambaCache(nn.Module): @@ -207,6 +211,9 @@ class PlamoMambaCache(nn.Module): self.conv_state = conv_state self.ssm_state = ssm_state + @property + def state(self) -> tuple[mx.array, mx.array]: + return self.conv_state, self.ssm_state PlamoLayerCache = PlamoAttentionCache | PlamoMambaCache