Made llama and mistral files mypy compatible (#1359)

* Made mypy compatible

* reformatted

* Added more fixes

* Added fixes to speculative-decoding

* Fixes

* fix circle

* revert some stuff

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Param Thakkar
2025-04-24 02:53:46 +05:30
committed by GitHub
parent c52cc748f8
commit 4c9f9f9be7
10 changed files with 32 additions and 29 deletions

View File

@@ -160,12 +160,12 @@ class SpeculativeDecoder:
)
n_accepted += num_to_accept
n_draft += draft_tokens.size
n_draft += len(draft_tokens)
# Rewind the cache for unaccepted tokens:
if (n := draft_tokens.size) > num_to_accept:
self.draft_model.truncate_cache(n - new_tokens.size)
self.model.truncate_cache(n - new_tokens.size + 1)
if (n := len(draft_tokens)) > num_to_accept:
self.draft_model.truncate_cache(n - len(new_tokens))
self.model.truncate_cache(n - len(new_tokens) + 1)
n_steps += 1
@@ -181,7 +181,7 @@ class SpeculativeDecoder:
if ntoks >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
break
draft_inputs = new_tokens[max(new_tokens.size - 2, 0) :]
draft_inputs = new_tokens[max(len(new_tokens) - 2, 0) :]
inputs = draft_inputs[-1:]
print(self.tokenizer.decode(outputs)[skip:], end="", flush=True)

View File

@@ -213,10 +213,10 @@ class TransformerDecoderLayer(nn.Module):
memory: mx.array,
mask: mx.array,
memory_mask: mx.array,
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
):
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
y = self.ln1(x)
y, cache = self.self_attention(y, y, y, mask, cache)
y, new_cache = self.self_attention(y, y, y, mask, cache)
x = x + y
y = self.ln2(x)
@@ -227,7 +227,7 @@ class TransformerDecoderLayer(nn.Module):
y = self.dense(y)
x = x + y
return x, cache
return x, new_cache
def create_additive_causal_mask(N: int, offset: int = 0):