mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
More cache improvements (#1015)
* fix rotating kv cache for chat use case * reorg + fixes to caching, unify prompt caching across types and use cases for e.g. caching during a chat * nit in chat * fix tests * fix tests * fix tests * docs * chat command * comments + docs * Define meta_state on all Cache implementations * fixes + trim_prompt_cache api * fix default model --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -62,8 +62,8 @@ class Attention(nn.Module):
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
bsz, q_len, _ = hidden_states.shape
|
||||
|
||||
queries = self.q_proj(hidden_states)
|
||||
@@ -127,8 +127,8 @@ class PlamoDecoderLayer(nn.Module):
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> Tuple[Any, ...]:
|
||||
cache: Optional[Any] = None,
|
||||
):
|
||||
# from LlamaDecoder
|
||||
residual = hidden_states
|
||||
|
||||
@@ -169,8 +169,8 @@ class PlamoModel(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[List[Union[Tuple[mx.array, mx.array], None]]] = None,
|
||||
) -> Tuple[mx.array, Optional[List[Union[Tuple[mx.array, mx.array], None]]]]:
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = create_attention_mask(h, cache)
|
||||
@@ -197,19 +197,11 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
out = self.model(inputs, cache)
|
||||
return self.lm_head(out)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_attention_heads // self.args.n_shared_head
|
||||
|
||||
Reference in New Issue
Block a user