diff --git a/llms/speculative_decoding/model.py b/llms/speculative_decoding/model.py index c310b943..d30daa97 100644 --- a/llms/speculative_decoding/model.py +++ b/llms/speculative_decoding/model.py @@ -213,10 +213,10 @@ class TransformerDecoderLayer(nn.Module): memory: mx.array, mask: mx.array, memory_mask: mx.array, - cache: Optional[List[Tuple[mx.array, mx.array]]] = None, - ): + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: y = self.ln1(x) - y, cache = self.self_attention(y, y, y, mask, cache) + y, new_cache = self.self_attention(y, y, y, mask, cache) x = x + y y = self.ln2(x) @@ -227,7 +227,7 @@ class TransformerDecoderLayer(nn.Module): y = self.dense(y) x = x + y - return x, cache + return x, new_cache def create_additive_causal_mask(N: int, offset: int = 0):