Allow a cache obj defined externally

This commit is contained in:
Shunta Saito 2025-02-14 21:11:30 +09:00
parent 00d13ebd40
commit 81917d41d5

View File

@ -286,6 +286,14 @@ class PlamoCache(nn.Module):
layer_cache = self.cache[layer_idx] layer_cache = self.cache[layer_idx]
return layer_cache # type: ignore return layer_cache # type: ignore
@property
def state(self):
return self.cache
@state.setter
def state(self, v):
self.cache = v
def __len__(self) -> int: def __len__(self) -> int:
return len(self.cache) return len(self.cache)
@ -295,7 +303,7 @@ class PlamoCache(nn.Module):
assert isinstance(c, PlamoAttentionCache) assert isinstance(c, PlamoAttentionCache)
return c.key.shape[2] # type: ignore return c.key.shape[2] # type: ignore
sequence_length: int | None = None sequence_length: int = 0
for layer_cache in self.cache: for layer_cache in self.cache:
if isinstance(layer_cache, PlamoAttentionCache): if isinstance(layer_cache, PlamoAttentionCache):
sequence_length = ( sequence_length = (
@ -303,7 +311,6 @@ class PlamoCache(nn.Module):
if sequence_length is not None if sequence_length is not None
else layer_cache.key.shape[2] else layer_cache.key.shape[2]
) )
assert sequence_length is not None
return sequence_length return sequence_length
def get_max_length(self) -> int | None: def get_max_length(self) -> int | None:
@ -1244,7 +1251,7 @@ class PlamoDecoder(nn.Module):
all_self_attns: Optional[tuple[mx.array, ...]] = () if x.output_attentions else None all_self_attns: Optional[tuple[mx.array, ...]] = () if x.output_attentions else None
hidden_states = x.hidden_states hidden_states = x.hidden_states
for layer_i, decoder_layer in enumerate(self.layers): for decoder_layer in self.layers:
if x.output_hidden_states: if x.output_hidden_states:
assert all_hidden_states is not None assert all_hidden_states is not None
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
@ -1590,8 +1597,6 @@ class Model(PlamoPreTrainedModel):
if not config.tie_word_embeddings: if not config.tie_word_embeddings:
self.lm_head: nn.Module = nn.Linear(config.hidden_size, vocab_size, bias=False) self.lm_head: nn.Module = nn.Linear(config.hidden_size, vocab_size, bias=False)
self._cache: Optional[PlamoCache] = None
# Initialize weights and apply final processing # Initialize weights and apply final processing
# self.post_init() # self.post_init()
@ -1623,16 +1628,12 @@ class Model(PlamoPreTrainedModel):
return PlamoCache(self.config) return PlamoCache(self.config)
def __call__(self, inputs: mx.array, cache: PlamoCache | None = None) -> mx.array: def __call__(self, inputs: mx.array, cache: PlamoCache | None = None) -> mx.array:
if self._cache is not None: model_inputs = self.prepare_inputs_for_generation(
plamo_cache = self._cache
else:
plamo_cache = None
output = self.forward(
input_ids=inputs, input_ids=inputs,
past_key_values=plamo_cache, past_key_values=cache,
use_cache=self.config.use_cache, use_cache=self.config.use_cache,
return_dict=True,
) )
output = self.forward(**model_inputs)
if not isinstance(output, CausalLMOutputWithPast): if not isinstance(output, CausalLMOutputWithPast):
raise ValueError( raise ValueError(
f"Unexpected output type for causal language model: {type(output)} != CausalLMOutputWithPast" f"Unexpected output type for causal language model: {type(output)} != CausalLMOutputWithPast"