mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
Allow a cache obj defined externally
This commit is contained in:
parent
00d13ebd40
commit
81917d41d5
@ -286,6 +286,14 @@ class PlamoCache(nn.Module):
|
|||||||
layer_cache = self.cache[layer_idx]
|
layer_cache = self.cache[layer_idx]
|
||||||
return layer_cache # type: ignore
|
return layer_cache # type: ignore
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self):
|
||||||
|
return self.cache
|
||||||
|
|
||||||
|
@state.setter
|
||||||
|
def state(self, v):
|
||||||
|
self.cache = v
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.cache)
|
return len(self.cache)
|
||||||
|
|
||||||
@ -295,7 +303,7 @@ class PlamoCache(nn.Module):
|
|||||||
assert isinstance(c, PlamoAttentionCache)
|
assert isinstance(c, PlamoAttentionCache)
|
||||||
return c.key.shape[2] # type: ignore
|
return c.key.shape[2] # type: ignore
|
||||||
|
|
||||||
sequence_length: int | None = None
|
sequence_length: int = 0
|
||||||
for layer_cache in self.cache:
|
for layer_cache in self.cache:
|
||||||
if isinstance(layer_cache, PlamoAttentionCache):
|
if isinstance(layer_cache, PlamoAttentionCache):
|
||||||
sequence_length = (
|
sequence_length = (
|
||||||
@ -303,7 +311,6 @@ class PlamoCache(nn.Module):
|
|||||||
if sequence_length is not None
|
if sequence_length is not None
|
||||||
else layer_cache.key.shape[2]
|
else layer_cache.key.shape[2]
|
||||||
)
|
)
|
||||||
assert sequence_length is not None
|
|
||||||
return sequence_length
|
return sequence_length
|
||||||
|
|
||||||
def get_max_length(self) -> int | None:
|
def get_max_length(self) -> int | None:
|
||||||
@ -1244,7 +1251,7 @@ class PlamoDecoder(nn.Module):
|
|||||||
all_self_attns: Optional[tuple[mx.array, ...]] = () if x.output_attentions else None
|
all_self_attns: Optional[tuple[mx.array, ...]] = () if x.output_attentions else None
|
||||||
hidden_states = x.hidden_states
|
hidden_states = x.hidden_states
|
||||||
|
|
||||||
for layer_i, decoder_layer in enumerate(self.layers):
|
for decoder_layer in self.layers:
|
||||||
if x.output_hidden_states:
|
if x.output_hidden_states:
|
||||||
assert all_hidden_states is not None
|
assert all_hidden_states is not None
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
@ -1590,8 +1597,6 @@ class Model(PlamoPreTrainedModel):
|
|||||||
if not config.tie_word_embeddings:
|
if not config.tie_word_embeddings:
|
||||||
self.lm_head: nn.Module = nn.Linear(config.hidden_size, vocab_size, bias=False)
|
self.lm_head: nn.Module = nn.Linear(config.hidden_size, vocab_size, bias=False)
|
||||||
|
|
||||||
self._cache: Optional[PlamoCache] = None
|
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
# self.post_init()
|
# self.post_init()
|
||||||
|
|
||||||
@ -1623,16 +1628,12 @@ class Model(PlamoPreTrainedModel):
|
|||||||
return PlamoCache(self.config)
|
return PlamoCache(self.config)
|
||||||
|
|
||||||
def __call__(self, inputs: mx.array, cache: PlamoCache | None = None) -> mx.array:
|
def __call__(self, inputs: mx.array, cache: PlamoCache | None = None) -> mx.array:
|
||||||
if self._cache is not None:
|
model_inputs = self.prepare_inputs_for_generation(
|
||||||
plamo_cache = self._cache
|
|
||||||
else:
|
|
||||||
plamo_cache = None
|
|
||||||
output = self.forward(
|
|
||||||
input_ids=inputs,
|
input_ids=inputs,
|
||||||
past_key_values=plamo_cache,
|
past_key_values=cache,
|
||||||
use_cache=self.config.use_cache,
|
use_cache=self.config.use_cache,
|
||||||
return_dict=True,
|
|
||||||
)
|
)
|
||||||
|
output = self.forward(**model_inputs)
|
||||||
if not isinstance(output, CausalLMOutputWithPast):
|
if not isinstance(output, CausalLMOutputWithPast):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unexpected output type for causal language model: {type(output)} != CausalLMOutputWithPast"
|
f"Unexpected output type for causal language model: {type(output)} != CausalLMOutputWithPast"
|
||||||
|
Loading…
Reference in New Issue
Block a user