Add state property to PlamoCache

This commit is contained in:
Shunta Saito 2025-02-23 16:06:38 +09:00
parent 21c0abaf23
commit e2d9d619c4

View File

@ -195,6 +195,10 @@ class PlamoAttentionCache(nn.Module):
self.key = key
self.value = value
@property
def state(self) -> tuple[mx.array, mx.array]:
return self.key, self.value
class PlamoMambaCache(nn.Module):
def __init__(self, conv_state: mx.array, ssm_state: mx.array) -> None:
@ -207,6 +211,9 @@ class PlamoMambaCache(nn.Module):
self.conv_state = conv_state
self.ssm_state = ssm_state
@property
def state(self) -> tuple[mx.array, mx.array]:
return self.conv_state, self.ssm_state
PlamoLayerCache = PlamoAttentionCache | PlamoMambaCache