mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-04 15:54:34 +08:00
pre_commit formatting
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_map
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
from mistral import load_model
|
||||
from mlx.utils import tree_map
|
||||
|
||||
|
||||
class PromptLookupDecoder:
|
||||
def __init__(self, model: str) -> None:
|
||||
@@ -62,10 +64,18 @@ class PromptLookupDecoder:
|
||||
of only drafting a few tokens.
|
||||
- We exit upon the first match. This avoids the need to rank matches.
|
||||
"""
|
||||
def prompt_lookup(self, prompt: str, max_tokens: int, n_draft: int,
|
||||
ngram_max: int, ngram_min: int, temp: float, seed: int,
|
||||
color: bool):
|
||||
|
||||
def prompt_lookup(
|
||||
self,
|
||||
prompt: str,
|
||||
max_tokens: int,
|
||||
n_draft: int,
|
||||
ngram_max: int,
|
||||
ngram_min: int,
|
||||
temp: float,
|
||||
seed: int,
|
||||
color: bool,
|
||||
):
|
||||
def sample(logits):
|
||||
if temp == 0:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
@@ -127,9 +137,7 @@ class PromptLookupDecoder:
|
||||
# Rewind the cache for unaccepted tokens:
|
||||
if (num_to_truncate := draft_tokens.size - num_to_accept) > 0:
|
||||
if num_to_truncate < cache[0][0].shape[2]:
|
||||
cache = tree_map(
|
||||
lambda x: x[:, :, :-num_to_truncate, :], cache
|
||||
)
|
||||
cache = tree_map(lambda x: x[:, :, :-num_to_truncate, :], cache)
|
||||
else:
|
||||
cache = [None] * len(self.model.layers)
|
||||
|
||||
@@ -139,8 +147,12 @@ class PromptLookupDecoder:
|
||||
for t in new_tokens.tolist()[:-1]:
|
||||
if t == self.tokenizer.eos_id:
|
||||
break
|
||||
if (color):
|
||||
print("\033[34m" + self.tokenizer.decode([t]) + "\033[30m", end="", flush=True)
|
||||
if color:
|
||||
print(
|
||||
"\033[34m" + self.tokenizer.decode([t]) + "\033[30m",
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
print(self.tokenizer.decode([t]), end="", flush=True)
|
||||
|
||||
@@ -155,7 +167,13 @@ class PromptLookupDecoder:
|
||||
end = time.time()
|
||||
print()
|
||||
print("=== PROMPT EVAL IN", round(prompt_time, 2), "SECONDS ===")
|
||||
print("=== GENERATED", n_generated, "TOKENS IN", round(end - start, 2), "SECONDS ===")
|
||||
print(
|
||||
"=== GENERATED",
|
||||
n_generated,
|
||||
"TOKENS IN",
|
||||
round(end - start, 2),
|
||||
"SECONDS ===",
|
||||
)
|
||||
print("=== ACCEPTED", n_accepted, "DRAFT TOKENS ===")
|
||||
print("=== ACCEPT", round(n_accepted / n_generated * 100, 2), "% ===")
|
||||
print("=== DECODING STEPS", n_decoding_steps, "===")
|
||||
@@ -205,17 +223,9 @@ if __name__ == "__main__":
|
||||
type=float,
|
||||
default=0.0,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=0,
|
||||
help="The PRNG seed"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--color",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Color the accepted draft tokens"
|
||||
"--color", type=bool, default=False, help="Color the accepted draft tokens"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
@@ -233,7 +243,5 @@ if __name__ == "__main__":
|
||||
args.ngram_min,
|
||||
args.temp,
|
||||
args.seed,
|
||||
args.color
|
||||
args.color,
|
||||
)
|
||||
|
||||
|
||||
|
@@ -19,6 +19,7 @@ def create_additive_causal_mask(N: int, offset: int = 0, dtype: mx.Dtype = mx.fl
|
||||
mask = mask.astype(dtype) * -1e9
|
||||
return mask
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
dim: int
|
||||
|
@@ -1,9 +0,0 @@
|
||||
from decoder import PromptLookupDecoder
|
||||
|
||||
prompt = "[INST] Repeat the following phrase 10 times: 'The quick brown fox jumps over the lazy dog.'. Don't say antyhing else. [/INST] "
|
||||
|
||||
engine = PromptLookupDecoder("mlx_model")
|
||||
|
||||
engine.generate(prompt, 250)
|
||||
|
||||
engine.prompt_lookup(prompt, 250, 10, 3, 1, 0.0, 0, True)
|
Reference in New Issue
Block a user