mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
pre_commit formatting
This commit is contained in:
@@ -1,15 +1,17 @@
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_map
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
from mistral import load_model
|
||||
from mlx.utils import tree_map
|
||||
|
||||
|
||||
class PromptLookupDecoder:
|
||||
def __init__(self, model: str) -> None:
|
||||
model, tokenizer = load_model(model)
|
||||
self.model = model
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
x: mx.array,
|
||||
@@ -50,7 +52,7 @@ class PromptLookupDecoder:
|
||||
run_time = time.time() - start
|
||||
print()
|
||||
print(f"=== GENERATED {n + 1} TOKENS IN {run_time} SECONDS ===")
|
||||
|
||||
|
||||
"""
|
||||
Considerations:
|
||||
- If a match is found but we can't draft n_draft tokens, do we draft as
|
||||
@@ -62,16 +64,24 @@ class PromptLookupDecoder:
|
||||
of only drafting a few tokens.
|
||||
- We exit upon the first match. This avoids the need to rank matches.
|
||||
"""
|
||||
def prompt_lookup(self, prompt: str, max_tokens: int, n_draft: int,
|
||||
ngram_max: int, ngram_min: int, temp: float, seed: int,
|
||||
color: bool):
|
||||
|
||||
|
||||
def prompt_lookup(
|
||||
self,
|
||||
prompt: str,
|
||||
max_tokens: int,
|
||||
n_draft: int,
|
||||
ngram_max: int,
|
||||
ngram_min: int,
|
||||
temp: float,
|
||||
seed: int,
|
||||
color: bool,
|
||||
):
|
||||
def sample(logits):
|
||||
if temp == 0:
|
||||
return mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
return mx.random.categorical(logits * (1 / temp))
|
||||
|
||||
|
||||
mx.random.seed(seed)
|
||||
|
||||
print("[INFO] Starting generation...")
|
||||
@@ -82,7 +92,7 @@ class PromptLookupDecoder:
|
||||
|
||||
# prefill model
|
||||
logit, cache = self.model(prompt[None])
|
||||
token = sample(logit[:, -1, :])
|
||||
token = sample(logit[:, -1, :])
|
||||
tokens = mx.concatenate([tokens, token])
|
||||
prompt_time = time.time() - start
|
||||
print(self.tokenizer.decode(token.tolist()), end="", flush=True)
|
||||
@@ -97,18 +107,18 @@ class PromptLookupDecoder:
|
||||
def generate_draft(input_ids):
|
||||
input_length = input_ids.size
|
||||
|
||||
for ngram_size in range(ngram_max, ngram_min, -1):
|
||||
for ngram_size in range(ngram_max, ngram_min, -1):
|
||||
ngram = input_ids[-ngram_size:]
|
||||
|
||||
for i in range(input_length - ngram_size):
|
||||
if mx.all(input_ids[i:i+ngram_size] == ngram):
|
||||
if mx.all(input_ids[i : i + ngram_size] == ngram):
|
||||
start_idx = i + ngram_size
|
||||
end_idx = start_idx + n_draft
|
||||
if start_idx < input_length - ngram_size:
|
||||
return input_ids[start_idx:end_idx]
|
||||
|
||||
return mx.array([], dtype=mx.uint32)
|
||||
|
||||
|
||||
draft_tokens = generate_draft(tokens)
|
||||
n_drafted += draft_tokens.size
|
||||
|
||||
@@ -117,7 +127,7 @@ class PromptLookupDecoder:
|
||||
logits, cache = self.model(verify_tokens[None], cache=cache)
|
||||
sampled = sample(logits).squeeze(0)
|
||||
|
||||
# Only keep samples that match the draft.
|
||||
# Only keep samples that match the draft.
|
||||
equal_toks = sampled[:-1] == draft_tokens
|
||||
num_to_accept = (equal_toks.tolist() + [False]).index(False)
|
||||
new_tokens = sampled[: max(1, num_to_accept + 1)]
|
||||
@@ -127,20 +137,22 @@ class PromptLookupDecoder:
|
||||
# Rewind the cache for unaccepted tokens:
|
||||
if (num_to_truncate := draft_tokens.size - num_to_accept) > 0:
|
||||
if num_to_truncate < cache[0][0].shape[2]:
|
||||
cache = tree_map(
|
||||
lambda x: x[:, :, :-num_to_truncate, :], cache
|
||||
)
|
||||
cache = tree_map(lambda x: x[:, :, :-num_to_truncate, :], cache)
|
||||
else:
|
||||
cache = [None] * len(self.model.layers)
|
||||
|
||||
|
||||
n_decoding_steps += 1
|
||||
|
||||
# Check stop decodig criteria:
|
||||
for t in new_tokens.tolist()[:-1]:
|
||||
if t == self.tokenizer.eos_id:
|
||||
break
|
||||
if (color):
|
||||
print("\033[34m" + self.tokenizer.decode([t]) + "\033[30m", end="", flush=True)
|
||||
if color:
|
||||
print(
|
||||
"\033[34m" + self.tokenizer.decode([t]) + "\033[30m",
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
print(self.tokenizer.decode([t]), end="", flush=True)
|
||||
|
||||
@@ -155,11 +167,17 @@ class PromptLookupDecoder:
|
||||
end = time.time()
|
||||
print()
|
||||
print("=== PROMPT EVAL IN", round(prompt_time, 2), "SECONDS ===")
|
||||
print("=== GENERATED", n_generated, "TOKENS IN", round(end - start, 2), "SECONDS ===")
|
||||
print(
|
||||
"=== GENERATED",
|
||||
n_generated,
|
||||
"TOKENS IN",
|
||||
round(end - start, 2),
|
||||
"SECONDS ===",
|
||||
)
|
||||
print("=== ACCEPTED", n_accepted, "DRAFT TOKENS ===")
|
||||
print("=== ACCEPT", round(n_accepted/n_generated * 100, 2), "% ===")
|
||||
print("=== ACCEPT", round(n_accepted / n_generated * 100, 2), "% ===")
|
||||
print("=== DECODING STEPS", n_decoding_steps, "===")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Mistral inference script")
|
||||
@@ -199,23 +217,15 @@ if __name__ == "__main__":
|
||||
default=1,
|
||||
help="Minimum ngrams to match against input during prompt lookup",
|
||||
)
|
||||
parser.add_argument(
|
||||
parser.add_argument(
|
||||
"--temp",
|
||||
help="The sampling temperature.",
|
||||
type=float,
|
||||
default=0.0,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=0,
|
||||
help="The PRNG seed"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--color",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Color the accepted draft tokens"
|
||||
"--color", type=bool, default=False, help="Color the accepted draft tokens"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
@@ -226,14 +236,12 @@ if __name__ == "__main__":
|
||||
engine = PromptLookupDecoder(args.model_path)
|
||||
|
||||
engine.prompt_lookup(
|
||||
args.prompt,
|
||||
args.max_tokens,
|
||||
args.n_draft,
|
||||
args.ngram_max,
|
||||
args.ngram_min,
|
||||
args.temp,
|
||||
args.prompt,
|
||||
args.max_tokens,
|
||||
args.n_draft,
|
||||
args.ngram_max,
|
||||
args.ngram_min,
|
||||
args.temp,
|
||||
args.seed,
|
||||
args.color
|
||||
)
|
||||
|
||||
|
||||
args.color,
|
||||
)
|
||||
|
@@ -19,6 +19,7 @@ def create_additive_causal_mask(N: int, offset: int = 0, dtype: mx.Dtype = mx.fl
|
||||
mask = mask.astype(dtype) * -1e9
|
||||
return mask
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
dim: int
|
||||
@@ -99,7 +100,7 @@ class Attention(nn.Module):
|
||||
|
||||
if mask is not None:
|
||||
scores += mask
|
||||
|
||||
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.wo(output), (keys, values)
|
||||
@@ -169,7 +170,7 @@ class Mistral(nn.Module):
|
||||
mask = mask.astype(self.tok_embeddings.weight.dtype)
|
||||
else:
|
||||
mask = None
|
||||
|
||||
|
||||
x = self.tok_embeddings(x)
|
||||
|
||||
if cache is None:
|
||||
@@ -207,7 +208,7 @@ class Tokenizer:
|
||||
if t and self._model.id_to_piece(t[0])[0] == self._sep:
|
||||
return " " + out
|
||||
return out
|
||||
|
||||
|
||||
|
||||
def load_model(folder: str):
|
||||
model_path = Path(folder)
|
||||
@@ -225,4 +226,4 @@ def load_model(folder: str):
|
||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
||||
model.update(weights)
|
||||
mx.eval(model.parameters())
|
||||
return model, tokenizer
|
||||
return model, tokenizer
|
||||
|
@@ -1,9 +0,0 @@
|
||||
from decoder import PromptLookupDecoder
|
||||
|
||||
prompt = "[INST] Repeat the following phrase 10 times: 'The quick brown fox jumps over the lazy dog.'. Don't say antyhing else. [/INST] "
|
||||
|
||||
engine = PromptLookupDecoder("mlx_model")
|
||||
|
||||
engine.generate(prompt, 250)
|
||||
|
||||
engine.prompt_lookup(prompt, 250, 10, 3, 1, 0.0, 0, True)
|
Reference in New Issue
Block a user