mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 06:22:46 +08:00
Allow a cache obj defined externally
This commit is contained in:
parent
00d13ebd40
commit
81917d41d5
@ -286,6 +286,14 @@ class PlamoCache(nn.Module):
|
||||
layer_cache = self.cache[layer_idx]
|
||||
return layer_cache # type: ignore
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self.cache
|
||||
|
||||
@state.setter
|
||||
def state(self, v):
|
||||
self.cache = v
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.cache)
|
||||
|
||||
@ -295,7 +303,7 @@ class PlamoCache(nn.Module):
|
||||
assert isinstance(c, PlamoAttentionCache)
|
||||
return c.key.shape[2] # type: ignore
|
||||
|
||||
sequence_length: int | None = None
|
||||
sequence_length: int = 0
|
||||
for layer_cache in self.cache:
|
||||
if isinstance(layer_cache, PlamoAttentionCache):
|
||||
sequence_length = (
|
||||
@ -303,7 +311,6 @@ class PlamoCache(nn.Module):
|
||||
if sequence_length is not None
|
||||
else layer_cache.key.shape[2]
|
||||
)
|
||||
assert sequence_length is not None
|
||||
return sequence_length
|
||||
|
||||
def get_max_length(self) -> int | None:
|
||||
@ -1244,7 +1251,7 @@ class PlamoDecoder(nn.Module):
|
||||
all_self_attns: Optional[tuple[mx.array, ...]] = () if x.output_attentions else None
|
||||
hidden_states = x.hidden_states
|
||||
|
||||
for layer_i, decoder_layer in enumerate(self.layers):
|
||||
for decoder_layer in self.layers:
|
||||
if x.output_hidden_states:
|
||||
assert all_hidden_states is not None
|
||||
all_hidden_states += (hidden_states,)
|
||||
@ -1590,8 +1597,6 @@ class Model(PlamoPreTrainedModel):
|
||||
if not config.tie_word_embeddings:
|
||||
self.lm_head: nn.Module = nn.Linear(config.hidden_size, vocab_size, bias=False)
|
||||
|
||||
self._cache: Optional[PlamoCache] = None
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
# self.post_init()
|
||||
|
||||
@ -1623,16 +1628,12 @@ class Model(PlamoPreTrainedModel):
|
||||
return PlamoCache(self.config)
|
||||
|
||||
def __call__(self, inputs: mx.array, cache: PlamoCache | None = None) -> mx.array:
|
||||
if self._cache is not None:
|
||||
plamo_cache = self._cache
|
||||
else:
|
||||
plamo_cache = None
|
||||
output = self.forward(
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
input_ids=inputs,
|
||||
past_key_values=plamo_cache,
|
||||
past_key_values=cache,
|
||||
use_cache=self.config.use_cache,
|
||||
return_dict=True,
|
||||
)
|
||||
output = self.forward(**model_inputs)
|
||||
if not isinstance(output, CausalLMOutputWithPast):
|
||||
raise ValueError(
|
||||
f"Unexpected output type for causal language model: {type(output)} != CausalLMOutputWithPast"
|
||||
|
Loading…
Reference in New Issue
Block a user