pre_commit formatting

This commit is contained in:
Leon Ericsson
2023-12-28 22:23:29 +01:00
parent cb4464bb7b
commit c73bec5598
3 changed files with 58 additions and 58 deletions

View File

@@ -1,15 +1,17 @@
import mlx.core as mx
from mlx.utils import tree_map
import argparse
import time
import mlx.core as mx
from mistral import load_model
from mlx.utils import tree_map
class PromptLookupDecoder:
def __init__(self, model: str) -> None:
model, tokenizer = load_model(model)
self.model = model
self.model = model
self.tokenizer = tokenizer
def _generate(
self,
x: mx.array,
@@ -50,7 +52,7 @@ class PromptLookupDecoder:
run_time = time.time() - start
print()
print(f"=== GENERATED {n + 1} TOKENS IN {run_time} SECONDS ===")
"""
Considerations:
- If a match is found but we can't draft n_draft tokens, do we draft as
@@ -62,16 +64,24 @@ class PromptLookupDecoder:
of only drafting a few tokens.
- We exit upon the first match. This avoids the need to rank matches.
"""
def prompt_lookup(self, prompt: str, max_tokens: int, n_draft: int,
ngram_max: int, ngram_min: int, temp: float, seed: int,
color: bool):
def prompt_lookup(
self,
prompt: str,
max_tokens: int,
n_draft: int,
ngram_max: int,
ngram_min: int,
temp: float,
seed: int,
color: bool,
):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
mx.random.seed(seed)
print("[INFO] Starting generation...")
@@ -82,7 +92,7 @@ class PromptLookupDecoder:
# prefill model
logit, cache = self.model(prompt[None])
token = sample(logit[:, -1, :])
token = sample(logit[:, -1, :])
tokens = mx.concatenate([tokens, token])
prompt_time = time.time() - start
print(self.tokenizer.decode(token.tolist()), end="", flush=True)
@@ -97,18 +107,18 @@ class PromptLookupDecoder:
def generate_draft(input_ids):
input_length = input_ids.size
for ngram_size in range(ngram_max, ngram_min, -1):
for ngram_size in range(ngram_max, ngram_min, -1):
ngram = input_ids[-ngram_size:]
for i in range(input_length - ngram_size):
if mx.all(input_ids[i:i+ngram_size] == ngram):
if mx.all(input_ids[i : i + ngram_size] == ngram):
start_idx = i + ngram_size
end_idx = start_idx + n_draft
if start_idx < input_length - ngram_size:
return input_ids[start_idx:end_idx]
return mx.array([], dtype=mx.uint32)
draft_tokens = generate_draft(tokens)
n_drafted += draft_tokens.size
@@ -117,7 +127,7 @@ class PromptLookupDecoder:
logits, cache = self.model(verify_tokens[None], cache=cache)
sampled = sample(logits).squeeze(0)
# Only keep samples that match the draft.
# Only keep samples that match the draft.
equal_toks = sampled[:-1] == draft_tokens
num_to_accept = (equal_toks.tolist() + [False]).index(False)
new_tokens = sampled[: max(1, num_to_accept + 1)]
@@ -127,20 +137,22 @@ class PromptLookupDecoder:
# Rewind the cache for unaccepted tokens:
if (num_to_truncate := draft_tokens.size - num_to_accept) > 0:
if num_to_truncate < cache[0][0].shape[2]:
cache = tree_map(
lambda x: x[:, :, :-num_to_truncate, :], cache
)
cache = tree_map(lambda x: x[:, :, :-num_to_truncate, :], cache)
else:
cache = [None] * len(self.model.layers)
n_decoding_steps += 1
# Check stop decodig criteria:
for t in new_tokens.tolist()[:-1]:
if t == self.tokenizer.eos_id:
break
if (color):
print("\033[34m" + self.tokenizer.decode([t]) + "\033[30m", end="", flush=True)
if color:
print(
"\033[34m" + self.tokenizer.decode([t]) + "\033[30m",
end="",
flush=True,
)
else:
print(self.tokenizer.decode([t]), end="", flush=True)
@@ -155,11 +167,17 @@ class PromptLookupDecoder:
end = time.time()
print()
print("=== PROMPT EVAL IN", round(prompt_time, 2), "SECONDS ===")
print("=== GENERATED", n_generated, "TOKENS IN", round(end - start, 2), "SECONDS ===")
print(
"=== GENERATED",
n_generated,
"TOKENS IN",
round(end - start, 2),
"SECONDS ===",
)
print("=== ACCEPTED", n_accepted, "DRAFT TOKENS ===")
print("=== ACCEPT", round(n_accepted/n_generated * 100, 2), "% ===")
print("=== ACCEPT", round(n_accepted / n_generated * 100, 2), "% ===")
print("=== DECODING STEPS", n_decoding_steps, "===")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Mistral inference script")
@@ -199,23 +217,15 @@ if __name__ == "__main__":
default=1,
help="Minimum ngrams to match against input during prompt lookup",
)
parser.add_argument(
parser.add_argument(
"--temp",
help="The sampling temperature.",
type=float,
default=0.0,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
parser.add_argument(
"--seed",
type=int,
default=0,
help="The PRNG seed"
)
parser.add_argument(
"--color",
type=bool,
default=False,
help="Color the accepted draft tokens"
"--color", type=bool, default=False, help="Color the accepted draft tokens"
)
args = parser.parse_args()
@@ -226,14 +236,12 @@ if __name__ == "__main__":
engine = PromptLookupDecoder(args.model_path)
engine.prompt_lookup(
args.prompt,
args.max_tokens,
args.n_draft,
args.ngram_max,
args.ngram_min,
args.temp,
args.prompt,
args.max_tokens,
args.n_draft,
args.ngram_max,
args.ngram_min,
args.temp,
args.seed,
args.color
)
args.color,
)

View File

@@ -19,6 +19,7 @@ def create_additive_causal_mask(N: int, offset: int = 0, dtype: mx.Dtype = mx.fl
mask = mask.astype(dtype) * -1e9
return mask
@dataclass
class ModelArgs:
dim: int
@@ -99,7 +100,7 @@ class Attention(nn.Module):
if mask is not None:
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.wo(output), (keys, values)
@@ -169,7 +170,7 @@ class Mistral(nn.Module):
mask = mask.astype(self.tok_embeddings.weight.dtype)
else:
mask = None
x = self.tok_embeddings(x)
if cache is None:
@@ -207,7 +208,7 @@ class Tokenizer:
if t and self._model.id_to_piece(t[0])[0] == self._sep:
return " " + out
return out
def load_model(folder: str):
model_path = Path(folder)
@@ -225,4 +226,4 @@ def load_model(folder: str):
nn.QuantizedLinear.quantize_module(model, **quantization)
model.update(weights)
mx.eval(model.parameters())
return model, tokenizer
return model, tokenizer

View File

@@ -1,9 +0,0 @@
from decoder import PromptLookupDecoder
prompt = "[INST] Repeat the following phrase 10 times: 'The quick brown fox jumps over the lazy dog.'. Don't say antyhing else. [/INST] "
engine = PromptLookupDecoder("mlx_model")
engine.generate(prompt, 250)
engine.prompt_lookup(prompt, 250, 10, 3, 1, 0.0, 0, True)