mlx-examples/llms/speculative_decoding/main.py
dmdaksh 7d7e236061
- Removed unused Python imports (#683)
- bert/model.py:10: tree_unflatten
  - bert/model.py:2: dataclass
  - bert/model.py:8: numpy
  - cifar/resnet.py:6: Any
  - clip/model.py:15: tree_flatten
  - clip/model.py:9: Union
  - gcn/main.py:8: download_cora
  - gcn/main.py:9: cross_entropy
  - llms/gguf_llm/models.py:12: tree_flatten, tree_unflatten
  - llms/gguf_llm/models.py:9: numpy
  - llms/mixtral/mixtral.py:12: tree_map
  - llms/mlx_lm/models/dbrx.py:2: Dict, Union
  - llms/mlx_lm/tuner/trainer.py:5: partial
  - llms/speculative_decoding/decoder.py:1: dataclass, field
  - llms/speculative_decoding/decoder.py:2: Optional
  - llms/speculative_decoding/decoder.py:5: mlx.nn
  - llms/speculative_decoding/decoder.py:6: numpy
  - llms/speculative_decoding/main.py:2: glob
  - llms/speculative_decoding/main.py:3: json
  - llms/speculative_decoding/main.py:5: Path
  - llms/speculative_decoding/main.py:8: mlx.nn
  - llms/speculative_decoding/model.py:6: tree_unflatten
  - llms/speculative_decoding/model.py:7: AutoTokenizer
  - llms/tests/test_lora.py:13: yaml_loader
  - lora/lora.py:14: tree_unflatten
  - lora/models.py:11: numpy
  - lora/models.py:3: glob
  - speechcommands/kwt.py:1: Any
  - speechcommands/main.py:7: mlx.data
  - stable_diffusion/stable_diffusion/model_io.py:4: partial
  - whisper/benchmark.py:5: sys
  - whisper/test.py:5: subprocess
  - whisper/whisper/audio.py:6: Optional
  - whisper/whisper/decoding.py:8: mlx.nn
2024-04-16 07:50:32 -07:00

96 lines
2.6 KiB
Python

import argparse
import time
import mlx.core as mx
from decoder import SpeculativeDecoder
from mlx.utils import tree_unflatten
from model import Model
from transformers import T5Config
def load_model(model_name: str):
config = T5Config.from_pretrained(model_name)
model = Model(config)
weights = mx.load(f"{model_name}.npz")
weights = tree_unflatten(list(weights.items()))
model.update(weights)
mx.eval(model.parameters())
return model
def main(args):
mx.random.seed(args.seed)
spec_decoder = SpeculativeDecoder(
model=load_model(args.model_name),
draft_model=load_model(args.draft_model_name),
tokenizer=args.model_name,
delta=args.delta,
num_draft=args.num_draft,
)
tic = time.time()
print(args.prompt)
if args.regular_decode:
spec_decoder.generate(args.prompt, max_tokens=args.max_tokens)
else:
stats = spec_decoder.speculative_decode(args.prompt, max_tokens=args.max_tokens)
print("=" * 10)
print(f"Accepted {stats['n_accepted']} / {stats['n_draft']}.")
print(f"Decoding steps {stats['n_steps']}.")
toc = time.time()
print("=" * 10)
print(f"Full generation time {toc - tic:.3f}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
parser.add_argument(
"--num-draft",
type=int,
default=5,
help="Number of draft tokens to use per decoding step.",
)
parser.add_argument(
"--model-name",
help="Name of the model.",
default="t5-small",
)
parser.add_argument(
"--draft-model-name",
help="Name of the draft model.",
default="t5-small",
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="PRNG seed.",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=100,
help="Maximum number of tokens to generate.",
)
parser.add_argument(
"--prompt",
default="translate English to French: Let's go to the store and buy some groceries including eggs, avocadoes, and bread.",
help="The prompt processed by the model.",
)
parser.add_argument(
"--delta",
type=float,
default=0.1,
help="Lenience for accepting the proposal tokens.",
)
parser.add_argument(
"--regular-decode",
action="store_true",
help="Use regular decoding instead of speculative decoding.",
)
args = parser.parse_args()
main(args)