updated implementation to match #149

This commit is contained in:
Leon Ericsson 2023-12-28 22:13:47 +01:00
parent d2d3b1fbf0
commit 3c7e28dd61
9 changed files with 277 additions and 328 deletions

View File

@ -1,209 +0,0 @@
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten, tree_map
import argparse
import time
import json
from typing import List, Optional, Tuple
from mistral import Mistral, Tokenizer, ModelArgs
from pathlib import Path
class MistralEngine:
def __init__(self, model: str) -> None:
model, tokenizer = self.load_model(model)
self.model = model
self.tokenizer = tokenizer
def load_model(self, folder: str):
model_path = Path(folder)
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
config.pop("sliding_window", None)
config.pop("model_type", None)
quantization = config.pop("quantization", None)
model_args = ModelArgs(**config)
weights = mx.load(str(model_path / "weights.npz"))
weights = tree_unflatten(list(weights.items()))
model = Mistral(model_args)
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
model.update(weights)
mx.eval(model.parameters())
return model, tokenizer
# note to self: do something about 'self'
"""
Considerations:
- If a match is found but we can't draft n_draft tokens, do we draft as
many as we can or check for a match with a smaller ngram size?
- How do we choose if there are multiple matches?
This implementation:
- Ignores a match if we can't draft n_draft tokens. This avoids the risk
of only drafting a few tokens.
- We exit upon the first match. This avoids the need to rank matches.
"""
def prompt_lookup(self, input_ids, ngram_max, ngram_min, n_draft):
input_length = input_ids.size
for ngram_size in range(ngram_max, ngram_min, -1):
ngram = input_ids[0, -ngram_size:]
for i in range(input_length - ngram_size):
if mx.all(input_ids[0, i:i+ngram_size] == ngram):
start_idx = i + ngram_size
end_idx = start_idx + n_draft
if start_idx < input_length - ngram_size:
return input_ids[0, start_idx:end_idx]
return mx.array([])
def generate(self, prompt: str, max_tokens: int, n_draft: int, ngram_max: int, ngram_min: int, temp: float, seed: int):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
mx.random.seed(seed)
print("[INFO] Starting generation...")
tic = time.time()
print(prompt, end="", flush=True)
prompt = mx.array(self.tokenizer.encode(prompt))
tokens = prompt # will store all tokens generated (for prompt lookup)
# prefill model
logit, cache = self.model(prompt[None])
token = sample(logit[:, -1, :])
tokens = mx.concatenate([tokens, token])
n_drafted = 0
n_accepted = 0
n_generated = 1
n_past = prompt.size
while True:
# For each decoding step: generate n_draft tokens by searching the prompt
draft_tokens = self.prompt_lookup(tokens, ngram_max, ngram_min, n_draft)
n_drafted += draft_tokens.size
verify_tokens = mx.concatenate([tokens[-1], draft_tokens])
logits, cache = self.model(verify_tokens[None, :-1], cache=cache)
logits = logits[:, :-1, :]
sampled = sample(logits)
num_to_accept = 0
for i in range(n_draft):
if mx.all(sampled[:, i] == draft_tokens[:, i]):
num_to_accept += 1
else:
break
n_past += num_to_accept
n_accepted += num_to_accept
n_generated += (1 + num_to_accept)
accepted_tokens = sampled[:, :num_to_accept + 1]
tokens = mx.concatenate([tokens, accepted_tokens], axis=1)
mx.eval(accepted_tokens)
s = self.tokenizer.decode([t.item() for t in accepted_tokens])
print(s, end="", flush=True)
# truncate kv cache to keep only accepted tokens
# self.model.truncate_kv_cache(n - num_to_accept)
cache_length = cache[0][0].shape[2]
num_to_truncate = min(num_to_truncate, cache_length)
if num_to_truncate == 0:
pass
else:
cache = tree_map(lambda x: x[:, :, :-num_to_truncate, :], cache)
if n_accepted >= max_tokens or mx.any(accepted_tokens == self.tokenizer.eos_token_id):
break
mx.eval(tokens)
s = self.tokenizer.decode([t.item() for t in tokens])
print(s, flush=True)
print("------")
generation_tps = ntoks / (time.time() - tic)
print(
f"Tokens per second: prompt {prompt_tps:.3f}, "
f"generation {generation_tps:.3f}"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Mistral inference script")
parser.add_argument(
"--model-path",
type=str,
default="mlx_model",
help="The path to the model weights and tokenizer",
)
parser.add_argument(
"--prompt",
help="The message to be processed by the model",
default="This is a test. This is a test. This is a",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=100,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--n-draft",
type=int,
default=10,
help="Number of draft tokens to generate upon prompt lookup match",
)
parser.add_argument(
"--ngram-max",
type=int,
default=3,
help="Maximum ngrams to match against input during prompt lookup",
)
parser.add_argument(
"--ngram-min",
type=int,
default=1,
help="Minimum ngrams to match against input during prompt lookup",
)
parser.add_argument(
"--temp",
help="The sampling temperature.",
type=float,
default=0.0,
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="The PRNG seed"
)
args = parser.parse_args()
mx.random.seed(args.seed)
print("[INFO] Loading model from disk.")
engine = MistralEngine(args.model_path)
engine.generate(
args.prompt,
args.max_tokens,
args.n_draft,
args.ngram_max,
args.ngram_min,
args.temp,
args.seed
)

View File

@ -1,117 +0,0 @@
# Copyright © 2023 Apple Inc.
import unittest
import mistral
import mlx.core as mx
from mlx.utils import tree_map
class TestMistral(unittest.TestCase):
def test_model(self):
vocab_size = 100
L = 32
args = mistral.ModelArgs(
dim=128,
n_layers=2,
head_dim=32,
hidden_dim=256,
n_heads=4,
n_kv_heads=4,
norm_eps=1e-3,
vocab_size=vocab_size,
)
model = mistral.Mistral(args)
inputs = mx.random.randint(0, vocab_size, (L,))
logits, cache = model(inputs[None])
self.assertEqual(logits.shape, [1, L, vocab_size])
self.assertEqual(logits.dtype, mx.float32)
self.assertEqual(len(cache), args.n_layers)
params = tree_map(lambda p: p.astype(mx.float16), model.parameters())
model.update(params)
logits, _ = model(inputs[None])
self.assertEqual(logits.dtype, mx.float16)
def test_generate(self):
model, tokenizer = mistral.load_model("mistral-7B-v0.1")
prompt = mx.array(tokenizer.encode("This is a test"))
tokens = [t for t, _ in zip(mistral.generate(prompt, model), range(30))]
mx.eval(tokens)
tokens = [t.item() for t in tokens]
expected = [
302,
272,
11843,
11837,
1587,
28723,
851,
349,
865,
264,
1369,
28723,
13,
13,
3381,
456,
654,
264,
1353,
11843,
28725,
368,
682,
347,
2240,
767,
298,
511,
28723,
13,
]
self.assertEqual(tokens, expected)
def benchmark(self):
import time
model, tokenizer = mistral.load_model("mistral-7B-v0.1")
prompt = mx.random.randint(0, model.vocab_size, (128,))
# warmup
for _ in range(2):
generator = mistral.generate(prompt, model)
mx.eval(next(generator))
tic = time.time()
its = 5
for _ in range(its):
generator = mistral.generate(prompt, model)
mx.eval(next(generator))
toc = time.time()
tps = its * prompt.size / (toc - tic)
print(f"Prompt processing: {tps:.2f} tokens per second")
# warmup
for _ in range(2):
tokens = [t for t, _ in zip(mistral.generate(prompt, model), range(101))]
mx.eval(tokens)
time_total = 0.0
its = 2
for _ in range(its):
generator = mistral.generate(prompt, model)
mx.eval(next(generator))
tic = time.time()
tokens = [t for t, _ in zip(generator, range(100))]
mx.eval(tokens)
time_total += time.time() - tic
tps = len(tokens) * its / time_total
print(f"Token generation: {tps:.3f} tokens per second")
if __name__ == "__main__":
unittest.main()

263
lookup_decoding/decoder.py Normal file
View File

@ -0,0 +1,263 @@
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten, tree_map
import argparse
import time
import json
from mistral import Mistral, Tokenizer, ModelArgs
from pathlib import Path
class PromptLookupDecoder:
def __init__(self, model: str) -> None:
model, tokenizer = self.load_model(model)
self.model = model
self.tokenizer = tokenizer
def load_model(self, folder: str):
model_path = Path(folder)
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
with open(model_path / "config.json", "r") as f:
config = json.loads(f.read())
config.pop("sliding_window", None)
config.pop("model_type", None)
quantization = config.pop("quantization", None)
model_args = ModelArgs(**config)
weights = mx.load(str(model_path / "weights.npz"))
weights = tree_unflatten(list(weights.items()))
model = Mistral(model_args)
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
model.update(weights)
mx.eval(model.parameters())
return model, tokenizer
def _generate(
self,
x: mx.array,
temp: float = 0.0,
):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
logits, cache = self.model(x[None])
y = sample(logits[:, -1, :])
yield y
while True:
logits, cache = self.model(x[None, :], cache, next_token_only=True)
x = sample(logits)
yield x
# Generate without prompt lookup decoding (for testing)
def generate(
self,
prompt,
max_tokens: int = 100,
temp: float = 0.0,
):
print("[INFO] Starting generation...")
print(prompt, end="", flush=True)
x = mx.array(self.tokenizer.encode(prompt), mx.uint32)
start = time.time()
for token, n in zip(self._generate(x, temp), range(max_tokens)):
token = token.item()
if token == self.tokenizer.eos_id:
break
print(self.tokenizer.decode([token]), end="", flush=True)
run_time = time.time() - start
print()
print(f"=== GENERATED {n + 1} TOKENS IN {run_time} SECONDS ===")
"""
Considerations:
- If a match is found but we can't draft n_draft tokens, do we draft as
many as we can or check for a match with a smaller ngram size?
- How do we choose if there are multiple matches?
This implementation:
- Ignores a match if we can't draft n_draft tokens. This avoids the risk
of only drafting a few tokens.
- We exit upon the first match. This avoids the need to rank matches.
"""
def prompt_lookup(self, prompt: str, max_tokens: int, n_draft: int,
ngram_max: int, ngram_min: int, temp: float, seed: int,
color: bool):
def sample(logits):
if temp == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temp))
mx.random.seed(seed)
print("[INFO] Starting generation...")
start = time.time()
print(prompt, end="", flush=True)
prompt = mx.array(self.tokenizer.encode(prompt), mx.uint32)
tokens = prompt
# prefill model
logit, cache = self.model(prompt[None])
token = sample(logit[:, -1, :])
tokens = mx.concatenate([tokens, token])
prompt_time = time.time() - start
print(self.tokenizer.decode(token.tolist()), end="", flush=True)
n_drafted = 0
n_accepted = 0
n_generated = 1
n_decoding_steps = 0
while True:
# For each decoding step: generate n_draft tokens by searching the prompt
def generate_draft(input_ids):
input_length = input_ids.size
for ngram_size in range(ngram_max, ngram_min, -1):
ngram = input_ids[-ngram_size:]
for i in range(input_length - ngram_size):
if mx.all(input_ids[i:i+ngram_size] == ngram):
start_idx = i + ngram_size
end_idx = start_idx + n_draft
if start_idx < input_length - ngram_size:
return input_ids[start_idx:end_idx]
return mx.array([], dtype=mx.uint32)
draft_tokens = generate_draft(tokens)
n_drafted += draft_tokens.size
# Verify draft tokens with the last verified token
verify_tokens = mx.concatenate([tokens[-1:], draft_tokens])
logits, cache = self.model(verify_tokens[None], cache=cache)
sampled = sample(logits).squeeze(0)
# Only keep samples that match the draft.
equal_toks = sampled[:-1] == draft_tokens
num_to_accept = (equal_toks.tolist() + [False]).index(False)
new_tokens = sampled[: max(1, num_to_accept + 1)]
n_accepted += num_to_accept
# Rewind the cache for unaccepted tokens:
if (num_to_truncate := draft_tokens.size - num_to_accept) > 0:
if num_to_truncate < cache[0][0].shape[2]:
cache = tree_map(
lambda x: x[:, :, :-num_to_truncate, :], cache
)
else:
cache = [None] * len(self.model.layers)
n_decoding_steps += 1
# Check stop decodig criteria:
for t in new_tokens.tolist()[:-1]:
if t == self.tokenizer.eos_id:
break
if (color):
print("\033[34m" + self.tokenizer.decode([t]) + "\033[30m", end="", flush=True)
else:
print(self.tokenizer.decode([t]), end="", flush=True)
print(self.tokenizer.decode(new_tokens[-1:].tolist()), end="", flush=True)
n_generated += new_tokens.size
if n_generated >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
break
tokens = mx.concatenate([tokens, new_tokens])
end = time.time()
print()
print("=== PROMPT EVAL IN", round(prompt_time, 2), "SECONDS ===")
print("=== GENERATED", n_generated, "TOKENS IN", round(end - start, 2), "SECONDS ===")
print("=== ACCEPTED", n_accepted, "DRAFT TOKENS ===")
print("=== ACCEPT", round(n_accepted/n_generated * 100, 2), "% ===")
print("=== DECODING STEPS", n_decoding_steps, "===")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Mistral inference script")
parser.add_argument(
"--model-path",
type=str,
default="mlx_model",
help="The path to the model weights and tokenizer",
)
parser.add_argument(
"--prompt",
help="The message to be processed by the model",
default="This is a test. This is a test. This is a",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=100,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--n-draft",
type=int,
default=10,
help="Number of draft tokens to generate upon prompt lookup match",
)
parser.add_argument(
"--ngram-max",
type=int,
default=3,
help="Maximum ngrams to match against input during prompt lookup",
)
parser.add_argument(
"--ngram-min",
type=int,
default=1,
help="Minimum ngrams to match against input during prompt lookup",
)
parser.add_argument(
"--temp",
help="The sampling temperature.",
type=float,
default=0.0,
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="The PRNG seed"
)
parser.add_argument(
"--color",
type=bool,
default=True,
help="Color the accepted draft tokens"
)
args = parser.parse_args()
mx.random.seed(args.seed)
print("[INFO] Loading model from disk.")
engine = PromptLookupDecoder(args.model_path)
engine.generate(args.prompt, args.max_tokens, args.temp)
""" engine.prompt_lookup(
args.prompt,
args.max_tokens,
args.n_draft,
args.ngram_max,
args.ngram_min,
args.temp,
args.seed,
args.color
) """

View File

@ -1,5 +1,4 @@
# Copyright © 2023 Apple Inc.
import argparse
import json
import time
@ -158,6 +157,7 @@ class Mistral(nn.Module):
self,
x: mx.array,
cache=None,
next_token_only: bool = False,
):
if cache is not None:
offset = cache[0][0].shape[-2]
@ -178,6 +178,9 @@ class Mistral(nn.Module):
for e, layer in enumerate(self.layers):
x, cache[e] = layer(x, mask, cache[e])
if next_token_only:
x = x[:, -1]
return self.output(self.norm(x)), cache
@ -272,7 +275,7 @@ if __name__ == "__main__":
"--tokens_per_eval",
help="The batch size of tokens to generate.",
type=int,
default=10,
default=1,
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")

9
lookup_decoding/test.py Normal file
View File

@ -0,0 +1,9 @@
from decoder import PromptLookupDecoder
prompt = "[INST] Repeat the following phrase 10 times: 'The quick brown fox jumps over the lazy dog.'. Don't say antyhing else. [/INST] "
engine = PromptLookupDecoder("mlx_model")
engine.generate(prompt, 250)
engine.prompt_lookup(prompt, 250, 10, 3, 1, 0.0, 0, True)