mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 04:25:06 +08:00
Add __iter__ and __next__ methods to PlamoCache
This commit is contained in:
parent
e2d9d619c4
commit
fb1559e1f3
@ -199,6 +199,11 @@ class PlamoAttentionCache(nn.Module):
|
||||
def state(self) -> tuple[mx.array, mx.array]:
|
||||
return self.key, self.value
|
||||
|
||||
@state.setter
|
||||
def state(self, key: mx.array, value: mx.array) -> None:
|
||||
self.key = key
|
||||
self.value = value
|
||||
|
||||
|
||||
class PlamoMambaCache(nn.Module):
|
||||
def __init__(self, conv_state: mx.array, ssm_state: mx.array) -> None:
|
||||
@ -215,6 +220,12 @@ class PlamoMambaCache(nn.Module):
|
||||
def state(self) -> tuple[mx.array, mx.array]:
|
||||
return self.conv_state, self.ssm_state
|
||||
|
||||
@state.setter
|
||||
def state(self, conv_satte: mx.array, ssm_state: mx.array) -> None:
|
||||
self.conv_state = conv_satte
|
||||
self.ssm_state = ssm_state
|
||||
|
||||
|
||||
PlamoLayerCache = PlamoAttentionCache | PlamoMambaCache
|
||||
|
||||
|
||||
@ -293,6 +304,18 @@ class PlamoCache(nn.Module):
|
||||
layer_cache = self.cache[layer_idx]
|
||||
return layer_cache # type: ignore
|
||||
|
||||
def __iter__(self):
|
||||
self._counter = 0
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self._counter < len(self.cache):
|
||||
layer_cache = self.cache[self._counter]
|
||||
self._counter += 1
|
||||
return layer_cache
|
||||
else:
|
||||
raise StopIteration
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self.cache
|
||||
|
Loading…
Reference in New Issue
Block a user