Allow a cache obj defined externally

This commit is contained in:
Shunta Saito 2025-02-14 21:11:30 +09:00
parent 00d13ebd40
commit 81917d41d5

View File

@ -285,6 +285,14 @@ class PlamoCache(nn.Module):
assert layer_idx < len(self.cache)
layer_cache = self.cache[layer_idx]
return layer_cache # type: ignore
@property
def state(self):
return self.cache
@state.setter
def state(self, v):
self.cache = v
def __len__(self) -> int:
return len(self.cache)
@ -295,7 +303,7 @@ class PlamoCache(nn.Module):
assert isinstance(c, PlamoAttentionCache)
return c.key.shape[2] # type: ignore
sequence_length: int | None = None
sequence_length: int = 0
for layer_cache in self.cache:
if isinstance(layer_cache, PlamoAttentionCache):
sequence_length = (
@ -303,7 +311,6 @@ class PlamoCache(nn.Module):
if sequence_length is not None
else layer_cache.key.shape[2]
)
assert sequence_length is not None
return sequence_length
def get_max_length(self) -> int | None:
@ -1244,7 +1251,7 @@ class PlamoDecoder(nn.Module):
all_self_attns: Optional[tuple[mx.array, ...]] = () if x.output_attentions else None
hidden_states = x.hidden_states
for layer_i, decoder_layer in enumerate(self.layers):
for decoder_layer in self.layers:
if x.output_hidden_states:
assert all_hidden_states is not None
all_hidden_states += (hidden_states,)
@ -1590,8 +1597,6 @@ class Model(PlamoPreTrainedModel):
if not config.tie_word_embeddings:
self.lm_head: nn.Module = nn.Linear(config.hidden_size, vocab_size, bias=False)
self._cache: Optional[PlamoCache] = None
# Initialize weights and apply final processing
# self.post_init()
@ -1623,16 +1628,12 @@ class Model(PlamoPreTrainedModel):
return PlamoCache(self.config)
def __call__(self, inputs: mx.array, cache: PlamoCache | None = None) -> mx.array:
if self._cache is not None:
plamo_cache = self._cache
else:
plamo_cache = None
output = self.forward(
model_inputs = self.prepare_inputs_for_generation(
input_ids=inputs,
past_key_values=plamo_cache,
past_key_values=cache,
use_cache=self.config.use_cache,
return_dict=True,
)
output = self.forward(**model_inputs)
if not isinstance(output, CausalLMOutputWithPast):
raise ValueError(
f"Unexpected output type for causal language model: {type(output)} != CausalLMOutputWithPast"