mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
updated implementation to match #149
This commit is contained in:
@@ -1,209 +0,0 @@
|
|||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
from mlx.utils import tree_unflatten, tree_map
|
|
||||||
import argparse
|
|
||||||
import time
|
|
||||||
import json
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
from mistral import Mistral, Tokenizer, ModelArgs
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
class MistralEngine:
|
|
||||||
|
|
||||||
def __init__(self, model: str) -> None:
|
|
||||||
model, tokenizer = self.load_model(model)
|
|
||||||
self.model = model
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
|
|
||||||
def load_model(self, folder: str):
|
|
||||||
model_path = Path(folder)
|
|
||||||
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
|
|
||||||
with open(model_path / "config.json", "r") as f:
|
|
||||||
config = json.loads(f.read())
|
|
||||||
config.pop("sliding_window", None)
|
|
||||||
config.pop("model_type", None)
|
|
||||||
quantization = config.pop("quantization", None)
|
|
||||||
model_args = ModelArgs(**config)
|
|
||||||
weights = mx.load(str(model_path / "weights.npz"))
|
|
||||||
weights = tree_unflatten(list(weights.items()))
|
|
||||||
model = Mistral(model_args)
|
|
||||||
if quantization is not None:
|
|
||||||
nn.QuantizedLinear.quantize_module(model, **quantization)
|
|
||||||
model.update(weights)
|
|
||||||
mx.eval(model.parameters())
|
|
||||||
return model, tokenizer
|
|
||||||
|
|
||||||
# note to self: do something about 'self'
|
|
||||||
"""
|
|
||||||
Considerations:
|
|
||||||
- If a match is found but we can't draft n_draft tokens, do we draft as
|
|
||||||
many as we can or check for a match with a smaller ngram size?
|
|
||||||
- How do we choose if there are multiple matches?
|
|
||||||
|
|
||||||
This implementation:
|
|
||||||
- Ignores a match if we can't draft n_draft tokens. This avoids the risk
|
|
||||||
of only drafting a few tokens.
|
|
||||||
- We exit upon the first match. This avoids the need to rank matches.
|
|
||||||
"""
|
|
||||||
def prompt_lookup(self, input_ids, ngram_max, ngram_min, n_draft):
|
|
||||||
input_length = input_ids.size
|
|
||||||
|
|
||||||
for ngram_size in range(ngram_max, ngram_min, -1):
|
|
||||||
ngram = input_ids[0, -ngram_size:]
|
|
||||||
|
|
||||||
for i in range(input_length - ngram_size):
|
|
||||||
if mx.all(input_ids[0, i:i+ngram_size] == ngram):
|
|
||||||
start_idx = i + ngram_size
|
|
||||||
end_idx = start_idx + n_draft
|
|
||||||
if start_idx < input_length - ngram_size:
|
|
||||||
return input_ids[0, start_idx:end_idx]
|
|
||||||
|
|
||||||
return mx.array([])
|
|
||||||
|
|
||||||
|
|
||||||
def generate(self, prompt: str, max_tokens: int, n_draft: int, ngram_max: int, ngram_min: int, temp: float, seed: int):
|
|
||||||
def sample(logits):
|
|
||||||
if temp == 0:
|
|
||||||
return mx.argmax(logits, axis=-1)
|
|
||||||
else:
|
|
||||||
return mx.random.categorical(logits * (1 / temp))
|
|
||||||
|
|
||||||
mx.random.seed(seed)
|
|
||||||
|
|
||||||
print("[INFO] Starting generation...")
|
|
||||||
tic = time.time()
|
|
||||||
print(prompt, end="", flush=True)
|
|
||||||
prompt = mx.array(self.tokenizer.encode(prompt))
|
|
||||||
tokens = prompt # will store all tokens generated (for prompt lookup)
|
|
||||||
|
|
||||||
# prefill model
|
|
||||||
logit, cache = self.model(prompt[None])
|
|
||||||
token = sample(logit[:, -1, :])
|
|
||||||
tokens = mx.concatenate([tokens, token])
|
|
||||||
|
|
||||||
n_drafted = 0
|
|
||||||
n_accepted = 0
|
|
||||||
n_generated = 1
|
|
||||||
n_past = prompt.size
|
|
||||||
|
|
||||||
while True:
|
|
||||||
# For each decoding step: generate n_draft tokens by searching the prompt
|
|
||||||
draft_tokens = self.prompt_lookup(tokens, ngram_max, ngram_min, n_draft)
|
|
||||||
n_drafted += draft_tokens.size
|
|
||||||
|
|
||||||
|
|
||||||
verify_tokens = mx.concatenate([tokens[-1], draft_tokens])
|
|
||||||
logits, cache = self.model(verify_tokens[None, :-1], cache=cache)
|
|
||||||
|
|
||||||
|
|
||||||
logits = logits[:, :-1, :]
|
|
||||||
sampled = sample(logits)
|
|
||||||
|
|
||||||
num_to_accept = 0
|
|
||||||
for i in range(n_draft):
|
|
||||||
if mx.all(sampled[:, i] == draft_tokens[:, i]):
|
|
||||||
num_to_accept += 1
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
n_past += num_to_accept
|
|
||||||
n_accepted += num_to_accept
|
|
||||||
n_generated += (1 + num_to_accept)
|
|
||||||
|
|
||||||
accepted_tokens = sampled[:, :num_to_accept + 1]
|
|
||||||
tokens = mx.concatenate([tokens, accepted_tokens], axis=1)
|
|
||||||
|
|
||||||
mx.eval(accepted_tokens)
|
|
||||||
s = self.tokenizer.decode([t.item() for t in accepted_tokens])
|
|
||||||
print(s, end="", flush=True)
|
|
||||||
|
|
||||||
# truncate kv cache to keep only accepted tokens
|
|
||||||
# self.model.truncate_kv_cache(n - num_to_accept)
|
|
||||||
cache_length = cache[0][0].shape[2]
|
|
||||||
num_to_truncate = min(num_to_truncate, cache_length)
|
|
||||||
if num_to_truncate == 0:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
cache = tree_map(lambda x: x[:, :, :-num_to_truncate, :], cache)
|
|
||||||
|
|
||||||
if n_accepted >= max_tokens or mx.any(accepted_tokens == self.tokenizer.eos_token_id):
|
|
||||||
break
|
|
||||||
|
|
||||||
mx.eval(tokens)
|
|
||||||
s = self.tokenizer.decode([t.item() for t in tokens])
|
|
||||||
print(s, flush=True)
|
|
||||||
print("------")
|
|
||||||
generation_tps = ntoks / (time.time() - tic)
|
|
||||||
print(
|
|
||||||
f"Tokens per second: prompt {prompt_tps:.3f}, "
|
|
||||||
f"generation {generation_tps:.3f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Mistral inference script")
|
|
||||||
parser.add_argument(
|
|
||||||
"--model-path",
|
|
||||||
type=str,
|
|
||||||
default="mlx_model",
|
|
||||||
help="The path to the model weights and tokenizer",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--prompt",
|
|
||||||
help="The message to be processed by the model",
|
|
||||||
default="This is a test. This is a test. This is a",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-tokens",
|
|
||||||
"-m",
|
|
||||||
type=int,
|
|
||||||
default=100,
|
|
||||||
help="Maximum number of tokens to generate",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--n-draft",
|
|
||||||
type=int,
|
|
||||||
default=10,
|
|
||||||
help="Number of draft tokens to generate upon prompt lookup match",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--ngram-max",
|
|
||||||
type=int,
|
|
||||||
default=3,
|
|
||||||
help="Maximum ngrams to match against input during prompt lookup",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--ngram-min",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Minimum ngrams to match against input during prompt lookup",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--temp",
|
|
||||||
help="The sampling temperature.",
|
|
||||||
type=float,
|
|
||||||
default=0.0,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--seed",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="The PRNG seed"
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
mx.random.seed(args.seed)
|
|
||||||
print("[INFO] Loading model from disk.")
|
|
||||||
|
|
||||||
engine = MistralEngine(args.model_path)
|
|
||||||
engine.generate(
|
|
||||||
args.prompt,
|
|
||||||
args.max_tokens,
|
|
||||||
args.n_draft,
|
|
||||||
args.ngram_max,
|
|
||||||
args.ngram_min,
|
|
||||||
args.temp,
|
|
||||||
args.seed
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
@@ -1,117 +0,0 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import mistral
|
|
||||||
import mlx.core as mx
|
|
||||||
from mlx.utils import tree_map
|
|
||||||
|
|
||||||
|
|
||||||
class TestMistral(unittest.TestCase):
|
|
||||||
def test_model(self):
|
|
||||||
vocab_size = 100
|
|
||||||
L = 32
|
|
||||||
args = mistral.ModelArgs(
|
|
||||||
dim=128,
|
|
||||||
n_layers=2,
|
|
||||||
head_dim=32,
|
|
||||||
hidden_dim=256,
|
|
||||||
n_heads=4,
|
|
||||||
n_kv_heads=4,
|
|
||||||
norm_eps=1e-3,
|
|
||||||
vocab_size=vocab_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
model = mistral.Mistral(args)
|
|
||||||
inputs = mx.random.randint(0, vocab_size, (L,))
|
|
||||||
logits, cache = model(inputs[None])
|
|
||||||
self.assertEqual(logits.shape, [1, L, vocab_size])
|
|
||||||
self.assertEqual(logits.dtype, mx.float32)
|
|
||||||
self.assertEqual(len(cache), args.n_layers)
|
|
||||||
|
|
||||||
params = tree_map(lambda p: p.astype(mx.float16), model.parameters())
|
|
||||||
model.update(params)
|
|
||||||
logits, _ = model(inputs[None])
|
|
||||||
self.assertEqual(logits.dtype, mx.float16)
|
|
||||||
|
|
||||||
def test_generate(self):
|
|
||||||
model, tokenizer = mistral.load_model("mistral-7B-v0.1")
|
|
||||||
prompt = mx.array(tokenizer.encode("This is a test"))
|
|
||||||
tokens = [t for t, _ in zip(mistral.generate(prompt, model), range(30))]
|
|
||||||
mx.eval(tokens)
|
|
||||||
tokens = [t.item() for t in tokens]
|
|
||||||
expected = [
|
|
||||||
302,
|
|
||||||
272,
|
|
||||||
11843,
|
|
||||||
11837,
|
|
||||||
1587,
|
|
||||||
28723,
|
|
||||||
851,
|
|
||||||
349,
|
|
||||||
865,
|
|
||||||
264,
|
|
||||||
1369,
|
|
||||||
28723,
|
|
||||||
13,
|
|
||||||
13,
|
|
||||||
3381,
|
|
||||||
456,
|
|
||||||
654,
|
|
||||||
264,
|
|
||||||
1353,
|
|
||||||
11843,
|
|
||||||
28725,
|
|
||||||
368,
|
|
||||||
682,
|
|
||||||
347,
|
|
||||||
2240,
|
|
||||||
767,
|
|
||||||
298,
|
|
||||||
511,
|
|
||||||
28723,
|
|
||||||
13,
|
|
||||||
]
|
|
||||||
self.assertEqual(tokens, expected)
|
|
||||||
|
|
||||||
def benchmark(self):
|
|
||||||
import time
|
|
||||||
|
|
||||||
model, tokenizer = mistral.load_model("mistral-7B-v0.1")
|
|
||||||
prompt = mx.random.randint(0, model.vocab_size, (128,))
|
|
||||||
|
|
||||||
# warmup
|
|
||||||
for _ in range(2):
|
|
||||||
generator = mistral.generate(prompt, model)
|
|
||||||
mx.eval(next(generator))
|
|
||||||
|
|
||||||
tic = time.time()
|
|
||||||
its = 5
|
|
||||||
for _ in range(its):
|
|
||||||
generator = mistral.generate(prompt, model)
|
|
||||||
mx.eval(next(generator))
|
|
||||||
toc = time.time()
|
|
||||||
tps = its * prompt.size / (toc - tic)
|
|
||||||
print(f"Prompt processing: {tps:.2f} tokens per second")
|
|
||||||
|
|
||||||
# warmup
|
|
||||||
for _ in range(2):
|
|
||||||
tokens = [t for t, _ in zip(mistral.generate(prompt, model), range(101))]
|
|
||||||
mx.eval(tokens)
|
|
||||||
|
|
||||||
time_total = 0.0
|
|
||||||
its = 2
|
|
||||||
for _ in range(its):
|
|
||||||
generator = mistral.generate(prompt, model)
|
|
||||||
mx.eval(next(generator))
|
|
||||||
tic = time.time()
|
|
||||||
tokens = [t for t, _ in zip(generator, range(100))]
|
|
||||||
mx.eval(tokens)
|
|
||||||
time_total += time.time() - tic
|
|
||||||
|
|
||||||
tps = len(tokens) * its / time_total
|
|
||||||
print(f"Token generation: {tps:.3f} tokens per second")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
263
lookup_decoding/decoder.py
Normal file
263
lookup_decoding/decoder.py
Normal file
@@ -0,0 +1,263 @@
|
|||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from mlx.utils import tree_unflatten, tree_map
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
from mistral import Mistral, Tokenizer, ModelArgs
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
class PromptLookupDecoder:
|
||||||
|
|
||||||
|
def __init__(self, model: str) -> None:
|
||||||
|
model, tokenizer = self.load_model(model)
|
||||||
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
def load_model(self, folder: str):
|
||||||
|
model_path = Path(folder)
|
||||||
|
tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
|
||||||
|
with open(model_path / "config.json", "r") as f:
|
||||||
|
config = json.loads(f.read())
|
||||||
|
config.pop("sliding_window", None)
|
||||||
|
config.pop("model_type", None)
|
||||||
|
quantization = config.pop("quantization", None)
|
||||||
|
model_args = ModelArgs(**config)
|
||||||
|
weights = mx.load(str(model_path / "weights.npz"))
|
||||||
|
weights = tree_unflatten(list(weights.items()))
|
||||||
|
model = Mistral(model_args)
|
||||||
|
if quantization is not None:
|
||||||
|
nn.QuantizedLinear.quantize_module(model, **quantization)
|
||||||
|
model.update(weights)
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
temp: float = 0.0,
|
||||||
|
):
|
||||||
|
def sample(logits):
|
||||||
|
if temp == 0:
|
||||||
|
return mx.argmax(logits, axis=-1)
|
||||||
|
else:
|
||||||
|
return mx.random.categorical(logits * (1 / temp))
|
||||||
|
|
||||||
|
logits, cache = self.model(x[None])
|
||||||
|
y = sample(logits[:, -1, :])
|
||||||
|
yield y
|
||||||
|
|
||||||
|
while True:
|
||||||
|
logits, cache = self.model(x[None, :], cache, next_token_only=True)
|
||||||
|
x = sample(logits)
|
||||||
|
yield x
|
||||||
|
|
||||||
|
# Generate without prompt lookup decoding (for testing)
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
max_tokens: int = 100,
|
||||||
|
temp: float = 0.0,
|
||||||
|
):
|
||||||
|
print("[INFO] Starting generation...")
|
||||||
|
print(prompt, end="", flush=True)
|
||||||
|
x = mx.array(self.tokenizer.encode(prompt), mx.uint32)
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
for token, n in zip(self._generate(x, temp), range(max_tokens)):
|
||||||
|
token = token.item()
|
||||||
|
if token == self.tokenizer.eos_id:
|
||||||
|
break
|
||||||
|
print(self.tokenizer.decode([token]), end="", flush=True)
|
||||||
|
run_time = time.time() - start
|
||||||
|
print()
|
||||||
|
print(f"=== GENERATED {n + 1} TOKENS IN {run_time} SECONDS ===")
|
||||||
|
|
||||||
|
"""
|
||||||
|
Considerations:
|
||||||
|
- If a match is found but we can't draft n_draft tokens, do we draft as
|
||||||
|
many as we can or check for a match with a smaller ngram size?
|
||||||
|
- How do we choose if there are multiple matches?
|
||||||
|
|
||||||
|
This implementation:
|
||||||
|
- Ignores a match if we can't draft n_draft tokens. This avoids the risk
|
||||||
|
of only drafting a few tokens.
|
||||||
|
- We exit upon the first match. This avoids the need to rank matches.
|
||||||
|
"""
|
||||||
|
def prompt_lookup(self, prompt: str, max_tokens: int, n_draft: int,
|
||||||
|
ngram_max: int, ngram_min: int, temp: float, seed: int,
|
||||||
|
color: bool):
|
||||||
|
|
||||||
|
def sample(logits):
|
||||||
|
if temp == 0:
|
||||||
|
return mx.argmax(logits, axis=-1)
|
||||||
|
else:
|
||||||
|
return mx.random.categorical(logits * (1 / temp))
|
||||||
|
|
||||||
|
mx.random.seed(seed)
|
||||||
|
|
||||||
|
print("[INFO] Starting generation...")
|
||||||
|
start = time.time()
|
||||||
|
print(prompt, end="", flush=True)
|
||||||
|
prompt = mx.array(self.tokenizer.encode(prompt), mx.uint32)
|
||||||
|
tokens = prompt
|
||||||
|
|
||||||
|
# prefill model
|
||||||
|
logit, cache = self.model(prompt[None])
|
||||||
|
token = sample(logit[:, -1, :])
|
||||||
|
tokens = mx.concatenate([tokens, token])
|
||||||
|
prompt_time = time.time() - start
|
||||||
|
print(self.tokenizer.decode(token.tolist()), end="", flush=True)
|
||||||
|
|
||||||
|
n_drafted = 0
|
||||||
|
n_accepted = 0
|
||||||
|
n_generated = 1
|
||||||
|
n_decoding_steps = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# For each decoding step: generate n_draft tokens by searching the prompt
|
||||||
|
def generate_draft(input_ids):
|
||||||
|
input_length = input_ids.size
|
||||||
|
|
||||||
|
for ngram_size in range(ngram_max, ngram_min, -1):
|
||||||
|
ngram = input_ids[-ngram_size:]
|
||||||
|
|
||||||
|
for i in range(input_length - ngram_size):
|
||||||
|
if mx.all(input_ids[i:i+ngram_size] == ngram):
|
||||||
|
start_idx = i + ngram_size
|
||||||
|
end_idx = start_idx + n_draft
|
||||||
|
if start_idx < input_length - ngram_size:
|
||||||
|
return input_ids[start_idx:end_idx]
|
||||||
|
|
||||||
|
return mx.array([], dtype=mx.uint32)
|
||||||
|
|
||||||
|
draft_tokens = generate_draft(tokens)
|
||||||
|
n_drafted += draft_tokens.size
|
||||||
|
|
||||||
|
# Verify draft tokens with the last verified token
|
||||||
|
verify_tokens = mx.concatenate([tokens[-1:], draft_tokens])
|
||||||
|
logits, cache = self.model(verify_tokens[None], cache=cache)
|
||||||
|
sampled = sample(logits).squeeze(0)
|
||||||
|
|
||||||
|
# Only keep samples that match the draft.
|
||||||
|
equal_toks = sampled[:-1] == draft_tokens
|
||||||
|
num_to_accept = (equal_toks.tolist() + [False]).index(False)
|
||||||
|
new_tokens = sampled[: max(1, num_to_accept + 1)]
|
||||||
|
|
||||||
|
n_accepted += num_to_accept
|
||||||
|
|
||||||
|
# Rewind the cache for unaccepted tokens:
|
||||||
|
if (num_to_truncate := draft_tokens.size - num_to_accept) > 0:
|
||||||
|
if num_to_truncate < cache[0][0].shape[2]:
|
||||||
|
cache = tree_map(
|
||||||
|
lambda x: x[:, :, :-num_to_truncate, :], cache
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cache = [None] * len(self.model.layers)
|
||||||
|
|
||||||
|
n_decoding_steps += 1
|
||||||
|
|
||||||
|
# Check stop decodig criteria:
|
||||||
|
for t in new_tokens.tolist()[:-1]:
|
||||||
|
if t == self.tokenizer.eos_id:
|
||||||
|
break
|
||||||
|
if (color):
|
||||||
|
print("\033[34m" + self.tokenizer.decode([t]) + "\033[30m", end="", flush=True)
|
||||||
|
else:
|
||||||
|
print(self.tokenizer.decode([t]), end="", flush=True)
|
||||||
|
|
||||||
|
print(self.tokenizer.decode(new_tokens[-1:].tolist()), end="", flush=True)
|
||||||
|
|
||||||
|
n_generated += new_tokens.size
|
||||||
|
if n_generated >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
|
||||||
|
break
|
||||||
|
|
||||||
|
tokens = mx.concatenate([tokens, new_tokens])
|
||||||
|
|
||||||
|
end = time.time()
|
||||||
|
print()
|
||||||
|
print("=== PROMPT EVAL IN", round(prompt_time, 2), "SECONDS ===")
|
||||||
|
print("=== GENERATED", n_generated, "TOKENS IN", round(end - start, 2), "SECONDS ===")
|
||||||
|
print("=== ACCEPTED", n_accepted, "DRAFT TOKENS ===")
|
||||||
|
print("=== ACCEPT", round(n_accepted/n_generated * 100, 2), "% ===")
|
||||||
|
print("=== DECODING STEPS", n_decoding_steps, "===")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Mistral inference script")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-path",
|
||||||
|
type=str,
|
||||||
|
default="mlx_model",
|
||||||
|
help="The path to the model weights and tokenizer",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt",
|
||||||
|
help="The message to be processed by the model",
|
||||||
|
default="This is a test. This is a test. This is a",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-tokens",
|
||||||
|
"-m",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="Maximum number of tokens to generate",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n-draft",
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
help="Number of draft tokens to generate upon prompt lookup match",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngram-max",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="Maximum ngrams to match against input during prompt lookup",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngram-min",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Minimum ngrams to match against input during prompt lookup",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--temp",
|
||||||
|
help="The sampling temperature.",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="The PRNG seed"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--color",
|
||||||
|
type=bool,
|
||||||
|
default=True,
|
||||||
|
help="Color the accepted draft tokens"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
mx.random.seed(args.seed)
|
||||||
|
print("[INFO] Loading model from disk.")
|
||||||
|
|
||||||
|
engine = PromptLookupDecoder(args.model_path)
|
||||||
|
|
||||||
|
engine.generate(args.prompt, args.max_tokens, args.temp)
|
||||||
|
|
||||||
|
""" engine.prompt_lookup(
|
||||||
|
args.prompt,
|
||||||
|
args.max_tokens,
|
||||||
|
args.n_draft,
|
||||||
|
args.ngram_max,
|
||||||
|
args.ngram_min,
|
||||||
|
args.temp,
|
||||||
|
args.seed,
|
||||||
|
args.color
|
||||||
|
) """
|
||||||
|
|
||||||
|
|
@@ -1,5 +1,4 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
@@ -158,6 +157,7 @@ class Mistral(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
cache=None,
|
cache=None,
|
||||||
|
next_token_only: bool = False,
|
||||||
):
|
):
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
offset = cache[0][0].shape[-2]
|
offset = cache[0][0].shape[-2]
|
||||||
@@ -178,6 +178,9 @@ class Mistral(nn.Module):
|
|||||||
for e, layer in enumerate(self.layers):
|
for e, layer in enumerate(self.layers):
|
||||||
x, cache[e] = layer(x, mask, cache[e])
|
x, cache[e] = layer(x, mask, cache[e])
|
||||||
|
|
||||||
|
if next_token_only:
|
||||||
|
x = x[:, -1]
|
||||||
|
|
||||||
return self.output(self.norm(x)), cache
|
return self.output(self.norm(x)), cache
|
||||||
|
|
||||||
|
|
||||||
@@ -272,7 +275,7 @@ if __name__ == "__main__":
|
|||||||
"--tokens_per_eval",
|
"--tokens_per_eval",
|
||||||
help="The batch size of tokens to generate.",
|
help="The batch size of tokens to generate.",
|
||||||
type=int,
|
type=int,
|
||||||
default=10,
|
default=1,
|
||||||
)
|
)
|
||||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||||
|
|
9
lookup_decoding/test.py
Normal file
9
lookup_decoding/test.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
from decoder import PromptLookupDecoder
|
||||||
|
|
||||||
|
prompt = "[INST] Repeat the following phrase 10 times: 'The quick brown fox jumps over the lazy dog.'. Don't say antyhing else. [/INST] "
|
||||||
|
|
||||||
|
engine = PromptLookupDecoder("mlx_model")
|
||||||
|
|
||||||
|
engine.generate(prompt, 250)
|
||||||
|
|
||||||
|
engine.prompt_lookup(prompt, 250, 10, 3, 1, 0.0, 0, True)
|
Reference in New Issue
Block a user