From 8f8f9b699136730f646d5e5a273edffc1dde965b Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Tue, 8 Apr 2025 23:26:35 +0530 Subject: [PATCH] Added fixes to speculative-decoding --- llms/speculative_decoding/model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/llms/speculative_decoding/model.py b/llms/speculative_decoding/model.py index c310b943..d30daa97 100644 --- a/llms/speculative_decoding/model.py +++ b/llms/speculative_decoding/model.py @@ -213,10 +213,10 @@ class TransformerDecoderLayer(nn.Module): memory: mx.array, mask: mx.array, memory_mask: mx.array, - cache: Optional[List[Tuple[mx.array, mx.array]]] = None, - ): + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: y = self.ln1(x) - y, cache = self.self_attention(y, y, y, mask, cache) + y, new_cache = self.self_attention(y, y, y, mask, cache) x = x + y y = self.ln2(x) @@ -227,7 +227,7 @@ class TransformerDecoderLayer(nn.Module): y = self.dense(y) x = x + y - return x, cache + return x, new_cache def create_additive_causal_mask(N: int, offset: int = 0):