pre_commit formatting

This commit is contained in:
Leon Ericsson
2023-12-28 22:23:29 +01:00
parent cb4464bb7b
commit c73bec5598
3 changed files with 58 additions and 58 deletions

View File

@@ -1,8 +1,10 @@
import mlx.core as mx
from mlx.utils import tree_map
import argparse
import time
import mlx.core as mx
from mistral import load_model
from mlx.utils import tree_map
class PromptLookupDecoder:
def __init__(self, model: str) -> None:
@@ -62,10 +64,18 @@ class PromptLookupDecoder:
of only drafting a few tokens.
- We exit upon the first match. This avoids the need to rank matches.
"""
def prompt_lookup(self, prompt: str, max_tokens: int, n_draft: int,
ngram_max: int, ngram_min: int, temp: float, seed: int,
color: bool):
def prompt_lookup(
self,
prompt: str,
max_tokens: int,
n_draft: int,
ngram_max: int,
ngram_min: int,
temp: float,
seed: int,
color: bool,
):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
@@ -127,9 +137,7 @@ class PromptLookupDecoder:
# Rewind the cache for unaccepted tokens:
if (num_to_truncate := draft_tokens.size - num_to_accept) > 0:
if num_to_truncate < cache[0][0].shape[2]:
cache = tree_map(
lambda x: x[:, :, :-num_to_truncate, :], cache
)
cache = tree_map(lambda x: x[:, :, :-num_to_truncate, :], cache)
else:
cache = [None] * len(self.model.layers)
@@ -139,8 +147,12 @@ class PromptLookupDecoder:
for t in new_tokens.tolist()[:-1]:
if t == self.tokenizer.eos_id:
break
if (color):
print("\033[34m" + self.tokenizer.decode([t]) + "\033[30m", end="", flush=True)
if color:
print(
"\033[34m" + self.tokenizer.decode([t]) + "\033[30m",
end="",
flush=True,
)
else:
print(self.tokenizer.decode([t]), end="", flush=True)
@@ -155,7 +167,13 @@ class PromptLookupDecoder:
end = time.time()
print()
print("=== PROMPT EVAL IN", round(prompt_time, 2), "SECONDS ===")
print("=== GENERATED", n_generated, "TOKENS IN", round(end - start, 2), "SECONDS ===")
print(
"=== GENERATED",
n_generated,
"TOKENS IN",
round(end - start, 2),
"SECONDS ===",
)
print("=== ACCEPTED", n_accepted, "DRAFT TOKENS ===")
print("=== ACCEPT", round(n_accepted / n_generated * 100, 2), "% ===")
print("=== DECODING STEPS", n_decoding_steps, "===")
@@ -205,17 +223,9 @@ if __name__ == "__main__":
type=float,
default=0.0,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
parser.add_argument(
"--seed",
type=int,
default=0,
help="The PRNG seed"
)
parser.add_argument(
"--color",
type=bool,
default=False,
help="Color the accepted draft tokens"
"--color", type=bool, default=False, help="Color the accepted draft tokens"
)
args = parser.parse_args()
@@ -233,7 +243,5 @@ if __name__ == "__main__":
args.ngram_min,
args.temp,
args.seed,
args.color
args.color,
)

View File

@@ -19,6 +19,7 @@ def create_additive_causal_mask(N: int, offset: int = 0, dtype: mx.Dtype = mx.fl
mask = mask.astype(dtype) * -1e9
return mask
@dataclass
class ModelArgs:
dim: int

View File

@@ -1,9 +0,0 @@
from decoder import PromptLookupDecoder
prompt = "[INST] Repeat the following phrase 10 times: 'The quick brown fox jumps over the lazy dog.'. Don't say antyhing else. [/INST] "
engine = PromptLookupDecoder("mlx_model")
engine.generate(prompt, 250)
engine.prompt_lookup(prompt, 250, 10, 3, 1, 0.0, 0, True)