Made llama and mistral files mypy compatible (#1359)

* Made mypy compatible

* reformatted

* Added more fixes

* Added fixes to speculative-decoding

* Fixes

* fix circle

* revert some stuff

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Param Thakkar
2025-04-24 02:53:46 +05:30
committed by GitHub
parent c52cc748f8
commit 4c9f9f9be7
10 changed files with 32 additions and 29 deletions

View File

@@ -213,10 +213,10 @@ class TransformerDecoderLayer(nn.Module):
memory: mx.array,
mask: mx.array,
memory_mask: mx.array,
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
):
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
y = self.ln1(x)
y, cache = self.self_attention(y, y, y, mask, cache)
y, new_cache = self.self_attention(y, y, y, mask, cache)
x = x + y
y = self.ln2(x)
@@ -227,7 +227,7 @@ class TransformerDecoderLayer(nn.Module):
y = self.dense(y)
x = x + y
return x, cache
return x, new_cache
def create_additive_causal_mask(N: int, offset: int = 0):