From 3c7e28dd618a3cf18129ffc9ced605b3bb17d08a Mon Sep 17 00:00:00 2001 From: Leon Ericsson Date: Thu, 28 Dec 2023 22:13:47 +0100 Subject: [PATCH] updated implementation to match #149 --- llms/lookup_decoding/decoder.py | 209 -------------- llms/lookup_decoding/test.py | 117 -------- .../.gitignore | 0 .../README.md | 0 .../convert.py | 0 lookup_decoding/decoder.py | 263 ++++++++++++++++++ .../mistral.py | 7 +- .../requirements.txt | 0 lookup_decoding/test.py | 9 + 9 files changed, 277 insertions(+), 328 deletions(-) delete mode 100644 llms/lookup_decoding/decoder.py delete mode 100644 llms/lookup_decoding/test.py rename {llms/lookup_decoding => lookup_decoding}/.gitignore (100%) rename {llms/lookup_decoding => lookup_decoding}/README.md (100%) rename {llms/lookup_decoding => lookup_decoding}/convert.py (100%) create mode 100644 lookup_decoding/decoder.py rename {llms/lookup_decoding => lookup_decoding}/mistral.py (98%) rename {llms/lookup_decoding => lookup_decoding}/requirements.txt (100%) create mode 100644 lookup_decoding/test.py diff --git a/llms/lookup_decoding/decoder.py b/llms/lookup_decoding/decoder.py deleted file mode 100644 index d7f2509e..00000000 --- a/llms/lookup_decoding/decoder.py +++ /dev/null @@ -1,209 +0,0 @@ -import mlx.core as mx -import mlx.nn as nn -from mlx.utils import tree_unflatten, tree_map -import argparse -import time -import json -from typing import List, Optional, Tuple -from mistral import Mistral, Tokenizer, ModelArgs -from pathlib import Path - -class MistralEngine: - - def __init__(self, model: str) -> None: - model, tokenizer = self.load_model(model) - self.model = model - self.tokenizer = tokenizer - - def load_model(self, folder: str): - model_path = Path(folder) - tokenizer = Tokenizer(str(model_path / "tokenizer.model")) - with open(model_path / "config.json", "r") as f: - config = json.loads(f.read()) - config.pop("sliding_window", None) - config.pop("model_type", None) - quantization = config.pop("quantization", None) - model_args = ModelArgs(**config) - weights = mx.load(str(model_path / "weights.npz")) - weights = tree_unflatten(list(weights.items())) - model = Mistral(model_args) - if quantization is not None: - nn.QuantizedLinear.quantize_module(model, **quantization) - model.update(weights) - mx.eval(model.parameters()) - return model, tokenizer - - # note to self: do something about 'self' - """ - Considerations: - - If a match is found but we can't draft n_draft tokens, do we draft as - many as we can or check for a match with a smaller ngram size? - - How do we choose if there are multiple matches? - - This implementation: - - Ignores a match if we can't draft n_draft tokens. This avoids the risk - of only drafting a few tokens. - - We exit upon the first match. This avoids the need to rank matches. - """ - def prompt_lookup(self, input_ids, ngram_max, ngram_min, n_draft): - input_length = input_ids.size - - for ngram_size in range(ngram_max, ngram_min, -1): - ngram = input_ids[0, -ngram_size:] - - for i in range(input_length - ngram_size): - if mx.all(input_ids[0, i:i+ngram_size] == ngram): - start_idx = i + ngram_size - end_idx = start_idx + n_draft - if start_idx < input_length - ngram_size: - return input_ids[0, start_idx:end_idx] - - return mx.array([]) - - - def generate(self, prompt: str, max_tokens: int, n_draft: int, ngram_max: int, ngram_min: int, temp: float, seed: int): - def sample(logits): - if temp == 0: - return mx.argmax(logits, axis=-1) - else: - return mx.random.categorical(logits * (1 / temp)) - - mx.random.seed(seed) - - print("[INFO] Starting generation...") - tic = time.time() - print(prompt, end="", flush=True) - prompt = mx.array(self.tokenizer.encode(prompt)) - tokens = prompt # will store all tokens generated (for prompt lookup) - - # prefill model - logit, cache = self.model(prompt[None]) - token = sample(logit[:, -1, :]) - tokens = mx.concatenate([tokens, token]) - - n_drafted = 0 - n_accepted = 0 - n_generated = 1 - n_past = prompt.size - - while True: - # For each decoding step: generate n_draft tokens by searching the prompt - draft_tokens = self.prompt_lookup(tokens, ngram_max, ngram_min, n_draft) - n_drafted += draft_tokens.size - - - verify_tokens = mx.concatenate([tokens[-1], draft_tokens]) - logits, cache = self.model(verify_tokens[None, :-1], cache=cache) - - - logits = logits[:, :-1, :] - sampled = sample(logits) - - num_to_accept = 0 - for i in range(n_draft): - if mx.all(sampled[:, i] == draft_tokens[:, i]): - num_to_accept += 1 - else: - break - - n_past += num_to_accept - n_accepted += num_to_accept - n_generated += (1 + num_to_accept) - - accepted_tokens = sampled[:, :num_to_accept + 1] - tokens = mx.concatenate([tokens, accepted_tokens], axis=1) - - mx.eval(accepted_tokens) - s = self.tokenizer.decode([t.item() for t in accepted_tokens]) - print(s, end="", flush=True) - - # truncate kv cache to keep only accepted tokens - # self.model.truncate_kv_cache(n - num_to_accept) - cache_length = cache[0][0].shape[2] - num_to_truncate = min(num_to_truncate, cache_length) - if num_to_truncate == 0: - pass - else: - cache = tree_map(lambda x: x[:, :, :-num_to_truncate, :], cache) - - if n_accepted >= max_tokens or mx.any(accepted_tokens == self.tokenizer.eos_token_id): - break - - mx.eval(tokens) - s = self.tokenizer.decode([t.item() for t in tokens]) - print(s, flush=True) - print("------") - generation_tps = ntoks / (time.time() - tic) - print( - f"Tokens per second: prompt {prompt_tps:.3f}, " - f"generation {generation_tps:.3f}" - ) - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Mistral inference script") - parser.add_argument( - "--model-path", - type=str, - default="mlx_model", - help="The path to the model weights and tokenizer", - ) - parser.add_argument( - "--prompt", - help="The message to be processed by the model", - default="This is a test. This is a test. This is a", - ) - parser.add_argument( - "--max-tokens", - "-m", - type=int, - default=100, - help="Maximum number of tokens to generate", - ) - parser.add_argument( - "--n-draft", - type=int, - default=10, - help="Number of draft tokens to generate upon prompt lookup match", - ) - parser.add_argument( - "--ngram-max", - type=int, - default=3, - help="Maximum ngrams to match against input during prompt lookup", - ) - parser.add_argument( - "--ngram-min", - type=int, - default=1, - help="Minimum ngrams to match against input during prompt lookup", - ) - parser.add_argument( - "--temp", - help="The sampling temperature.", - type=float, - default=0.0, - ) - parser.add_argument( - "--seed", - type=int, - default=0, - help="The PRNG seed" - ) - - args = parser.parse_args() - - mx.random.seed(args.seed) - print("[INFO] Loading model from disk.") - - engine = MistralEngine(args.model_path) - engine.generate( - args.prompt, - args.max_tokens, - args.n_draft, - args.ngram_max, - args.ngram_min, - args.temp, - args.seed - ) - - diff --git a/llms/lookup_decoding/test.py b/llms/lookup_decoding/test.py deleted file mode 100644 index 25385626..00000000 --- a/llms/lookup_decoding/test.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright © 2023 Apple Inc. - -import unittest - -import mistral -import mlx.core as mx -from mlx.utils import tree_map - - -class TestMistral(unittest.TestCase): - def test_model(self): - vocab_size = 100 - L = 32 - args = mistral.ModelArgs( - dim=128, - n_layers=2, - head_dim=32, - hidden_dim=256, - n_heads=4, - n_kv_heads=4, - norm_eps=1e-3, - vocab_size=vocab_size, - ) - - model = mistral.Mistral(args) - inputs = mx.random.randint(0, vocab_size, (L,)) - logits, cache = model(inputs[None]) - self.assertEqual(logits.shape, [1, L, vocab_size]) - self.assertEqual(logits.dtype, mx.float32) - self.assertEqual(len(cache), args.n_layers) - - params = tree_map(lambda p: p.astype(mx.float16), model.parameters()) - model.update(params) - logits, _ = model(inputs[None]) - self.assertEqual(logits.dtype, mx.float16) - - def test_generate(self): - model, tokenizer = mistral.load_model("mistral-7B-v0.1") - prompt = mx.array(tokenizer.encode("This is a test")) - tokens = [t for t, _ in zip(mistral.generate(prompt, model), range(30))] - mx.eval(tokens) - tokens = [t.item() for t in tokens] - expected = [ - 302, - 272, - 11843, - 11837, - 1587, - 28723, - 851, - 349, - 865, - 264, - 1369, - 28723, - 13, - 13, - 3381, - 456, - 654, - 264, - 1353, - 11843, - 28725, - 368, - 682, - 347, - 2240, - 767, - 298, - 511, - 28723, - 13, - ] - self.assertEqual(tokens, expected) - - def benchmark(self): - import time - - model, tokenizer = mistral.load_model("mistral-7B-v0.1") - prompt = mx.random.randint(0, model.vocab_size, (128,)) - - # warmup - for _ in range(2): - generator = mistral.generate(prompt, model) - mx.eval(next(generator)) - - tic = time.time() - its = 5 - for _ in range(its): - generator = mistral.generate(prompt, model) - mx.eval(next(generator)) - toc = time.time() - tps = its * prompt.size / (toc - tic) - print(f"Prompt processing: {tps:.2f} tokens per second") - - # warmup - for _ in range(2): - tokens = [t for t, _ in zip(mistral.generate(prompt, model), range(101))] - mx.eval(tokens) - - time_total = 0.0 - its = 2 - for _ in range(its): - generator = mistral.generate(prompt, model) - mx.eval(next(generator)) - tic = time.time() - tokens = [t for t, _ in zip(generator, range(100))] - mx.eval(tokens) - time_total += time.time() - tic - - tps = len(tokens) * its / time_total - print(f"Token generation: {tps:.3f} tokens per second") - - -if __name__ == "__main__": - unittest.main() diff --git a/llms/lookup_decoding/.gitignore b/lookup_decoding/.gitignore similarity index 100% rename from llms/lookup_decoding/.gitignore rename to lookup_decoding/.gitignore diff --git a/llms/lookup_decoding/README.md b/lookup_decoding/README.md similarity index 100% rename from llms/lookup_decoding/README.md rename to lookup_decoding/README.md diff --git a/llms/lookup_decoding/convert.py b/lookup_decoding/convert.py similarity index 100% rename from llms/lookup_decoding/convert.py rename to lookup_decoding/convert.py diff --git a/lookup_decoding/decoder.py b/lookup_decoding/decoder.py new file mode 100644 index 00000000..996ab35d --- /dev/null +++ b/lookup_decoding/decoder.py @@ -0,0 +1,263 @@ +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_unflatten, tree_map +import argparse +import time +import json +from mistral import Mistral, Tokenizer, ModelArgs +from pathlib import Path + +class PromptLookupDecoder: + + def __init__(self, model: str) -> None: + model, tokenizer = self.load_model(model) + self.model = model + self.tokenizer = tokenizer + + def load_model(self, folder: str): + model_path = Path(folder) + tokenizer = Tokenizer(str(model_path / "tokenizer.model")) + with open(model_path / "config.json", "r") as f: + config = json.loads(f.read()) + config.pop("sliding_window", None) + config.pop("model_type", None) + quantization = config.pop("quantization", None) + model_args = ModelArgs(**config) + weights = mx.load(str(model_path / "weights.npz")) + weights = tree_unflatten(list(weights.items())) + model = Mistral(model_args) + if quantization is not None: + nn.QuantizedLinear.quantize_module(model, **quantization) + model.update(weights) + mx.eval(model.parameters()) + return model, tokenizer + + def _generate( + self, + x: mx.array, + temp: float = 0.0, + ): + def sample(logits): + if temp == 0: + return mx.argmax(logits, axis=-1) + else: + return mx.random.categorical(logits * (1 / temp)) + + logits, cache = self.model(x[None]) + y = sample(logits[:, -1, :]) + yield y + + while True: + logits, cache = self.model(x[None, :], cache, next_token_only=True) + x = sample(logits) + yield x + + # Generate without prompt lookup decoding (for testing) + def generate( + self, + prompt, + max_tokens: int = 100, + temp: float = 0.0, + ): + print("[INFO] Starting generation...") + print(prompt, end="", flush=True) + x = mx.array(self.tokenizer.encode(prompt), mx.uint32) + + start = time.time() + for token, n in zip(self._generate(x, temp), range(max_tokens)): + token = token.item() + if token == self.tokenizer.eos_id: + break + print(self.tokenizer.decode([token]), end="", flush=True) + run_time = time.time() - start + print() + print(f"=== GENERATED {n + 1} TOKENS IN {run_time} SECONDS ===") + + """ + Considerations: + - If a match is found but we can't draft n_draft tokens, do we draft as + many as we can or check for a match with a smaller ngram size? + - How do we choose if there are multiple matches? + + This implementation: + - Ignores a match if we can't draft n_draft tokens. This avoids the risk + of only drafting a few tokens. + - We exit upon the first match. This avoids the need to rank matches. + """ + def prompt_lookup(self, prompt: str, max_tokens: int, n_draft: int, + ngram_max: int, ngram_min: int, temp: float, seed: int, + color: bool): + + def sample(logits): + if temp == 0: + return mx.argmax(logits, axis=-1) + else: + return mx.random.categorical(logits * (1 / temp)) + + mx.random.seed(seed) + + print("[INFO] Starting generation...") + start = time.time() + print(prompt, end="", flush=True) + prompt = mx.array(self.tokenizer.encode(prompt), mx.uint32) + tokens = prompt + + # prefill model + logit, cache = self.model(prompt[None]) + token = sample(logit[:, -1, :]) + tokens = mx.concatenate([tokens, token]) + prompt_time = time.time() - start + print(self.tokenizer.decode(token.tolist()), end="", flush=True) + + n_drafted = 0 + n_accepted = 0 + n_generated = 1 + n_decoding_steps = 0 + + while True: + # For each decoding step: generate n_draft tokens by searching the prompt + def generate_draft(input_ids): + input_length = input_ids.size + + for ngram_size in range(ngram_max, ngram_min, -1): + ngram = input_ids[-ngram_size:] + + for i in range(input_length - ngram_size): + if mx.all(input_ids[i:i+ngram_size] == ngram): + start_idx = i + ngram_size + end_idx = start_idx + n_draft + if start_idx < input_length - ngram_size: + return input_ids[start_idx:end_idx] + + return mx.array([], dtype=mx.uint32) + + draft_tokens = generate_draft(tokens) + n_drafted += draft_tokens.size + + # Verify draft tokens with the last verified token + verify_tokens = mx.concatenate([tokens[-1:], draft_tokens]) + logits, cache = self.model(verify_tokens[None], cache=cache) + sampled = sample(logits).squeeze(0) + + # Only keep samples that match the draft. + equal_toks = sampled[:-1] == draft_tokens + num_to_accept = (equal_toks.tolist() + [False]).index(False) + new_tokens = sampled[: max(1, num_to_accept + 1)] + + n_accepted += num_to_accept + + # Rewind the cache for unaccepted tokens: + if (num_to_truncate := draft_tokens.size - num_to_accept) > 0: + if num_to_truncate < cache[0][0].shape[2]: + cache = tree_map( + lambda x: x[:, :, :-num_to_truncate, :], cache + ) + else: + cache = [None] * len(self.model.layers) + + n_decoding_steps += 1 + + # Check stop decodig criteria: + for t in new_tokens.tolist()[:-1]: + if t == self.tokenizer.eos_id: + break + if (color): + print("\033[34m" + self.tokenizer.decode([t]) + "\033[30m", end="", flush=True) + else: + print(self.tokenizer.decode([t]), end="", flush=True) + + print(self.tokenizer.decode(new_tokens[-1:].tolist()), end="", flush=True) + + n_generated += new_tokens.size + if n_generated >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id: + break + + tokens = mx.concatenate([tokens, new_tokens]) + + end = time.time() + print() + print("=== PROMPT EVAL IN", round(prompt_time, 2), "SECONDS ===") + print("=== GENERATED", n_generated, "TOKENS IN", round(end - start, 2), "SECONDS ===") + print("=== ACCEPTED", n_accepted, "DRAFT TOKENS ===") + print("=== ACCEPT", round(n_accepted/n_generated * 100, 2), "% ===") + print("=== DECODING STEPS", n_decoding_steps, "===") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Mistral inference script") + parser.add_argument( + "--model-path", + type=str, + default="mlx_model", + help="The path to the model weights and tokenizer", + ) + parser.add_argument( + "--prompt", + help="The message to be processed by the model", + default="This is a test. This is a test. This is a", + ) + parser.add_argument( + "--max-tokens", + "-m", + type=int, + default=100, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--n-draft", + type=int, + default=10, + help="Number of draft tokens to generate upon prompt lookup match", + ) + parser.add_argument( + "--ngram-max", + type=int, + default=3, + help="Maximum ngrams to match against input during prompt lookup", + ) + parser.add_argument( + "--ngram-min", + type=int, + default=1, + help="Minimum ngrams to match against input during prompt lookup", + ) + parser.add_argument( + "--temp", + help="The sampling temperature.", + type=float, + default=0.0, + ) + parser.add_argument( + "--seed", + type=int, + default=0, + help="The PRNG seed" + ) + parser.add_argument( + "--color", + type=bool, + default=True, + help="Color the accepted draft tokens" + ) + + args = parser.parse_args() + + mx.random.seed(args.seed) + print("[INFO] Loading model from disk.") + + engine = PromptLookupDecoder(args.model_path) + + engine.generate(args.prompt, args.max_tokens, args.temp) + + """ engine.prompt_lookup( + args.prompt, + args.max_tokens, + args.n_draft, + args.ngram_max, + args.ngram_min, + args.temp, + args.seed, + args.color + ) """ + + diff --git a/llms/lookup_decoding/mistral.py b/lookup_decoding/mistral.py similarity index 98% rename from llms/lookup_decoding/mistral.py rename to lookup_decoding/mistral.py index 9880aff7..d481e343 100644 --- a/llms/lookup_decoding/mistral.py +++ b/lookup_decoding/mistral.py @@ -1,5 +1,4 @@ # Copyright © 2023 Apple Inc. - import argparse import json import time @@ -158,6 +157,7 @@ class Mistral(nn.Module): self, x: mx.array, cache=None, + next_token_only: bool = False, ): if cache is not None: offset = cache[0][0].shape[-2] @@ -178,6 +178,9 @@ class Mistral(nn.Module): for e, layer in enumerate(self.layers): x, cache[e] = layer(x, mask, cache[e]) + if next_token_only: + x = x[:, -1] + return self.output(self.norm(x)), cache @@ -272,7 +275,7 @@ if __name__ == "__main__": "--tokens_per_eval", help="The batch size of tokens to generate.", type=int, - default=10, + default=1, ) parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") diff --git a/llms/lookup_decoding/requirements.txt b/lookup_decoding/requirements.txt similarity index 100% rename from llms/lookup_decoding/requirements.txt rename to lookup_decoding/requirements.txt diff --git a/lookup_decoding/test.py b/lookup_decoding/test.py new file mode 100644 index 00000000..f26dca03 --- /dev/null +++ b/lookup_decoding/test.py @@ -0,0 +1,9 @@ +from decoder import PromptLookupDecoder + +prompt = "[INST] Repeat the following phrase 10 times: 'The quick brown fox jumps over the lazy dog.'. Don't say antyhing else. [/INST] " + +engine = PromptLookupDecoder("mlx_model") + +engine.generate(prompt, 250) + +engine.prompt_lookup(prompt, 250, 10, 3, 1, 0.0, 0, True)