mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Merge branch 'main' into feat/batch_generate
This commit is contained in:
commit
9ee726cf18
3
.gitignore
vendored
3
.gitignore
vendored
@ -6,6 +6,9 @@ __pycache__/
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Vim
|
||||
*.swp
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
|
@ -20,6 +20,31 @@ The `mlx-lm` package also has:
|
||||
- [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md)
|
||||
- [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md)
|
||||
|
||||
### Quick Start
|
||||
|
||||
To generate text with an LLM use:
|
||||
|
||||
```bash
|
||||
mlx_lm.generate --prompt "Hi!"
|
||||
```
|
||||
|
||||
To chat with an LLM use:
|
||||
|
||||
```bash
|
||||
mlx_lm.chat
|
||||
```
|
||||
|
||||
This will give you a chat REPL that you can use to interact with the LLM. The
|
||||
chat context is preserved during the lifetime of the REPL.
|
||||
|
||||
Commands in `mlx-lm` typically take command line options which let you specify
|
||||
the model, sampling parameters, and more. Use `-h` to see a list of available
|
||||
options for a command, e.g.:
|
||||
|
||||
```bash
|
||||
mlx_lm.generate -h
|
||||
```
|
||||
|
||||
### Python API
|
||||
|
||||
You can use `mlx-lm` as a module:
|
||||
@ -138,7 +163,7 @@ mlx_lm.convert \
|
||||
|
||||
### Long Prompts and Generations
|
||||
|
||||
MLX LM has some tools to scale efficiently to long prompts and generations:
|
||||
`mlx-lm` has some tools to scale efficiently to long prompts and generations:
|
||||
|
||||
- A rotating fixed-size key-value cache.
|
||||
- Prompt caching
|
||||
@ -155,14 +180,14 @@ different queries. To cache a prompt use `mlx_lm.cache_prompt`. For example:
|
||||
cat prompt.txt | mlx_lm.cache_prompt \
|
||||
--model mistralai/Mistral-7B-Instruct-v0.3 \
|
||||
--prompt - \
|
||||
--kv-cache-file mistral_prompt.safetensors
|
||||
--prompt-cache-file mistral_prompt.safetensors
|
||||
```
|
||||
|
||||
Then use the cached prompt with `mlx_lm.generate`:
|
||||
|
||||
```
|
||||
mlx_lm.generate \
|
||||
--kv-cache-file mistral_prompt.safetensors \
|
||||
--prompt-cache-file mistral_prompt.safetensors \
|
||||
--prompt "\nSummarize the above text."
|
||||
```
|
||||
|
||||
@ -170,9 +195,15 @@ The cached prompt is treated as a prefix to the supplied prompt. Also notice
|
||||
when using a cached prompt, the model to use is read from the cache and need
|
||||
not be supplied explicitly.
|
||||
|
||||
Prompt caching can also be used in the Python API in order to to avoid
|
||||
recomputing the prompt. This is useful in multi-turn dialogues or across
|
||||
requests that use the same context. See the
|
||||
[example](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/examples/chat.py)
|
||||
for more usage details.
|
||||
|
||||
### Supported Models
|
||||
|
||||
MLX LM supports thousands of Hugging Face format LLMs. If the model you want to
|
||||
`mlx-lm` supports thousands of Hugging Face format LLMs. If the model you want to
|
||||
run is not supported, file an
|
||||
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
|
||||
submit a pull request.
|
||||
|
26
llms/a.py
26
llms/a.py
@ -1,26 +0,0 @@
|
||||
import mlx_lm
|
||||
|
||||
# model, tokenizer = mlx_lm.load("mlx-community/SmolLM-1.7B-Instruct-fp16")
|
||||
model, tokenizer = mlx_lm.load("/Users/llwu/models/mlx/Qwen2-0.5B-8bit-Instruct")
|
||||
draft_model, draft_tokenizer = mlx_lm.load("mlx-community/SmolLM-135M-Instruct-4bit")
|
||||
|
||||
# https://github.com/hemingkx/Spec-Bench/blob/main/data/spec_bench/question.jsonl
|
||||
prompt = "Develop a Python program that reads all the text files under a directory and returns top-5 words with the most number of occurrences."
|
||||
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
mlx_lm.generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt=prompt,
|
||||
verbose=True,
|
||||
max_tokens=500,
|
||||
temp=1.0,
|
||||
min_p=0.1,
|
||||
repetition_penalty=1.2,
|
||||
# draft_model=draft_model,
|
||||
)
|
41
llms/b.py
41
llms/b.py
@ -1,41 +0,0 @@
|
||||
import mlx_lm
|
||||
import random
|
||||
import string
|
||||
|
||||
model, tokenizer = mlx_lm.load("/Users/llwu/models/mlx/Qwen2-0.5B-8bit-Instruct")
|
||||
|
||||
capital_letters = string.ascii_uppercase
|
||||
distinct_pairs = [
|
||||
(a, b) for i, a in enumerate(capital_letters) for b in capital_letters[i + 1 :]
|
||||
]
|
||||
|
||||
num_prompts = 16
|
||||
prompt_template = "Think of a real word containing both the letters {l1} and {l2}. Then, say 3 sentences which use the word."
|
||||
prompts = [
|
||||
prompt_template.format(l1=p[0], l2=p[1])
|
||||
for p in random.sample(distinct_pairs, num_prompts)
|
||||
]
|
||||
prompts = [
|
||||
"Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?",
|
||||
"James writes a 3-page letter to 2 different friends twice a week. How many pages does he write a year?",
|
||||
"Tina makes $18.00 an hour. If she works more than 8 hours per shift, she is eligible for overtime, which is paid by your hourly wage + 1/2 your hourly wage. If she works 10 hours every day for 5 days, how much money does she make?"
|
||||
]
|
||||
prompts = [
|
||||
tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
|
||||
response = mlx_lm.batch_generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompts=prompts,
|
||||
max_tokens=512,
|
||||
verbose=True,
|
||||
temp=1.0,
|
||||
min_p=0.1,
|
||||
repetition_penalty=1.2,
|
||||
)
|
11
llms/c.py
11
llms/c.py
@ -1,11 +0,0 @@
|
||||
import mlx_lm
|
||||
|
||||
model, tokenizer = mlx_lm.load("/Users/llwu/models/mlx/Meta-Llama-3.1-8B-4bit")
|
||||
|
||||
for s in mlx_lm.stream_generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt="Meta Llama 3.1 is a ",
|
||||
max_tokens=100,
|
||||
):
|
||||
print(s, end="", flush=True)
|
11
llms/d.py
11
llms/d.py
@ -1,11 +0,0 @@
|
||||
import mlx_lm
|
||||
|
||||
model, tokenizer = mlx_lm.load("/Users/llwu/models/mlx/Meta-Llama-3.1-8B-4bit")
|
||||
|
||||
for s in mlx_lm.stream_generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt=["Meta Llama 3.1 is a ", "Google Gemma 2 is a "],
|
||||
max_tokens=20,
|
||||
):
|
||||
print(s[0].ljust(30) + s[1], flush=True)
|
@ -1,21 +0,0 @@
|
||||
## Steps to reproduce
|
||||
|
||||
Run the following with and without `prefill_step_size=2` commented out:
|
||||
|
||||
```py
|
||||
import mlx_lm
|
||||
|
||||
model, tokenizer = mlx_lm.load('/Users/llwu/models/mlx/Meta-Llama-3.1-8B-4bit')
|
||||
|
||||
mlx_lm.generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt="69 + 420= ",
|
||||
verbose=True,
|
||||
max_tokens=10,
|
||||
max_kv_size=5,
|
||||
prefill_step_size=2,
|
||||
)
|
||||
```
|
||||
|
||||
The output is different. I notice that the RotatingKVCache has length 5 with prefill and length 7 without.
|
@ -1,3 +1,3 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
__version__ = "0.18.2"
|
||||
__version__ = "0.19.1"
|
||||
|
@ -7,13 +7,14 @@ import time
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .utils import load, make_kv_caches
|
||||
from .models.cache import make_prompt_cache, save_prompt_cache
|
||||
from .utils import load
|
||||
|
||||
|
||||
def setup_arg_parser():
|
||||
"""Set up and return the argument parser."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Cache the KV cache of a prompt to be reused with mlx_lm.generate"
|
||||
description="Cache the state of a prompt to be reused with mlx_lm.generate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
@ -60,7 +61,9 @@ def setup_arg_parser():
|
||||
help="Set the maximum key-value cache size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-cache-file", help="The file to save the KV caches in", required=True
|
||||
"--prompt-cache-file",
|
||||
help="The file to save the prompt cache in",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
@ -115,7 +118,7 @@ def main():
|
||||
else:
|
||||
prompt = args.prompt
|
||||
|
||||
cache = make_kv_caches(model, args.max_kv_size)
|
||||
cache = make_prompt_cache(model, args.max_kv_size)
|
||||
y = mx.array(tokenizer.encode(prompt))
|
||||
|
||||
# Process the prompt
|
||||
@ -137,16 +140,12 @@ def main():
|
||||
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
|
||||
|
||||
print("Saving...")
|
||||
cache_dict = {}
|
||||
for i, c in enumerate(cache):
|
||||
cache_dict[f"{i}_keys"] = c.state[0][..., : c.offset, :]
|
||||
cache_dict[f"{i}_values"] = c.state[1][..., : c.offset, :]
|
||||
metadata = {}
|
||||
metadata["model"] = args.model
|
||||
metadata["chat_template"] = tokenizer.chat_template
|
||||
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
|
||||
metadata["max_kv_size"] = str(args.max_kv_size)
|
||||
mx.save_safetensors(args.kv_cache_file, cache_dict, metadata)
|
||||
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
|
||||
save_prompt_cache(args.prompt_cache_file, cache, metadata)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
82
llms/mlx_lm/chat.py
Normal file
82
llms/mlx_lm/chat.py
Normal file
@ -0,0 +1,82 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache
|
||||
from .utils import load, stream_generate
|
||||
|
||||
DEFAULT_TEMP = 0.0
|
||||
DEFAULT_TOP_P = 1.0
|
||||
DEFAULT_SEED = 0
|
||||
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||
|
||||
|
||||
def setup_arg_parser():
|
||||
"""Set up and return the argument parser."""
|
||||
parser = argparse.ArgumentParser(description="Chat with an LLM")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
default=DEFAULT_MODEL,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-path",
|
||||
type=str,
|
||||
help="Optional path for the trained adapter weights and config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
|
||||
parser.add_argument(
|
||||
"--max-kv-size",
|
||||
type=int,
|
||||
help="Set the maximum key-value cache size",
|
||||
default=None,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = setup_arg_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
model, tokenizer = load(
|
||||
args.model,
|
||||
adapter_path=args.adapter_path,
|
||||
tokenizer_config={"trust_remote_code": True},
|
||||
)
|
||||
|
||||
print(f"[INFO] Starting chat sessiong with {args.model}. To exit, enter 'q'.")
|
||||
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
||||
while True:
|
||||
query = input(">> ")
|
||||
if query == "q":
|
||||
break
|
||||
messages = [{"role": "user", "content": query}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
for response in stream_generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt,
|
||||
temp=args.temp,
|
||||
top_p=args.top_p,
|
||||
prompt_cache=prompt_cache,
|
||||
):
|
||||
print(response, flush=True, end="")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
53
llms/mlx_lm/examples/chat.py
Normal file
53
llms/mlx_lm/examples/chat.py
Normal file
@ -0,0 +1,53 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
"""
|
||||
An example of a multi-turn chat with prompt caching.
|
||||
"""
|
||||
|
||||
from mlx_lm import generate, load
|
||||
from mlx_lm.models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache
|
||||
|
||||
model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")
|
||||
|
||||
# Make the initial prompt cache for the model
|
||||
prompt_cache = make_prompt_cache(model)
|
||||
|
||||
# User turn
|
||||
prompt = "Hi my name is <Name>."
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Assistant response
|
||||
response = generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt=prompt,
|
||||
verbose=True,
|
||||
temp=0.0,
|
||||
prompt_cache=prompt_cache,
|
||||
)
|
||||
|
||||
# User turn
|
||||
prompt = "What's my name?"
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Assistant response
|
||||
response = generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt=prompt,
|
||||
verbose=True,
|
||||
temp=0.0,
|
||||
prompt_cache=prompt_cache,
|
||||
)
|
||||
|
||||
# Save the prompt cache to disk to reuse it at a later time
|
||||
save_prompt_cache("mistral_prompt.safetensors", prompt_cache)
|
||||
|
||||
# Load the prompt cache from disk
|
||||
prompt_cache = load_prompt_cache("mistral_prompt.safetensors")
|
@ -1,3 +1,5 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
from mlx_lm import generate, load
|
||||
|
||||
# Specify the checkpoint
|
||||
|
@ -6,13 +6,15 @@ import sys
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .models.cache import load_prompt_cache
|
||||
from .utils import generate, load
|
||||
|
||||
DEFAULT_PROMPT = "hello"
|
||||
DEFAULT_MAX_TOKENS = 100
|
||||
DEFAULT_TEMP = 0.6
|
||||
DEFAULT_TEMP = 0.0
|
||||
DEFAULT_TOP_P = 1.0
|
||||
DEFAULT_SEED = 0
|
||||
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||
|
||||
|
||||
def str2bool(string):
|
||||
@ -25,7 +27,11 @@ def setup_arg_parser():
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
help="The path to the local model directory or Hugging Face repo.",
|
||||
help=(
|
||||
"The path to the local model directory or Hugging Face repo. "
|
||||
f"If no model is specified, then {DEFAULT_MODEL} is used."
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-path",
|
||||
@ -96,7 +102,7 @@ def setup_arg_parser():
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-cache-file",
|
||||
"--prompt-cache-file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A file containing saved KV caches to avoid recomputing them",
|
||||
@ -131,24 +137,6 @@ def colorprint_by_t0(s, t0):
|
||||
colorprint(color, s)
|
||||
|
||||
|
||||
def load_kv_cache_from_file(kv_cache_file):
|
||||
if kv_cache_file is None:
|
||||
return None, None
|
||||
|
||||
kv_cache, metadata = mx.load(kv_cache_file, return_metadata=True)
|
||||
cache_per_layer = {}
|
||||
for k, x in kv_cache.items():
|
||||
layer, kv_type = k.split("_")
|
||||
if layer not in cache_per_layer:
|
||||
cache_per_layer[layer] = {}
|
||||
cache_per_layer[layer][kv_type] = x
|
||||
|
||||
cache_history = [None] * len(cache_per_layer)
|
||||
for layer, c in cache_per_layer.items():
|
||||
cache_history[int(layer)] = (c["keys"], c["values"])
|
||||
return cache_history, metadata
|
||||
|
||||
|
||||
def main():
|
||||
parser = setup_arg_parser()
|
||||
args = parser.parse_args()
|
||||
@ -158,22 +146,33 @@ def main():
|
||||
if args.cache_limit_gb is not None:
|
||||
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
|
||||
|
||||
# Load the kv cache and metadata if a kv cache file is provided
|
||||
cache_history, metadata = load_kv_cache_from_file(args.kv_cache_file)
|
||||
# Load the prompt cache and metadata if a cache file is provided
|
||||
using_cache = args.prompt_cache_file is not None
|
||||
if using_cache:
|
||||
prompt_cache, metadata = load_prompt_cache(
|
||||
args.prompt_cache_file, return_metadata=True
|
||||
)
|
||||
|
||||
# Building tokenizer_config
|
||||
tokenizer_config = (
|
||||
{} if cache_history is None else json.loads(metadata["tokenizer_config"])
|
||||
{} if not using_cache else json.loads(metadata["tokenizer_config"])
|
||||
)
|
||||
if args.trust_remote_code:
|
||||
tokenizer_config["trust_remote_code"] = True
|
||||
if args.eos_token is not None:
|
||||
tokenizer_config["eos_token"] = args.eos_token
|
||||
|
||||
# If no model path is provided then use the one in the kv cache history
|
||||
model_path = args.model
|
||||
if cache_history is not None and model_path is None:
|
||||
model_path = metadata["model"]
|
||||
if using_cache:
|
||||
if model_path is None:
|
||||
model_path = metadata["model"]
|
||||
elif model_path != metadata["model"]:
|
||||
raise ValueError(
|
||||
f"Providing a different model ({model_path}) than that "
|
||||
f"used to create the prompt cache ({metadata['model']}) "
|
||||
"is an error."
|
||||
)
|
||||
model_path = model_path or DEFAULT_MODEL
|
||||
|
||||
model, tokenizer = load(
|
||||
model_path,
|
||||
@ -184,7 +183,7 @@ def main():
|
||||
if args.use_default_chat_template:
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = tokenizer.default_chat_template
|
||||
elif cache_history is not None:
|
||||
elif using_cache:
|
||||
tokenizer.chat_template = metadata["chat_template"]
|
||||
|
||||
if not args.ignore_chat_template and (
|
||||
@ -203,7 +202,7 @@ def main():
|
||||
|
||||
# Treat the prompt as a suffix assuming that the prefix is in the
|
||||
# stored kv cache.
|
||||
if cache_history is not None:
|
||||
if using_cache:
|
||||
test_prompt = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": "<query>"}],
|
||||
tokenize=False,
|
||||
@ -217,12 +216,6 @@ def main():
|
||||
raise ValueError("Cannot use --colorize with --verbose=False")
|
||||
formatter = colorprint_by_t0 if args.colorize else None
|
||||
|
||||
# Determine the max kv size from the kv cache or passed arguments
|
||||
max_kv_size = args.max_kv_size
|
||||
if cache_history is not None:
|
||||
max_kv_size = metadata["max_kv_size"]
|
||||
max_kv_size = int(max_kv_size) if max_kv_size.isdigit() else None
|
||||
|
||||
response = generate(
|
||||
model,
|
||||
tokenizer,
|
||||
@ -232,8 +225,8 @@ def main():
|
||||
formatter=formatter,
|
||||
temp=args.temp,
|
||||
top_p=args.top_p,
|
||||
max_kv_size=max_kv_size,
|
||||
cache_history=cache_history,
|
||||
max_kv_size=args.max_kv_size,
|
||||
prompt_cache=prompt_cache if using_cache else None,
|
||||
)
|
||||
if not args.verbose:
|
||||
print(response)
|
||||
|
@ -2,145 +2,9 @@
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class KVCache:
|
||||
|
||||
def __init__(self, head_dim, n_kv_heads):
|
||||
self.n_kv_heads = n_kv_heads
|
||||
if isinstance(head_dim, int):
|
||||
self.k_head_dim = self.v_head_dim = head_dim
|
||||
elif isinstance(head_dim, tuple) and len(head_dim) == 2:
|
||||
self.k_head_dim, self.v_head_dim = head_dim
|
||||
else:
|
||||
raise ValueError("head_dim must be an int or a tuple of two ints")
|
||||
self.keys = None
|
||||
self.values = None
|
||||
self.offset = 0
|
||||
self.step = 256
|
||||
|
||||
def update_and_fetch(self, keys, values):
|
||||
prev = self.offset
|
||||
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
|
||||
B = keys.shape[0]
|
||||
n_steps = (self.step + keys.shape[2] - 1) // self.step
|
||||
k_shape = (B, self.n_kv_heads, n_steps * self.step, self.k_head_dim)
|
||||
v_shape = (B, self.n_kv_heads, n_steps * self.step, self.v_head_dim)
|
||||
new_k = mx.zeros(k_shape, keys.dtype)
|
||||
new_v = mx.zeros(v_shape, values.dtype)
|
||||
if self.keys is not None:
|
||||
if prev % self.step != 0:
|
||||
self.keys = self.keys[..., :prev, :]
|
||||
self.values = self.values[..., :prev, :]
|
||||
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
||||
self.values = mx.concatenate([self.values, new_v], axis=2)
|
||||
else:
|
||||
self.keys, self.values = new_k, new_v
|
||||
|
||||
self.offset += keys.shape[2]
|
||||
self.keys[..., prev : self.offset, :] = keys
|
||||
self.values[..., prev : self.offset, :] = values
|
||||
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self.keys, self.values
|
||||
|
||||
|
||||
class RotatingKVCache:
|
||||
|
||||
def __init__(self, head_dim, n_kv_heads, max_size, keep=0, step=256):
|
||||
self.n_kv_heads = n_kv_heads
|
||||
if isinstance(head_dim, int):
|
||||
self.k_head_dim = self.v_head_dim = head_dim
|
||||
elif isinstance(head_dim, tuple) and len(head_dim) == 2:
|
||||
self.k_head_dim, self.v_head_dim = head_dim
|
||||
else:
|
||||
raise ValueError("head_dim must be an int or a tuple of two ints")
|
||||
self.keep = keep
|
||||
self.keys = None
|
||||
self.values = None
|
||||
self.offset = 0
|
||||
self.max_size = max_size
|
||||
self.step = step
|
||||
self._idx = 0
|
||||
|
||||
def _trim(self, trim_size, v, append=None):
|
||||
to_cat = []
|
||||
if trim_size > 0:
|
||||
to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]]
|
||||
else:
|
||||
to_cat = [v]
|
||||
if append is not None:
|
||||
to_cat.append(append)
|
||||
return mx.concatenate(to_cat, axis=2)
|
||||
|
||||
def update_and_fetch(self, keys, values):
|
||||
prev = self.offset
|
||||
B, _, S = keys.shape[:3]
|
||||
|
||||
# Prefill mode
|
||||
if S > 1:
|
||||
if self.keys is None:
|
||||
self.keys = keys
|
||||
self.values = values
|
||||
else:
|
||||
# The largest size is self.max_size + S - 1 to ensure
|
||||
# every token gets at least self.max_size context
|
||||
trim_size = self.keys.shape[2] - self.max_size + 1
|
||||
self.keys = self._trim(trim_size, self.keys, keys)
|
||||
self.values = self._trim(trim_size, self.values, values)
|
||||
self.offset += S
|
||||
self._idx = self.keys.shape[2]
|
||||
return self.keys, self.values
|
||||
|
||||
# Generation mode
|
||||
# May not have hit the max size yet, so potentially
|
||||
# keep growing the cache
|
||||
if self.keys is None or (
|
||||
prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
|
||||
):
|
||||
new_size = min(self.step, self.max_size - prev)
|
||||
k_shape = (B, self.n_kv_heads, new_size, self.k_head_dim)
|
||||
v_shape = (B, self.n_kv_heads, new_size, self.v_head_dim)
|
||||
new_k = mx.zeros(k_shape, keys.dtype)
|
||||
new_v = mx.zeros(v_shape, values.dtype)
|
||||
if self.keys is not None:
|
||||
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
||||
self.values = mx.concatenate([self.values, new_v], axis=2)
|
||||
else:
|
||||
self.keys, self.values = new_k, new_v
|
||||
self._idx = prev
|
||||
|
||||
# Trim if needed
|
||||
trim_size = self.keys.shape[2] - self.max_size
|
||||
if trim_size > 0:
|
||||
self.keys = self._trim(trim_size, self.keys)
|
||||
self.values = self._trim(trim_size, self.values)
|
||||
self._idx = self.max_size
|
||||
|
||||
# Rotate
|
||||
if self._idx == self.max_size:
|
||||
self._idx = self.keep
|
||||
|
||||
# Assign
|
||||
self.keys[..., self._idx : self._idx + 1, :] = keys
|
||||
self.values[..., self._idx : self._idx + 1, :] = values
|
||||
self.offset += 1
|
||||
self._idx += 1
|
||||
|
||||
# If the buffer is not full, slice off the end
|
||||
if self.offset < self.max_size:
|
||||
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||
return self.keys, self.values
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self.keys, self.values
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -156,25 +20,30 @@ class BaseModelArgs:
|
||||
)
|
||||
|
||||
|
||||
def create_additive_causal_mask(N: int, offset: int = 0):
|
||||
def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None):
|
||||
rinds = mx.arange(offset + N)
|
||||
linds = mx.arange(offset, offset + N) if offset else rinds
|
||||
mask = linds[:, None] < rinds[None]
|
||||
linds = linds[:, None]
|
||||
rinds = rinds[None]
|
||||
mask = linds < rinds
|
||||
if window_size is not None:
|
||||
mask = mask | (linds > rinds + window_size)
|
||||
return mask * -1e9
|
||||
|
||||
|
||||
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
||||
T = h.shape[1]
|
||||
if T > 1:
|
||||
window_size = None
|
||||
offset = 0
|
||||
if cache is not None and cache[0] is not None:
|
||||
c = cache[0]
|
||||
if isinstance(c, RotatingKVCache):
|
||||
if hasattr(c, "max_size"):
|
||||
offset = min(c.max_size - 1, c.offset)
|
||||
window_size = c.max_size
|
||||
else:
|
||||
offset = c.offset
|
||||
else:
|
||||
offset = 0
|
||||
mask = create_additive_causal_mask(T, offset)
|
||||
mask = create_causal_mask(T, offset, window_size=window_size)
|
||||
mask = mask.astype(h.dtype)
|
||||
else:
|
||||
mask = None
|
||||
|
333
llms/mlx_lm/models/cache.py
Normal file
333
llms/mlx_lm/models/cache.py
Normal file
@ -0,0 +1,333 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
|
||||
|
||||
def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> List[Any]:
|
||||
"""
|
||||
Construct the model's cache for use when cgeneration.
|
||||
|
||||
This function will defer the cache construction to the model if it has a
|
||||
``make_cache`` method, otherwise it will make a default KV cache.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The language model.
|
||||
max_kv_size (Optional[int]): If provided and the model does not have a
|
||||
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
|
||||
size of ``max_kv_size``
|
||||
"""
|
||||
if hasattr(model, "make_cache"):
|
||||
return model.make_cache()
|
||||
|
||||
num_layers = len(model.layers)
|
||||
if max_kv_size is not None:
|
||||
return [
|
||||
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
|
||||
]
|
||||
else:
|
||||
return [KVCache() for _ in range(num_layers)]
|
||||
|
||||
|
||||
def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}):
|
||||
"""
|
||||
Save a pre-computed prompt cache to a file.
|
||||
|
||||
Args:
|
||||
file_name (str): The ``.safetensors`` file name.
|
||||
cache (List[Any]): The model state.
|
||||
metadata (Dict[str, str]): Optional metadata to save along with model
|
||||
state.
|
||||
"""
|
||||
cache_data = [c.state for c in cache]
|
||||
cache_info = [c.meta_state for c in cache]
|
||||
cache_data = dict(tree_flatten(cache_data))
|
||||
cache_classes = [type(c).__name__ for c in cache]
|
||||
cache_metadata = [cache_info, metadata, cache_classes]
|
||||
cache_metadata = dict(tree_flatten(cache_metadata))
|
||||
mx.save_safetensors(file_name, cache_data, cache_metadata)
|
||||
|
||||
|
||||
def load_prompt_cache(file_name, return_metadata=False):
|
||||
"""
|
||||
Load a prompt cache from a file.
|
||||
|
||||
Args:
|
||||
file_name (str): The ``.safetensors`` file name.
|
||||
return_metadata (bool): Whether or not to return metadata.
|
||||
Default: ``False``.
|
||||
|
||||
Returns:
|
||||
List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and
|
||||
the metadata if requested.
|
||||
"""
|
||||
arrays, cache_metadata = mx.load(file_name, return_metadata=True)
|
||||
arrays = tree_unflatten(list(arrays.items()))
|
||||
cache_metadata = tree_unflatten(list(cache_metadata.items()))
|
||||
info, metadata, classes = cache_metadata
|
||||
cache = [globals()[c]() for c in classes]
|
||||
for c, state, meta_state in zip(cache, arrays, info):
|
||||
c.state = state
|
||||
c.meta_state = meta_state
|
||||
if return_metadata:
|
||||
return cache, metadata
|
||||
return cache
|
||||
|
||||
|
||||
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
|
||||
"""
|
||||
Trim the model's cache by the given number of tokens.
|
||||
|
||||
This function will trim the cache if possible (in-place) and return the
|
||||
number of tokens that were trimmed.
|
||||
|
||||
Args:
|
||||
cache (List[Any]): The model's cache.
|
||||
num_tokens (int): The number of tokens to trim.
|
||||
|
||||
Returns:
|
||||
(int): The number of tokens that were trimmed.
|
||||
"""
|
||||
if not all(c.is_trimmable() for c in cache) or len(cache) == 0:
|
||||
return 0
|
||||
return [c.trim(num_tokens) for c in cache][0]
|
||||
|
||||
|
||||
class _BaseCache:
|
||||
@property
|
||||
def state(self):
|
||||
return []
|
||||
|
||||
@state.setter
|
||||
def state(self, v):
|
||||
if v is not None and v:
|
||||
raise ValueError("This cache has no state but a state was set.")
|
||||
|
||||
@property
|
||||
def meta_state(self):
|
||||
return ""
|
||||
|
||||
@meta_state.setter
|
||||
def meta_state(self, v):
|
||||
if v is not None and v:
|
||||
raise ValueError("This cache has no meta_state but a meta_state was set.")
|
||||
|
||||
def is_trimmable(self):
|
||||
return False
|
||||
|
||||
|
||||
class KVCache(_BaseCache):
|
||||
def __init__(self):
|
||||
self.keys = None
|
||||
self.values = None
|
||||
self.offset = 0
|
||||
self.step = 256
|
||||
|
||||
def update_and_fetch(self, keys, values):
|
||||
prev = self.offset
|
||||
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
|
||||
B, n_kv_heads, _, k_head_dim = keys.shape
|
||||
v_head_dim = values.shape[3]
|
||||
n_steps = (self.step + keys.shape[2] - 1) // self.step
|
||||
k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim)
|
||||
v_shape = (B, n_kv_heads, n_steps * self.step, v_head_dim)
|
||||
new_k = mx.zeros(k_shape, keys.dtype)
|
||||
new_v = mx.zeros(v_shape, values.dtype)
|
||||
if self.keys is not None:
|
||||
if prev % self.step != 0:
|
||||
self.keys = self.keys[..., :prev, :]
|
||||
self.values = self.values[..., :prev, :]
|
||||
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
||||
self.values = mx.concatenate([self.values, new_v], axis=2)
|
||||
else:
|
||||
self.keys, self.values = new_k, new_v
|
||||
|
||||
self.offset += keys.shape[2]
|
||||
self.keys[..., prev : self.offset, :] = keys
|
||||
self.values[..., prev : self.offset, :] = values
|
||||
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
if self.offset == self.keys.shape[2]:
|
||||
return self.keys, self.values
|
||||
else:
|
||||
return (
|
||||
self.keys[..., : self.offset, :],
|
||||
self.values[..., : self.offset, :],
|
||||
)
|
||||
|
||||
@state.setter
|
||||
def state(self, v):
|
||||
self.keys, self.values = v
|
||||
self.offset = self.keys.shape[2]
|
||||
|
||||
def is_trimmable(self):
|
||||
return True
|
||||
|
||||
def trim(self, n):
|
||||
n = min(self.offset, n)
|
||||
self.offset -= n
|
||||
return n
|
||||
|
||||
|
||||
class RotatingKVCache(_BaseCache):
|
||||
|
||||
def __init__(self, max_size=None, keep=0, step=256):
|
||||
self.keep = keep
|
||||
self.keys = None
|
||||
self.values = None
|
||||
self.offset = 0
|
||||
self.max_size = max_size
|
||||
self.step = step
|
||||
self._idx = 0
|
||||
|
||||
def _trim(self, trim_size, v, append=None):
|
||||
to_cat = []
|
||||
if trim_size > 0:
|
||||
to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]]
|
||||
else:
|
||||
to_cat = [v]
|
||||
if append is not None:
|
||||
to_cat.append(append)
|
||||
return mx.concatenate(to_cat, axis=2)
|
||||
|
||||
def _temporal_order(self, v):
|
||||
"""
|
||||
Rearrange the cache into temporal order, slicing off the end if unused.
|
||||
"""
|
||||
if self._idx == v.shape[2]:
|
||||
return v
|
||||
elif self._idx < self.offset:
|
||||
return mx.concatenate(
|
||||
[
|
||||
v[..., : self.keep, :],
|
||||
v[..., self._idx :, :],
|
||||
v[..., self.keep : self._idx, :],
|
||||
],
|
||||
axis=2,
|
||||
)
|
||||
else:
|
||||
return v[..., : self._idx, :]
|
||||
|
||||
def _update_concat(self, keys, values):
|
||||
if self.keys is None:
|
||||
self.keys = keys
|
||||
self.values = values
|
||||
else:
|
||||
# Put the keys/values in temporal order to
|
||||
# preserve context
|
||||
self.keys = self._temporal_order(self.keys)
|
||||
self.values = self._temporal_order(self.values)
|
||||
|
||||
# The largest size is self.max_size + S - 1 to ensure
|
||||
# every token gets at least self.max_size context
|
||||
trim_size = self._idx - self.max_size + 1
|
||||
self.keys = self._trim(trim_size, self.keys, keys)
|
||||
self.values = self._trim(trim_size, self.values, values)
|
||||
self.offset += keys.shape[2]
|
||||
self._idx = self.keys.shape[2]
|
||||
return self.keys, self.values
|
||||
|
||||
def _update_in_place(self, keys, values):
|
||||
# May not have hit the max size yet, so potentially
|
||||
# keep growing the cache
|
||||
B, n_kv_heads, S, k_head_dim = keys.shape
|
||||
prev = self.offset
|
||||
if self.keys is None or (
|
||||
prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
|
||||
):
|
||||
v_head_dim = values.shape[3]
|
||||
new_size = min(self.step, self.max_size - prev)
|
||||
k_shape = (B, n_kv_heads, new_size, k_head_dim)
|
||||
v_shape = (B, n_kv_heads, new_size, v_head_dim)
|
||||
new_k = mx.zeros(k_shape, keys.dtype)
|
||||
new_v = mx.zeros(v_shape, values.dtype)
|
||||
if self.keys is not None:
|
||||
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
||||
self.values = mx.concatenate([self.values, new_v], axis=2)
|
||||
else:
|
||||
self.keys, self.values = new_k, new_v
|
||||
self._idx = prev
|
||||
|
||||
# Trim if needed
|
||||
trim_size = self.keys.shape[2] - self.max_size
|
||||
if trim_size > 0:
|
||||
self.keys = self._trim(trim_size, self.keys)
|
||||
self.values = self._trim(trim_size, self.values)
|
||||
self._idx = self.max_size
|
||||
|
||||
# Rotate
|
||||
if self._idx == self.max_size:
|
||||
self._idx = self.keep
|
||||
|
||||
# Assign
|
||||
self.keys[..., self._idx : self._idx + S, :] = keys
|
||||
self.values[..., self._idx : self._idx + S, :] = values
|
||||
self.offset += S
|
||||
self._idx += S
|
||||
|
||||
# If the buffer is not full, slice off the end
|
||||
if self.offset < self.max_size:
|
||||
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||
return self.keys, self.values
|
||||
|
||||
def update_and_fetch(self, keys, values):
|
||||
if keys.shape[2] == 1:
|
||||
return self._update_in_place(keys, values)
|
||||
return self._update_concat(keys, values)
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
if self.offset < self.keys.shape[2]:
|
||||
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||
else:
|
||||
return self.keys, self.values
|
||||
|
||||
@state.setter
|
||||
def state(self, v):
|
||||
self.keys, self.values = v
|
||||
|
||||
@property
|
||||
def meta_state(self):
|
||||
return tuple(
|
||||
map(str, (self.keep, self.max_size, self.step, self.offset, self._idx))
|
||||
)
|
||||
|
||||
@meta_state.setter
|
||||
def meta_state(self, v):
|
||||
self.keep, self.max_size, self.step, self.offset, self._idx = map(
|
||||
int,
|
||||
v,
|
||||
)
|
||||
|
||||
def is_trimmable(self):
|
||||
return self.offset < self.max_size
|
||||
|
||||
def trim(self, n):
|
||||
n = min(self.offset, n)
|
||||
self.offset -= n
|
||||
self._idx -= n
|
||||
return n
|
||||
|
||||
|
||||
class MambaCache(_BaseCache):
|
||||
def __init__(self):
|
||||
self.cache = [None, None]
|
||||
|
||||
def __setitem__(self, idx, value):
|
||||
self.cache[idx] = value
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.cache[idx]
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self.cache
|
||||
|
||||
@state.setter
|
||||
def state(self, v):
|
||||
self.cache = v
|
@ -1,7 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -69,7 +69,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@ -129,7 +129,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
h = self.input_layernorm(x)
|
||||
attn_h = self.self_attn(h, mask, cache)
|
||||
@ -190,11 +190,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -49,7 +49,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
|
||||
qkv = self.Wqkv(x)
|
||||
@ -92,7 +92,7 @@ class NormAttnNorm(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
h = self.attn(self.norm_1(x), mask=mask, cache=cache)
|
||||
x = h + x
|
||||
@ -179,7 +179,7 @@ class DecoderLayer(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r, h = self.norm_attn_norm(x, mask, cache)
|
||||
out = self.ffn(h) + r
|
||||
@ -249,11 +249,3 @@ class Model(nn.Module):
|
||||
experts = [(s, sv.T) for s, sv in experts]
|
||||
new_weights.update(experts)
|
||||
return new_weights
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.d_model // self.args.n_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.attn_config["kv_n_heads"]
|
||||
|
@ -1,10 +1,10 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@ -77,7 +77,7 @@ class DeepseekAttention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, _ = x.shape
|
||||
|
||||
@ -108,8 +108,8 @@ class DeepseekMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: ModelArgs,
|
||||
hidden_size: int | None = None,
|
||||
intermediate_size: int | None = None,
|
||||
hidden_size: Optional[int] = None,
|
||||
intermediate_size: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -188,7 +188,7 @@ class DeepseekDecoderLayer(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
@ -210,7 +210,7 @@ class DeepseekModel(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(x)
|
||||
mask = create_attention_mask(h, cache)
|
||||
@ -235,7 +235,7 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
return self.lm_head(out)
|
||||
@ -256,11 +256,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@ -2,12 +2,12 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ class ModelArgs(BaseModelArgs):
|
||||
max_position_embeddings: int = 2048
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 10000.0
|
||||
rope_scaling: Optional[Dict] = None
|
||||
rope_scaling: Dict = None
|
||||
attention_bias: bool = False
|
||||
|
||||
|
||||
@ -172,12 +172,11 @@ class DeepseekV2Attention(nn.Module):
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
|
||||
if self.config.rope_scaling is not None:
|
||||
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
if mscale_all_dim:
|
||||
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
|
||||
self.scale = self.scale * mscale * mscale
|
||||
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
if mscale_all_dim:
|
||||
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
|
||||
self.scale = self.scale * mscale * mscale
|
||||
|
||||
rope_kwargs = {
|
||||
key: self.config.rope_scaling[key]
|
||||
@ -202,7 +201,7 @@ class DeepseekV2Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@ -347,7 +346,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
@ -370,7 +369,7 @@ class DeepseekV2Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(x)
|
||||
mask = create_attention_mask(h, cache)
|
||||
@ -395,7 +394,7 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
return self.lm_head(out)
|
||||
@ -416,14 +415,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return (
|
||||
self.args.qk_nope_head_dim + self.args.qk_rope_head_dim,
|
||||
self.args.v_head_dim,
|
||||
)
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -60,7 +60,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@ -113,7 +113,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
@ -173,11 +173,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.head_dim
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -64,7 +64,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
@ -135,13 +135,11 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x.astype(mx.float32)), mask, cache)
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + self.post_attention_layernorm(r)
|
||||
r = self.mlp(self.pre_feedforward_layernorm(h).astype(mx.float16)).astype(
|
||||
mx.float32
|
||||
)
|
||||
r = self.mlp(self.pre_feedforward_layernorm(h))
|
||||
out = h + self.post_feedforward_layernorm(r)
|
||||
return out
|
||||
|
||||
@ -200,11 +198,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.head_dim
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -46,7 +46,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@ -100,7 +100,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.attn(self.ln_1(x), mask, cache)
|
||||
h = x + r
|
||||
@ -196,11 +196,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.h
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.n_embd // self.args.n_head
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -57,7 +57,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@ -114,7 +114,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.attn(self.ln_1(x), mask, cache)
|
||||
h = x + r
|
||||
@ -184,11 +184,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.transformer.h
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.n_embd // self.args.n_head
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -60,7 +60,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@ -120,7 +120,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
residual = x
|
||||
# NeoX runs attention and feedforward network in parallel.
|
||||
@ -214,11 +214,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.h
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -116,7 +116,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@ -171,7 +171,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.attention(self.attention_norm(x), mask, cache)
|
||||
h = x + r
|
||||
@ -236,11 +236,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@ -1,12 +1,12 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -171,7 +171,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@ -233,7 +233,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
@ -303,13 +303,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return (
|
||||
self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
|
||||
)
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@ -7,6 +7,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .cache import MambaCache
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -45,21 +46,6 @@ class ModelArgs(BaseModelArgs):
|
||||
self.time_step_rank = math.ceil(self.hidden_size / 16)
|
||||
|
||||
|
||||
class MambaCache:
|
||||
def __init__(self):
|
||||
self.cache = [None, None]
|
||||
|
||||
def __setitem__(self, idx, value):
|
||||
self.cache[idx] = value
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.cache[idx]
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self.cache
|
||||
|
||||
|
||||
class DepthWiseConv1d(nn.Module):
|
||||
def __init__(self, channels, kernel_size, bias=True, padding=0):
|
||||
super().__init__()
|
||||
@ -223,7 +209,7 @@ class Model(nn.Module):
|
||||
weights[k] = v.moveaxis(2, 1)
|
||||
return weights
|
||||
|
||||
def make_cache(self, batch_size: int = 1):
|
||||
def make_cache(self):
|
||||
return [MambaCache() for _ in range(len(self.layers))]
|
||||
|
||||
@property
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -85,7 +85,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
):
|
||||
B, L, _ = x.shape
|
||||
|
||||
@ -135,7 +135,7 @@ class DecoderLayer(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r * (self.scale_depth / np.sqrt(self.num_hidden_layers))
|
||||
@ -205,11 +205,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -66,7 +66,7 @@ class MixtralAttention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@ -138,7 +138,7 @@ class MixtralDecoderLayer(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
@ -215,11 +215,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@ -2,12 +2,12 @@
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -94,7 +94,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, _ = x.shape
|
||||
|
||||
@ -151,7 +151,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
@ -215,13 +215,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return (
|
||||
self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
|
||||
)
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@ -1,8 +1,8 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from sys import exit
|
||||
from typing import Optional, Tuple
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -13,7 +13,7 @@ try:
|
||||
import hf_olmo
|
||||
except ImportError:
|
||||
print("To run olmo install ai2-olmo: pip install ai2-olmo")
|
||||
exit(1)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -68,7 +68,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@ -98,7 +98,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.attend(self.att_norm(x), mask, cache)
|
||||
h = x + r
|
||||
@ -174,11 +174,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.transformer.blocks
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.d_model // self.args.n_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.n_heads
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -80,7 +80,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@ -152,7 +152,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.attn(self.attn_norm(x), mask, cache)
|
||||
h = x + r
|
||||
@ -218,11 +218,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.transformer.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.head_dim
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_kv_heads
|
||||
|
@ -162,19 +162,11 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: mx.array = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
y = self.model(x, cache)
|
||||
return self.lm_head(y)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@ -1,12 +1,12 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
from .su_rope import SuScaledRotaryEmbedding
|
||||
|
||||
|
||||
@ -84,7 +84,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@ -143,7 +143,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
@ -202,11 +202,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@ -3,12 +3,12 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -22,14 +22,14 @@ class ModelArgs(BaseModelArgs):
|
||||
num_attention_heads: int
|
||||
layer_norm_epsilon: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: Optional[int] = None
|
||||
num_key_value_heads: int
|
||||
mup_attn_multiplier: float = 1.0
|
||||
mup_use_scaling: bool = True
|
||||
mup_embedding_multiplier: float = 10.0
|
||||
mup_width_multiplier: float = 8.0
|
||||
rope_embedding_base: float = 1000000
|
||||
rope_position_scale: float = 1.0
|
||||
blocksparse_block_size: Tuple[int] = (64,)
|
||||
blocksparse_block_size: int = 64
|
||||
blocksparse_num_local_blocks: int = 16
|
||||
blocksparse_vert_stride: int = 8
|
||||
|
||||
@ -61,7 +61,6 @@ class Attention(nn.Module):
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
assert args.num_key_value_heads is not None
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
self.n_q_per_kv = n_heads // n_kv_heads
|
||||
|
||||
@ -161,7 +160,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@ -230,7 +229,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
@ -304,16 +303,8 @@ class Model(nn.Module):
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
def sanitize(self, weights):
|
||||
# Remove unused precomputed rotary freqs
|
||||
return {
|
||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||
}
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@ -173,6 +173,7 @@ class PhiMoEModel(nn.Module):
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.args = args
|
||||
self.model = PhiMoEModel(args)
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=True)
|
||||
@ -208,11 +209,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@ -168,8 +168,8 @@ class Model(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache: mx.array = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
mask = create_attention_mask(x, cache)
|
||||
|
||||
y = self.transformer(x, mask, cache)
|
||||
@ -193,11 +193,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.transformer.h
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.model_dim // self.args.num_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_heads
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -62,8 +62,8 @@ class Attention(nn.Module):
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
bsz, q_len, _ = hidden_states.shape
|
||||
|
||||
queries = self.q_proj(hidden_states)
|
||||
@ -127,8 +127,8 @@ class PlamoDecoderLayer(nn.Module):
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> Tuple[Any, ...]:
|
||||
cache: Optional[Any] = None,
|
||||
):
|
||||
# from LlamaDecoder
|
||||
residual = hidden_states
|
||||
|
||||
@ -169,8 +169,8 @@ class PlamoModel(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[List[Union[Tuple[mx.array, mx.array], None]]] = None,
|
||||
) -> Tuple[mx.array, Optional[List[Union[Tuple[mx.array, mx.array], None]]]]:
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = create_attention_mask(h, cache)
|
||||
@ -197,19 +197,11 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
out = self.model(inputs, cache)
|
||||
return self.lm_head(out)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_attention_heads // self.args.n_shared_head
|
||||
|
@ -1,7 +1,6 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -149,19 +148,11 @@ class Model(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache: mx.array = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
y = self.transformer(x, mask, cache)
|
||||
return self.lm_head(y)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.transformer.h
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_attention_heads
|
||||
|
@ -1,12 +1,12 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -70,7 +70,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@ -124,7 +124,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
@ -196,11 +196,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@ -2,12 +2,12 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@ -70,7 +70,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@ -162,7 +162,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
@ -236,11 +236,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@ -7,13 +7,13 @@ from typing import List, Literal, Optional
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
from .cache import MambaCache, RotatingKVCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
attention_bias: bool
|
||||
conv1d_width: int
|
||||
hidden_size: int
|
||||
@ -36,59 +36,6 @@ class ModelArgs(BaseModelArgs):
|
||||
self.block_types = self._block_types
|
||||
|
||||
|
||||
def create_window_causal_mask(N: int, window_size: int):
|
||||
inds = mx.arange(N)
|
||||
linds = inds[:, None]
|
||||
rinds = inds[None]
|
||||
mask = (linds < rinds) | (linds > rinds + window_size)
|
||||
return mask * -1e9
|
||||
|
||||
|
||||
class RecurrentCache:
|
||||
|
||||
def __init__(self):
|
||||
self._cache = (None, None)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self._cache[idx]
|
||||
|
||||
def update(self, conv_state, recurrent_state):
|
||||
self._cache = (conv_state, recurrent_state)
|
||||
|
||||
def state(self):
|
||||
return self._cache
|
||||
|
||||
|
||||
class WindowKVCache:
|
||||
|
||||
def __init__(self, window_size):
|
||||
self.keys = None
|
||||
self.values = None
|
||||
self.offset = 0
|
||||
self.window_size = window_size
|
||||
|
||||
def update_and_fetch(self, keys, values):
|
||||
# TODO consider using rotating buffer here
|
||||
# especially for very long generations
|
||||
def _update(x, v):
|
||||
t = x.shape[2] - self.window_size
|
||||
if t > 0:
|
||||
x = x[..., t:, :]
|
||||
return mx.concatenate([x, v], axis=2)
|
||||
|
||||
self.offset += keys.shape[2]
|
||||
if self.keys is None:
|
||||
self.keys = keys
|
||||
self.values = values
|
||||
else:
|
||||
self.keys = _update(self.keys, keys)
|
||||
self.values = _update(self.values, values)
|
||||
return self.keys, self.values
|
||||
|
||||
def state(self):
|
||||
return self.keys, self.values
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
@ -136,31 +83,22 @@ class Conv1d(nn.Module):
|
||||
kernel_size: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.weight = mx.zeros((kernel_size, channels))
|
||||
self.weight = mx.zeros((channels, kernel_size, 1))
|
||||
self.bias = mx.zeros((channels,))
|
||||
|
||||
def __call__(self, x, cache=None):
|
||||
w = self.weight.T[..., None]
|
||||
kw, groups = self.weight.shape
|
||||
if cache is not None:
|
||||
l = []
|
||||
# Pad the cache if needed
|
||||
if cache.shape[1] < kw - 1:
|
||||
l.append(
|
||||
mx.zeros(
|
||||
(x.shape[0], kw - 1 - cache.shape[1], groups), dtype=x.dtype
|
||||
)
|
||||
)
|
||||
l.extend([cache, x])
|
||||
x = mx.concatenate(l, axis=1)
|
||||
y = (x * w.swapaxes(0, 2)).sum(axis=1, keepdims=True)
|
||||
else:
|
||||
y = mx.conv_general(x, w, padding=([kw - 1], [0]), groups=groups)
|
||||
B, L, C = x.shape
|
||||
groups, K, _ = self.weight.shape
|
||||
|
||||
# The cache is always kw - 1
|
||||
cache = x[:, max(x.shape[1] - kw + 1, 0) :, :]
|
||||
if cache is not None:
|
||||
x = mx.concatenate([cache, x], axis=1)
|
||||
else:
|
||||
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
|
||||
|
||||
y = mx.conv_general(x, self.weight, groups=groups)
|
||||
y = y + self.bias
|
||||
return y, cache
|
||||
|
||||
return y, x[:, -K + 1 :, :]
|
||||
|
||||
|
||||
class RGLRU(nn.Module):
|
||||
@ -269,19 +207,9 @@ class RecurrentBlock(nn.Module):
|
||||
# x branch.
|
||||
x = self.linear_x(x)
|
||||
if cache is None:
|
||||
conv_state, recurrent_state = (None, None)
|
||||
else:
|
||||
conv_state, recurrent_state = cache[0], cache[1]
|
||||
x, conv_state = self.conv_1d(
|
||||
x=x,
|
||||
cache=conv_state,
|
||||
)
|
||||
x, recurrent_state = self.rg_lru(
|
||||
x=x,
|
||||
cache=recurrent_state,
|
||||
)
|
||||
if cache is not None:
|
||||
cache.update(conv_state, recurrent_state)
|
||||
cache = [None, None]
|
||||
x, cache[0] = self.conv_1d(x=x, cache=cache[0])
|
||||
x, cache[1] = self.rg_lru(x=x, cache=cache[1])
|
||||
|
||||
x = x * y
|
||||
x = self.linear_out(x)
|
||||
@ -467,12 +395,14 @@ class Griffin(nn.Module):
|
||||
if self.scale_by_sqrt_dim:
|
||||
x = x * math.sqrt(x.shape[-1])
|
||||
|
||||
mask = None
|
||||
if x.shape[1] > 1:
|
||||
mask = create_window_causal_mask(
|
||||
x.shape[1], self.config.attention_window_size
|
||||
)
|
||||
mask = mask.astype(x.dtype)
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for i, block in enumerate(self.layers):
|
||||
if block.temporal_block_type != "recurrent":
|
||||
mask_cache = [cache[i]]
|
||||
|
||||
mask = create_attention_mask(x, mask_cache)
|
||||
|
||||
for i, block in enumerate(self.layers):
|
||||
x = block(x, mask=mask, cache=cache[i])
|
||||
@ -485,6 +415,7 @@ class Model(nn.Module):
|
||||
def __init__(self, config):
|
||||
self.args = config
|
||||
self.model = Griffin(config)
|
||||
self.model_type = config.model_type
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(self, tokens: mx.array, cache=None) -> mx.array:
|
||||
@ -508,10 +439,9 @@ class Model(nn.Module):
|
||||
return self.model.layers
|
||||
|
||||
def sanitize(self, weights):
|
||||
# Remove unused precomputed rotary freqs
|
||||
for k, v in weights.items():
|
||||
if "conv_1d.weight" in k and v.ndim == 3:
|
||||
weights[k] = v.squeeze(1).T
|
||||
weights[k] = v.moveaxis(2, 1)
|
||||
if "lm_head.weight" not in weights:
|
||||
self.pop("lm_head")
|
||||
return weights
|
||||
@ -520,7 +450,7 @@ class Model(nn.Module):
|
||||
cache = []
|
||||
for layer in self.layers:
|
||||
if layer.temporal_block_type == "recurrent":
|
||||
cache.append(RecurrentCache())
|
||||
cache.append(MambaCache())
|
||||
else:
|
||||
cache.append(WindowKVCache(self.args.attention_window_size))
|
||||
cache.append(RotatingKVCache(max_size=self.args.attention_window_size))
|
||||
return cache
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -198,8 +197,8 @@ class Model(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache: mx.array = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
mask = create_attention_mask(x, cache)
|
||||
y = self.model(x, mask, cache)
|
||||
return self.lm_head(y)
|
||||
@ -207,11 +206,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@ -1,12 +1,12 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -45,7 +45,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@ -100,7 +100,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
@ -164,11 +164,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@ -18,7 +18,7 @@ from mlx.utils import tree_flatten
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
# Local imports
|
||||
from .models.base import KVCache, RotatingKVCache
|
||||
from .models import base, cache
|
||||
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
|
||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||
from .tuner.utils import dequantize as dequantize_model
|
||||
@ -124,26 +124,6 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float)
|
||||
return logits
|
||||
|
||||
|
||||
def make_kv_caches(
|
||||
model: nn.Module, max_kv_size: Optional[int] = None
|
||||
) -> List[Union[KVCache, RotatingKVCache]]:
|
||||
if hasattr(model, "make_cache"):
|
||||
return model.make_cache()
|
||||
|
||||
kv_heads = (
|
||||
[model.n_kv_heads] * len(model.layers)
|
||||
if isinstance(model.n_kv_heads, int)
|
||||
else model.n_kv_heads
|
||||
)
|
||||
if max_kv_size is not None:
|
||||
return [
|
||||
RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4)
|
||||
for n in kv_heads
|
||||
]
|
||||
else:
|
||||
return [KVCache(model.head_dim, n) for n in kv_heads]
|
||||
|
||||
|
||||
def generate_step(
|
||||
prompts: mx.array,
|
||||
model: nn.Module,
|
||||
@ -155,7 +135,7 @@ def generate_step(
|
||||
min_tokens_to_keep: int = 1,
|
||||
prefill_step_size: int = 512,
|
||||
max_kv_size: Optional[int] = None,
|
||||
cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
||||
prompt_cache: Optional[Any] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
@ -180,6 +160,8 @@ def generate_step(
|
||||
prefill_step_size (int): Step size for processing the prompt.
|
||||
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
||||
entries (except the first 4 tokens) will be overwritten.
|
||||
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
||||
provided, the cache will be updated in place.
|
||||
logit_bias (dictionary, optional): Additive logit bias.
|
||||
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||
A list of functions that take tokens and logits and return the processed
|
||||
@ -243,20 +225,13 @@ def generate_step(
|
||||
tokens = None
|
||||
|
||||
# Create the KV cache for generation
|
||||
cache = make_kv_caches(model, max_kv_size)
|
||||
|
||||
if cache_history is not None:
|
||||
if len(cache_history) != len(cache):
|
||||
raise ValueError("Wrong number of layers in the cache history")
|
||||
|
||||
# Set the history in the cache objects and evaluate them to prepare for
|
||||
# generation.
|
||||
for c, h in zip(cache, cache_history):
|
||||
c.update_and_fetch(h[0], h[1])
|
||||
mx.eval([c.state for c in cache])
|
||||
if prompt_cache is None:
|
||||
prompt_cache = cache.make_prompt_cache(model, max_kv_size)
|
||||
elif len(prompt_cache) != len(model.layers):
|
||||
raise ValueError("Wrong number of layers in the prompt cache.")
|
||||
|
||||
def _step(y):
|
||||
logits = model(y, cache=cache)
|
||||
logits = model(y, cache=prompt_cache)
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
if logits_processor:
|
||||
@ -270,7 +245,7 @@ def generate_step(
|
||||
return y, logprobs
|
||||
|
||||
while y.shape[1] > prefill_step_size:
|
||||
model(y[:, :prefill_step_size], cache=cache)
|
||||
model(y[:, :prefill_step_size], cache=prompt_cache)
|
||||
mx.eval([c.state for c in cache])
|
||||
y = y[:, prefill_step_size:]
|
||||
|
||||
@ -312,9 +287,9 @@ def stream_generate(
|
||||
detokenizer = tokenizer.detokenizer
|
||||
|
||||
detokenizer.reset()
|
||||
for (token, _), n in zip(
|
||||
generate_step(prompt_tokens[None], model, **kwargs),
|
||||
for _, (token, _) in zip(
|
||||
range(max_tokens),
|
||||
generate_step(prompt_tokens, model, **kwargs),
|
||||
):
|
||||
token = token.item()
|
||||
if token == tokenizer.eos_token_id:
|
||||
@ -365,9 +340,9 @@ def generate(
|
||||
tic = time.perf_counter()
|
||||
detokenizer.reset()
|
||||
|
||||
for (token, logprobs), n in zip(
|
||||
generate_step(prompt_tokens[None], model, **kwargs),
|
||||
for n, (token, logprobs) in zip(
|
||||
range(max_tokens),
|
||||
generate_step(prompt_tokens[None], model, **kwargs),
|
||||
):
|
||||
token = token.item()
|
||||
if n == 0:
|
||||
|
@ -32,6 +32,7 @@ setup(
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"mlx_lm.cache_prompt = mlx_lm.cache_prompt:main",
|
||||
"mlx_lm.chat = mlx_lm.chat:main",
|
||||
"mlx_lm.convert = mlx_lm.convert:main",
|
||||
"mlx_lm.fuse = mlx_lm.fuse:main",
|
||||
"mlx_lm.generate = mlx_lm.generate:main",
|
||||
|
@ -1,17 +1,15 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_map
|
||||
from mlx_lm.models.base import KVCache, RotatingKVCache
|
||||
from mlx_lm.utils import make_kv_caches
|
||||
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
|
||||
|
||||
|
||||
class TestModels(unittest.TestCase):
|
||||
|
||||
def test_kv_cache(self):
|
||||
cache = KVCache(32, 4)
|
||||
cache = KVCache()
|
||||
|
||||
k = mx.ones((1, 4, 1, 32), mx.float16)
|
||||
v = mx.ones((1, 4, 1, 32), mx.float16)
|
||||
@ -32,7 +30,7 @@ class TestModels(unittest.TestCase):
|
||||
|
||||
def test_rotating_kv_cache(self):
|
||||
b, h, d = 1, 2, 32
|
||||
cache = RotatingKVCache(d, h, max_size=8, step=4)
|
||||
cache = RotatingKVCache(max_size=8, step=4)
|
||||
|
||||
k = mx.random.uniform(shape=(b, h, 2, d))
|
||||
v = mx.random.uniform(shape=(b, h, 2, d))
|
||||
@ -65,7 +63,7 @@ class TestModels(unittest.TestCase):
|
||||
idx %= 8
|
||||
|
||||
# Try with nonzero keep
|
||||
cache = RotatingKVCache(d, h, max_size=8, step=4, keep=2)
|
||||
cache = RotatingKVCache(max_size=8, step=4, keep=2)
|
||||
|
||||
# Check a large update
|
||||
k = mx.random.uniform(shape=(b, h, 20, d))
|
||||
@ -88,6 +86,46 @@ class TestModels(unittest.TestCase):
|
||||
if idx >= 8:
|
||||
idx = 2
|
||||
|
||||
def test_rotating_kv_cache_chat_mode(self):
|
||||
# Test that the rotating kv cache can handle
|
||||
# alternating prompt/prefill with generation
|
||||
d = 4
|
||||
h = 2
|
||||
cache = RotatingKVCache(max_size=18, step=4)
|
||||
|
||||
x = mx.random.uniform(shape=(1, h, 8, d))
|
||||
k, v = cache.update_and_fetch(x, x)
|
||||
self.assertEqual(k.shape[2], 8)
|
||||
self.assertEqual(cache.offset, 8)
|
||||
|
||||
x = mx.random.uniform(shape=(1, h, 1, d))
|
||||
k, v = cache.update_and_fetch(x, x)
|
||||
self.assertEqual(k.shape[2], 9)
|
||||
self.assertEqual(cache.offset, 9)
|
||||
self.assertTrue(mx.allclose(x, k[..., 8:9, :]))
|
||||
|
||||
x = mx.random.uniform(shape=(1, h, 2, d))
|
||||
k, v = cache.update_and_fetch(x, x)
|
||||
self.assertEqual(k.shape[2], 11)
|
||||
self.assertEqual(cache.offset, 11)
|
||||
self.assertTrue(mx.allclose(x, k[..., 9:11, :]))
|
||||
|
||||
x = mx.random.uniform(shape=(1, h, 3, d))
|
||||
k, v = cache.update_and_fetch(x, x)
|
||||
self.assertEqual(k.shape[2], 14)
|
||||
self.assertEqual(cache.offset, 14)
|
||||
self.assertTrue(mx.allclose(x, k[..., 11:14, :]))
|
||||
|
||||
x = mx.random.uniform(shape=(1, h, 6, d))
|
||||
k, v = cache.update_and_fetch(x, x)
|
||||
self.assertEqual(cache.offset, 20)
|
||||
self.assertTrue(mx.allclose(x, k[..., -6:, :]))
|
||||
|
||||
x = mx.random.uniform(shape=(1, h, 2, d))
|
||||
k, v = cache.update_and_fetch(x, x)
|
||||
self.assertEqual(cache.offset, 22)
|
||||
self.assertTrue(mx.allclose(x, k[..., -2:, :]))
|
||||
|
||||
def model_test_runner(self, model, model_type, vocab_size, num_layers):
|
||||
|
||||
self.assertEqual(len(model.layers), num_layers)
|
||||
@ -101,7 +139,7 @@ class TestModels(unittest.TestCase):
|
||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
||||
cache = make_kv_caches(model)
|
||||
cache = make_prompt_cache(model)
|
||||
outputs = model(inputs, cache)
|
||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
@ -549,6 +587,179 @@ class TestModels(unittest.TestCase):
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_deepseek(self):
|
||||
from mlx_lm.models import deepseek
|
||||
|
||||
args = deepseek.ModelArgs(
|
||||
model_type="deepseek",
|
||||
vocab_size=1024,
|
||||
hidden_size=128,
|
||||
intermediate_size=256,
|
||||
moe_intermediate_size=256,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=8,
|
||||
num_key_value_heads=4,
|
||||
)
|
||||
model = deepseek.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_deepseek_v2(self):
|
||||
from mlx_lm.models import deepseek_v2
|
||||
|
||||
args = deepseek_v2.ModelArgs(
|
||||
model_type="deepseek_v2",
|
||||
vocab_size=1024,
|
||||
hidden_size=128,
|
||||
intermediate_size=256,
|
||||
moe_intermediate_size=256,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
kv_lora_rank=4,
|
||||
q_lora_rank=4,
|
||||
qk_rope_head_dim=32,
|
||||
v_head_dim=16,
|
||||
qk_nope_head_dim=32,
|
||||
rope_scaling={
|
||||
"beta_fast": 32,
|
||||
"beta_slow": 1,
|
||||
"factor": 40,
|
||||
"mscale": 1.0,
|
||||
"mscale_all_dim": 1.0,
|
||||
"original_max_position_embeddings": 4096,
|
||||
"type": "yarn",
|
||||
},
|
||||
)
|
||||
model = deepseek_v2.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_gemma2(self):
|
||||
from mlx_lm.models import gemma2
|
||||
|
||||
args = gemma2.ModelArgs(
|
||||
model_type="gemma2",
|
||||
hidden_size=128,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=256,
|
||||
num_attention_heads=2,
|
||||
head_dim=32,
|
||||
rms_norm_eps=1e-4,
|
||||
vocab_size=1024,
|
||||
num_key_value_heads=2,
|
||||
)
|
||||
model = gemma2.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_gpt_bigcode(self):
|
||||
from mlx_lm.models import gpt_bigcode
|
||||
|
||||
args = gpt_bigcode.ModelArgs(
|
||||
model_type="gpt_bigcode",
|
||||
n_embd=128,
|
||||
n_layer=128,
|
||||
n_inner=256,
|
||||
n_head=4,
|
||||
n_positions=1000,
|
||||
layer_norm_epsilon=1e-5,
|
||||
vocab_size=1024,
|
||||
)
|
||||
model = gpt_bigcode.Model(args)
|
||||
self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layer)
|
||||
|
||||
def test_nemotron(self):
|
||||
from mlx_lm.models import nemotron
|
||||
|
||||
args = nemotron.ModelArgs(
|
||||
model_type="nemotron",
|
||||
hidden_size=128,
|
||||
hidden_act="gelu",
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=256,
|
||||
num_attention_heads=4,
|
||||
norm_eps=1e-5,
|
||||
vocab_size=1024,
|
||||
num_key_value_heads=2,
|
||||
)
|
||||
model = nemotron.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_phi3small(self):
|
||||
from mlx_lm.models import phi3small
|
||||
|
||||
args = phi3small.ModelArgs(
|
||||
model_type="phi3small",
|
||||
hidden_size=128,
|
||||
dense_attention_every_n_layers=2,
|
||||
ff_intermediate_size=256,
|
||||
gegelu_limit=1.0,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
layer_norm_epsilon=1e-4,
|
||||
vocab_size=1000,
|
||||
)
|
||||
model = phi3small.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_phimoe(self):
|
||||
from mlx_lm.models import phimoe
|
||||
|
||||
args = phimoe.ModelArgs(
|
||||
model_type="phimoe",
|
||||
vocab_size=320,
|
||||
hidden_size=128,
|
||||
intermediate_size=256,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=4,
|
||||
rope_scaling={
|
||||
"long_factor": [1.0] * 16,
|
||||
"long_mscale": 1.243163121016122,
|
||||
"original_max_position_embeddings": 4096,
|
||||
"short_factor": [1.0] * 16,
|
||||
"short_mscale": 1.243163121016122,
|
||||
"type": "longrope",
|
||||
},
|
||||
)
|
||||
model = phimoe.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_recurrent_gemma(self):
|
||||
from mlx_lm.models import recurrent_gemma
|
||||
|
||||
args = recurrent_gemma.ModelArgs(
|
||||
model_type="recurrent_gemma",
|
||||
hidden_size=128,
|
||||
attention_bias=False,
|
||||
conv1d_width=3,
|
||||
intermediate_size=256,
|
||||
logits_soft_cap=1.0,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=4,
|
||||
num_key_value_heads=2,
|
||||
rms_norm_eps=1e-4,
|
||||
rope_theta=1000,
|
||||
attention_window_size=1024,
|
||||
vocab_size=1000,
|
||||
block_types=["recurrent", "recurrent", "attention"],
|
||||
)
|
||||
model = recurrent_gemma.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
220
llms/tests/test_prompt_cache.py
Normal file
220
llms/tests/test_prompt_cache.py
Normal file
@ -0,0 +1,220 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.models.cache import (
|
||||
KVCache,
|
||||
MambaCache,
|
||||
RotatingKVCache,
|
||||
load_prompt_cache,
|
||||
make_prompt_cache,
|
||||
save_prompt_cache,
|
||||
trim_prompt_cache,
|
||||
)
|
||||
from mlx_lm.utils import generate_step, load
|
||||
|
||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
|
||||
|
||||
class TestPromptCache(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.test_dir_fid = tempfile.TemporaryDirectory()
|
||||
cls.test_dir = cls.test_dir_fid.name
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.test_dir_fid.cleanup()
|
||||
|
||||
def test_save_load(self):
|
||||
cache = [KVCache() for _ in range(4)]
|
||||
for c in cache:
|
||||
x = mx.random.uniform(shape=(1, 8, 10, 4))
|
||||
c.update_and_fetch(x, x)
|
||||
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
|
||||
save_prompt_cache(cache_file, cache)
|
||||
loaded_cache = load_prompt_cache(cache_file)
|
||||
self.assertTrue(len(cache), len(loaded_cache))
|
||||
for c, lc in zip(cache, loaded_cache):
|
||||
self.assertEqual(c.offset, lc.offset)
|
||||
self.assertTrue(mx.array_equal(c.state[0], lc.state[0]))
|
||||
self.assertTrue(mx.array_equal(c.state[1], lc.state[1]))
|
||||
|
||||
# Test with metadata
|
||||
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
|
||||
metadata = {"a": "b", "c": "d"}
|
||||
save_prompt_cache(cache_file, cache, metadata)
|
||||
_, loaded_metadata = load_prompt_cache(cache_file, return_metadata=True)
|
||||
self.assertEqual(metadata, loaded_metadata)
|
||||
|
||||
def test_save_load_rotating_cache(self):
|
||||
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
|
||||
|
||||
# Test with rotating cache
|
||||
cache = [RotatingKVCache(max_size=8, keep=2) for _ in range(4)]
|
||||
for c in cache:
|
||||
x = mx.random.uniform(shape=(1, 8, 10, 4))
|
||||
c.update_and_fetch(x, x)
|
||||
|
||||
save_prompt_cache(cache_file, cache)
|
||||
loaded_cache = load_prompt_cache(cache_file)
|
||||
self.assertTrue(len(cache), len(loaded_cache))
|
||||
for c, lc in zip(cache, loaded_cache):
|
||||
self.assertEqual(c.offset, lc.offset)
|
||||
self.assertEqual(c.keep, lc.keep)
|
||||
self.assertEqual(c.max_size, lc.max_size)
|
||||
self.assertEqual(c.step, lc.step)
|
||||
self.assertTrue(mx.array_equal(c.state[0], lc.state[0]))
|
||||
self.assertTrue(mx.array_equal(c.state[1], lc.state[1]))
|
||||
|
||||
# Do a couple single token updates to get a rotation
|
||||
for _ in range(2):
|
||||
for c in cache:
|
||||
x = mx.random.uniform(shape=(1, 8, 1, 4))
|
||||
c.update_and_fetch(x, x)
|
||||
|
||||
save_prompt_cache(cache_file, cache)
|
||||
loaded_cache = load_prompt_cache(cache_file)
|
||||
|
||||
for c, lc in zip(cache, loaded_cache):
|
||||
x = mx.random.uniform(shape=(1, 8, 1, 4))
|
||||
k, v = c.update_and_fetch(x, x)
|
||||
lk, lv = lc.update_and_fetch(x, x)
|
||||
self.assertEqual(c.offset, lc.offset)
|
||||
self.assertTrue(mx.array_equal(k, lk))
|
||||
self.assertTrue(mx.array_equal(v, lv))
|
||||
|
||||
def test_save_load_mixed_cache(self):
|
||||
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
|
||||
|
||||
cache = [MambaCache(), KVCache(), RotatingKVCache(8), MambaCache()]
|
||||
for c in cache:
|
||||
if isinstance(c, MambaCache):
|
||||
c[0] = mx.random.uniform(shape=(4, 4, 4))
|
||||
c[1] = mx.random.uniform(shape=(4, 4, 4))
|
||||
else:
|
||||
x = mx.random.uniform(shape=(4, 4, 7, 4))
|
||||
y = mx.random.uniform(shape=(4, 4, 7, 4))
|
||||
c.update_and_fetch(x, y)
|
||||
|
||||
save_prompt_cache(cache_file, cache)
|
||||
loaded_cache = load_prompt_cache(cache_file)
|
||||
for c, lc in zip(cache, loaded_cache):
|
||||
if isinstance(c, MambaCache):
|
||||
self.assertTrue(mx.array_equal(c[0], lc[0]))
|
||||
self.assertTrue(mx.array_equal(c[1], lc[1]))
|
||||
else:
|
||||
x = mx.random.uniform(shape=(4, 4, 1, 4))
|
||||
y = mx.random.uniform(shape=(4, 4, 1, 4))
|
||||
k, v = c.update_and_fetch(x, y)
|
||||
lk, lv = lc.update_and_fetch(x, y)
|
||||
self.assertEqual(c.offset, lc.offset)
|
||||
self.assertTrue(mx.array_equal(k, lk))
|
||||
self.assertTrue(mx.array_equal(v, lv))
|
||||
|
||||
def test_cache_with_generate(self):
|
||||
model, tokenizer = load(HF_MODEL_PATH)
|
||||
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
|
||||
results = zip(range(4), generate_step(prompt, model))
|
||||
toks, all_logits = zip(*(r[1] for r in results))
|
||||
|
||||
prompt_cache = make_prompt_cache(model)
|
||||
i = 0
|
||||
for _, (tok, logits) in zip(
|
||||
range(2), generate_step(prompt, model, prompt_cache=prompt_cache)
|
||||
):
|
||||
self.assertEqual(tok, toks[i])
|
||||
self.assertTrue(mx.allclose(logits, all_logits[i]))
|
||||
i += 1
|
||||
|
||||
for _, (tok, logits) in zip(
|
||||
range(1),
|
||||
generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache),
|
||||
):
|
||||
i += 1
|
||||
self.assertEqual(tok, toks[i])
|
||||
self.assertTrue(mx.allclose(logits, all_logits[i]))
|
||||
|
||||
def test_trim_cache(self):
|
||||
cache = [KVCache() for _ in range(2)]
|
||||
for c in cache:
|
||||
x = mx.random.uniform(shape=(1, 8, 10, 4))
|
||||
c.update_and_fetch(x, x)
|
||||
|
||||
# Trim
|
||||
num_trimmed = trim_prompt_cache(cache, 7)
|
||||
self.assertEqual(num_trimmed, 7)
|
||||
|
||||
# Trim more tokens than remain
|
||||
num_trimmed = trim_prompt_cache(cache, 4)
|
||||
self.assertEqual(num_trimmed, 3)
|
||||
|
||||
# Can't trim mamba cache
|
||||
cache = [MambaCache() for _ in range(2)]
|
||||
for c in cache:
|
||||
c.state = mx.zeros((5, 5))
|
||||
num_trimmed = trim_prompt_cache(cache, 7)
|
||||
self.assertEqual(num_trimmed, 0)
|
||||
|
||||
# All cache's have to be trimmable
|
||||
cache = [MambaCache(), KVCache()]
|
||||
cache[0].state = mx.zeros((5, 5))
|
||||
x = mx.random.uniform(shape=(1, 8, 10, 4))
|
||||
cache[1].update_and_fetch(x, x)
|
||||
num_trimmed = trim_prompt_cache(cache, 1)
|
||||
self.assertEqual(num_trimmed, 0)
|
||||
|
||||
cache = [RotatingKVCache(max_size=6) for _ in range(2)]
|
||||
for c in cache:
|
||||
x = mx.random.uniform(shape=(1, 8, 5, 4))
|
||||
c.update_and_fetch(x, x)
|
||||
|
||||
num_trimmed = trim_prompt_cache(cache, 4)
|
||||
self.assertEqual(num_trimmed, 4)
|
||||
|
||||
# Can't trim fixed-size KV cache after processing
|
||||
# more than max_kv_size tokens
|
||||
for c in cache:
|
||||
x = mx.random.uniform(shape=(1, 8, 10, 4))
|
||||
c.update_and_fetch(x, x)
|
||||
|
||||
num_trimmed = trim_prompt_cache(cache, 4)
|
||||
self.assertEqual(num_trimmed, 0)
|
||||
|
||||
def test_trim_cache_with_generate(self):
|
||||
model, tokenizer = load(HF_MODEL_PATH)
|
||||
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
|
||||
|
||||
prompt_cache = make_prompt_cache(model)
|
||||
|
||||
# Generate one token so we process the full prompt
|
||||
last_tok, _ = next(generate_step(prompt, model, prompt_cache=prompt_cache))
|
||||
last_tok = mx.array([last_tok])
|
||||
|
||||
# Generate two more tokens
|
||||
results = zip(
|
||||
range(2), generate_step(last_tok, model, prompt_cache=prompt_cache)
|
||||
)
|
||||
toks, all_logits = zip(*(r[1] for r in results))
|
||||
|
||||
# To get back to the cache just after processing the prompt,
|
||||
# trim by 3 tokens
|
||||
trim_prompt_cache(prompt_cache, 3)
|
||||
|
||||
# Generate the same thing again
|
||||
results = zip(
|
||||
range(2), generate_step(last_tok, model, prompt_cache=prompt_cache)
|
||||
)
|
||||
second_toks, second_all_logits = zip(*(r[1] for r in results))
|
||||
self.assertEqual(toks, second_toks)
|
||||
self.assertTrue(
|
||||
all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits))
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user