Add __iter__ and __next__ methods to PlamoCache

This commit is contained in:
Shunta Saito 2025-02-23 17:33:19 +09:00
parent e2d9d619c4
commit fb1559e1f3

View File

@ -199,6 +199,11 @@ class PlamoAttentionCache(nn.Module):
def state(self) -> tuple[mx.array, mx.array]:
return self.key, self.value
@state.setter
def state(self, key: mx.array, value: mx.array) -> None:
self.key = key
self.value = value
class PlamoMambaCache(nn.Module):
def __init__(self, conv_state: mx.array, ssm_state: mx.array) -> None:
@ -215,6 +220,12 @@ class PlamoMambaCache(nn.Module):
def state(self) -> tuple[mx.array, mx.array]:
return self.conv_state, self.ssm_state
@state.setter
def state(self, conv_satte: mx.array, ssm_state: mx.array) -> None:
self.conv_state = conv_satte
self.ssm_state = ssm_state
PlamoLayerCache = PlamoAttentionCache | PlamoMambaCache
@ -293,6 +304,18 @@ class PlamoCache(nn.Module):
layer_cache = self.cache[layer_idx]
return layer_cache # type: ignore
def __iter__(self):
self._counter = 0
return self
def __next__(self):
if self._counter < len(self.cache):
layer_cache = self.cache[self._counter]
self._counter += 1
return layer_cache
else:
raise StopIteration
@property
def state(self):
return self.cache