mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 04:25:06 +08:00
Add __iter__ and __next__ methods to PlamoCache
This commit is contained in:
parent
e2d9d619c4
commit
fb1559e1f3
@ -194,11 +194,16 @@ class PlamoAttentionCache(nn.Module):
|
|||||||
assert value.shape[2] == L
|
assert value.shape[2] == L
|
||||||
self.key = key
|
self.key = key
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def state(self) -> tuple[mx.array, mx.array]:
|
def state(self) -> tuple[mx.array, mx.array]:
|
||||||
return self.key, self.value
|
return self.key, self.value
|
||||||
|
|
||||||
|
@state.setter
|
||||||
|
def state(self, key: mx.array, value: mx.array) -> None:
|
||||||
|
self.key = key
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
|
||||||
class PlamoMambaCache(nn.Module):
|
class PlamoMambaCache(nn.Module):
|
||||||
def __init__(self, conv_state: mx.array, ssm_state: mx.array) -> None:
|
def __init__(self, conv_state: mx.array, ssm_state: mx.array) -> None:
|
||||||
@ -215,6 +220,12 @@ class PlamoMambaCache(nn.Module):
|
|||||||
def state(self) -> tuple[mx.array, mx.array]:
|
def state(self) -> tuple[mx.array, mx.array]:
|
||||||
return self.conv_state, self.ssm_state
|
return self.conv_state, self.ssm_state
|
||||||
|
|
||||||
|
@state.setter
|
||||||
|
def state(self, conv_satte: mx.array, ssm_state: mx.array) -> None:
|
||||||
|
self.conv_state = conv_satte
|
||||||
|
self.ssm_state = ssm_state
|
||||||
|
|
||||||
|
|
||||||
PlamoLayerCache = PlamoAttentionCache | PlamoMambaCache
|
PlamoLayerCache = PlamoAttentionCache | PlamoMambaCache
|
||||||
|
|
||||||
|
|
||||||
@ -293,6 +304,18 @@ class PlamoCache(nn.Module):
|
|||||||
layer_cache = self.cache[layer_idx]
|
layer_cache = self.cache[layer_idx]
|
||||||
return layer_cache # type: ignore
|
return layer_cache # type: ignore
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
self._counter = 0
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self._counter < len(self.cache):
|
||||||
|
layer_cache = self.cache[self._counter]
|
||||||
|
self._counter += 1
|
||||||
|
return layer_cache
|
||||||
|
else:
|
||||||
|
raise StopIteration
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def state(self):
|
def state(self):
|
||||||
return self.cache
|
return self.cache
|
||||||
@ -1597,7 +1620,7 @@ class Model(PlamoPreTrainedModel):
|
|||||||
|
|
||||||
if not config.tie_word_embeddings:
|
if not config.tie_word_embeddings:
|
||||||
self.lm_head: nn.Module = nn.Linear(config.hidden_size, vocab_size, bias=False)
|
self.lm_head: nn.Module = nn.Linear(config.hidden_size, vocab_size, bias=False)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
# self.post_init()
|
# self.post_init()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user