mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Added fixes to speculative-decoding
This commit is contained in:
parent
4304f5aaf5
commit
8f8f9b6991
@ -213,10 +213,10 @@ class TransformerDecoderLayer(nn.Module):
|
||||
memory: mx.array,
|
||||
mask: mx.array,
|
||||
memory_mask: mx.array,
|
||||
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
||||
):
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
|
||||
y = self.ln1(x)
|
||||
y, cache = self.self_attention(y, y, y, mask, cache)
|
||||
y, new_cache = self.self_attention(y, y, y, mask, cache)
|
||||
x = x + y
|
||||
|
||||
y = self.ln2(x)
|
||||
@ -227,7 +227,7 @@ class TransformerDecoderLayer(nn.Module):
|
||||
y = self.dense(y)
|
||||
x = x + y
|
||||
|
||||
return x, cache
|
||||
return x, new_cache
|
||||
|
||||
|
||||
def create_additive_causal_mask(N: int, offset: int = 0):
|
||||
|
Loading…
Reference in New Issue
Block a user