mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Added fixes to speculative-decoding
This commit is contained in:
parent
4304f5aaf5
commit
8f8f9b6991
@ -213,10 +213,10 @@ class TransformerDecoderLayer(nn.Module):
|
|||||||
memory: mx.array,
|
memory: mx.array,
|
||||||
mask: mx.array,
|
mask: mx.array,
|
||||||
memory_mask: mx.array,
|
memory_mask: mx.array,
|
||||||
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
):
|
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
|
||||||
y = self.ln1(x)
|
y = self.ln1(x)
|
||||||
y, cache = self.self_attention(y, y, y, mask, cache)
|
y, new_cache = self.self_attention(y, y, y, mask, cache)
|
||||||
x = x + y
|
x = x + y
|
||||||
|
|
||||||
y = self.ln2(x)
|
y = self.ln2(x)
|
||||||
@ -227,7 +227,7 @@ class TransformerDecoderLayer(nn.Module):
|
|||||||
y = self.dense(y)
|
y = self.dense(y)
|
||||||
x = x + y
|
x = x + y
|
||||||
|
|
||||||
return x, cache
|
return x, new_cache
|
||||||
|
|
||||||
|
|
||||||
def create_additive_causal_mask(N: int, offset: int = 0):
|
def create_additive_causal_mask(N: int, offset: int = 0):
|
||||||
|
Loading…
Reference in New Issue
Block a user