mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-19 11:28:07 +08:00
Made llama and mistral files mypy compatible (#1359)
* Made mypy compatible * reformatted * Added more fixes * Added fixes to speculative-decoding * Fixes * fix circle * revert some stuff --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -160,12 +160,12 @@ class SpeculativeDecoder:
|
||||
)
|
||||
|
||||
n_accepted += num_to_accept
|
||||
n_draft += draft_tokens.size
|
||||
n_draft += len(draft_tokens)
|
||||
|
||||
# Rewind the cache for unaccepted tokens:
|
||||
if (n := draft_tokens.size) > num_to_accept:
|
||||
self.draft_model.truncate_cache(n - new_tokens.size)
|
||||
self.model.truncate_cache(n - new_tokens.size + 1)
|
||||
if (n := len(draft_tokens)) > num_to_accept:
|
||||
self.draft_model.truncate_cache(n - len(new_tokens))
|
||||
self.model.truncate_cache(n - len(new_tokens) + 1)
|
||||
|
||||
n_steps += 1
|
||||
|
||||
@@ -181,7 +181,7 @@ class SpeculativeDecoder:
|
||||
|
||||
if ntoks >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
|
||||
break
|
||||
draft_inputs = new_tokens[max(new_tokens.size - 2, 0) :]
|
||||
draft_inputs = new_tokens[max(len(new_tokens) - 2, 0) :]
|
||||
inputs = draft_inputs[-1:]
|
||||
|
||||
print(self.tokenizer.decode(outputs)[skip:], end="", flush=True)
|
||||
|
@@ -213,10 +213,10 @@ class TransformerDecoderLayer(nn.Module):
|
||||
memory: mx.array,
|
||||
mask: mx.array,
|
||||
memory_mask: mx.array,
|
||||
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
||||
):
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
|
||||
y = self.ln1(x)
|
||||
y, cache = self.self_attention(y, y, y, mask, cache)
|
||||
y, new_cache = self.self_attention(y, y, y, mask, cache)
|
||||
x = x + y
|
||||
|
||||
y = self.ln2(x)
|
||||
@@ -227,7 +227,7 @@ class TransformerDecoderLayer(nn.Module):
|
||||
y = self.dense(y)
|
||||
x = x + y
|
||||
|
||||
return x, cache
|
||||
return x, new_cache
|
||||
|
||||
|
||||
def create_additive_causal_mask(N: int, offset: int = 0):
|
||||
|
Reference in New Issue
Block a user