Added fixes to speculative-decoding

This commit is contained in:
paramthakkar123 2025-04-08 23:26:35 +05:30
parent 4304f5aaf5
commit 8f8f9b6991

View File

@ -213,10 +213,10 @@ class TransformerDecoderLayer(nn.Module):
memory: mx.array,
mask: mx.array,
memory_mask: mx.array,
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
):
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
y = self.ln1(x)
y, cache = self.self_attention(y, y, y, mask, cache)
y, new_cache = self.self_attention(y, y, y, mask, cache)
x = x + y
y = self.ln2(x)
@ -227,7 +227,7 @@ class TransformerDecoderLayer(nn.Module):
y = self.dense(y)
x = x + y
return x, cache
return x, new_cache
def create_additive_causal_mask(N: int, offset: int = 0):