mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
More cache improvements (#1015)
* fix rotating kv cache for chat use case * reorg + fixes to caching, unify prompt caching across types and use cases for e.g. caching during a chat * nit in chat * fix tests * fix tests * fix tests * docs * chat command * comments + docs * Define meta_state on all Cache implementations * fixes + trim_prompt_cache api * fix default model --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
parent
9bc53fc210
commit
fca087be49
3
.gitignore
vendored
3
.gitignore
vendored
@ -6,6 +6,9 @@ __pycache__/
|
|||||||
# C extensions
|
# C extensions
|
||||||
*.so
|
*.so
|
||||||
|
|
||||||
|
# Vim
|
||||||
|
*.swp
|
||||||
|
|
||||||
# Distribution / packaging
|
# Distribution / packaging
|
||||||
.Python
|
.Python
|
||||||
build/
|
build/
|
||||||
|
@ -20,6 +20,31 @@ The `mlx-lm` package also has:
|
|||||||
- [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md)
|
- [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md)
|
||||||
- [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md)
|
- [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md)
|
||||||
|
|
||||||
|
### Quick Start
|
||||||
|
|
||||||
|
To generate text with an LLM use:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mlx_lm.generate --prompt "Hi!"
|
||||||
|
```
|
||||||
|
|
||||||
|
To chat with an LLM use:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mlx_lm.chat
|
||||||
|
```
|
||||||
|
|
||||||
|
This will give you a chat REPL that you can use to interact with the LLM. The
|
||||||
|
chat context is preserved during the lifetime of the REPL.
|
||||||
|
|
||||||
|
Commands in `mlx-lm` typically take command line options which let you specify
|
||||||
|
the model, sampling parameters, and more. Use `-h` to see a list of available
|
||||||
|
options for a command, e.g.:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mlx_lm.generate -h
|
||||||
|
```
|
||||||
|
|
||||||
### Python API
|
### Python API
|
||||||
|
|
||||||
You can use `mlx-lm` as a module:
|
You can use `mlx-lm` as a module:
|
||||||
@ -138,7 +163,7 @@ mlx_lm.convert \
|
|||||||
|
|
||||||
### Long Prompts and Generations
|
### Long Prompts and Generations
|
||||||
|
|
||||||
MLX LM has some tools to scale efficiently to long prompts and generations:
|
`mlx-lm` has some tools to scale efficiently to long prompts and generations:
|
||||||
|
|
||||||
- A rotating fixed-size key-value cache.
|
- A rotating fixed-size key-value cache.
|
||||||
- Prompt caching
|
- Prompt caching
|
||||||
@ -155,14 +180,14 @@ different queries. To cache a prompt use `mlx_lm.cache_prompt`. For example:
|
|||||||
cat prompt.txt | mlx_lm.cache_prompt \
|
cat prompt.txt | mlx_lm.cache_prompt \
|
||||||
--model mistralai/Mistral-7B-Instruct-v0.3 \
|
--model mistralai/Mistral-7B-Instruct-v0.3 \
|
||||||
--prompt - \
|
--prompt - \
|
||||||
--kv-cache-file mistral_prompt.safetensors
|
--prompt-cache-file mistral_prompt.safetensors
|
||||||
```
|
```
|
||||||
|
|
||||||
Then use the cached prompt with `mlx_lm.generate`:
|
Then use the cached prompt with `mlx_lm.generate`:
|
||||||
|
|
||||||
```
|
```
|
||||||
mlx_lm.generate \
|
mlx_lm.generate \
|
||||||
--kv-cache-file mistral_prompt.safetensors \
|
--prompt-cache-file mistral_prompt.safetensors \
|
||||||
--prompt "\nSummarize the above text."
|
--prompt "\nSummarize the above text."
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -170,9 +195,15 @@ The cached prompt is treated as a prefix to the supplied prompt. Also notice
|
|||||||
when using a cached prompt, the model to use is read from the cache and need
|
when using a cached prompt, the model to use is read from the cache and need
|
||||||
not be supplied explicitly.
|
not be supplied explicitly.
|
||||||
|
|
||||||
|
Prompt caching can also be used in the Python API in order to to avoid
|
||||||
|
recomputing the prompt. This is useful in multi-turn dialogues or across
|
||||||
|
requests that use the same context. See the
|
||||||
|
[example](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/examples/chat.py)
|
||||||
|
for more usage details.
|
||||||
|
|
||||||
### Supported Models
|
### Supported Models
|
||||||
|
|
||||||
MLX LM supports thousands of Hugging Face format LLMs. If the model you want to
|
`mlx-lm` supports thousands of Hugging Face format LLMs. If the model you want to
|
||||||
run is not supported, file an
|
run is not supported, file an
|
||||||
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
|
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
|
||||||
submit a pull request.
|
submit a pull request.
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
__version__ = "0.18.2"
|
__version__ = "0.19.1"
|
||||||
|
@ -7,13 +7,14 @@ import time
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
from .utils import load, make_kv_caches
|
from .models.cache import make_prompt_cache, save_prompt_cache
|
||||||
|
from .utils import load
|
||||||
|
|
||||||
|
|
||||||
def setup_arg_parser():
|
def setup_arg_parser():
|
||||||
"""Set up and return the argument parser."""
|
"""Set up and return the argument parser."""
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Cache the KV cache of a prompt to be reused with mlx_lm.generate"
|
description="Cache the state of a prompt to be reused with mlx_lm.generate"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
@ -60,7 +61,9 @@ def setup_arg_parser():
|
|||||||
help="Set the maximum key-value cache size",
|
help="Set the maximum key-value cache size",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--kv-cache-file", help="The file to save the KV caches in", required=True
|
"--prompt-cache-file",
|
||||||
|
help="The file to save the prompt cache in",
|
||||||
|
required=True,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prompt",
|
"--prompt",
|
||||||
@ -115,7 +118,7 @@ def main():
|
|||||||
else:
|
else:
|
||||||
prompt = args.prompt
|
prompt = args.prompt
|
||||||
|
|
||||||
cache = make_kv_caches(model, args.max_kv_size)
|
cache = make_prompt_cache(model, args.max_kv_size)
|
||||||
y = mx.array(tokenizer.encode(prompt))
|
y = mx.array(tokenizer.encode(prompt))
|
||||||
|
|
||||||
# Process the prompt
|
# Process the prompt
|
||||||
@ -137,16 +140,12 @@ def main():
|
|||||||
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
|
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
|
||||||
|
|
||||||
print("Saving...")
|
print("Saving...")
|
||||||
cache_dict = {}
|
|
||||||
for i, c in enumerate(cache):
|
|
||||||
cache_dict[f"{i}_keys"] = c.state[0][..., : c.offset, :]
|
|
||||||
cache_dict[f"{i}_values"] = c.state[1][..., : c.offset, :]
|
|
||||||
metadata = {}
|
metadata = {}
|
||||||
metadata["model"] = args.model
|
metadata["model"] = args.model
|
||||||
metadata["chat_template"] = tokenizer.chat_template
|
metadata["chat_template"] = tokenizer.chat_template
|
||||||
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
|
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
|
||||||
metadata["max_kv_size"] = str(args.max_kv_size)
|
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
|
||||||
mx.save_safetensors(args.kv_cache_file, cache_dict, metadata)
|
save_prompt_cache(args.prompt_cache_file, cache, metadata)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
82
llms/mlx_lm/chat.py
Normal file
82
llms/mlx_lm/chat.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
from .models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache
|
||||||
|
from .utils import load, stream_generate
|
||||||
|
|
||||||
|
DEFAULT_TEMP = 0.0
|
||||||
|
DEFAULT_TOP_P = 1.0
|
||||||
|
DEFAULT_SEED = 0
|
||||||
|
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||||
|
|
||||||
|
|
||||||
|
def setup_arg_parser():
|
||||||
|
"""Set up and return the argument parser."""
|
||||||
|
parser = argparse.ArgumentParser(description="Chat with an LLM")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
help="The path to the local model directory or Hugging Face repo.",
|
||||||
|
default=DEFAULT_MODEL,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--adapter-path",
|
||||||
|
type=str,
|
||||||
|
help="Optional path for the trained adapter weights and config.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-kv-size",
|
||||||
|
type=int,
|
||||||
|
help="Set the maximum key-value cache size",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = setup_arg_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
mx.random.seed(args.seed)
|
||||||
|
|
||||||
|
model, tokenizer = load(
|
||||||
|
args.model,
|
||||||
|
adapter_path=args.adapter_path,
|
||||||
|
tokenizer_config={"trust_remote_code": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"[INFO] Starting chat sessiong with {args.model}. To exit, enter 'q'.")
|
||||||
|
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
||||||
|
while True:
|
||||||
|
query = input(">> ")
|
||||||
|
if query == "q":
|
||||||
|
break
|
||||||
|
messages = [{"role": "user", "content": query}]
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
for response in stream_generate(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
prompt,
|
||||||
|
temp=args.temp,
|
||||||
|
top_p=args.top_p,
|
||||||
|
prompt_cache=prompt_cache,
|
||||||
|
):
|
||||||
|
print(response, flush=True, end="")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
53
llms/mlx_lm/examples/chat.py
Normal file
53
llms/mlx_lm/examples/chat.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
"""
|
||||||
|
An example of a multi-turn chat with prompt caching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from mlx_lm import generate, load
|
||||||
|
from mlx_lm.models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache
|
||||||
|
|
||||||
|
model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")
|
||||||
|
|
||||||
|
# Make the initial prompt cache for the model
|
||||||
|
prompt_cache = make_prompt_cache(model)
|
||||||
|
|
||||||
|
# User turn
|
||||||
|
prompt = "Hi my name is <Name>."
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assistant response
|
||||||
|
response = generate(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
prompt=prompt,
|
||||||
|
verbose=True,
|
||||||
|
temp=0.0,
|
||||||
|
prompt_cache=prompt_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
# User turn
|
||||||
|
prompt = "What's my name?"
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assistant response
|
||||||
|
response = generate(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
prompt=prompt,
|
||||||
|
verbose=True,
|
||||||
|
temp=0.0,
|
||||||
|
prompt_cache=prompt_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the prompt cache to disk to reuse it at a later time
|
||||||
|
save_prompt_cache("mistral_prompt.safetensors", prompt_cache)
|
||||||
|
|
||||||
|
# Load the prompt cache from disk
|
||||||
|
prompt_cache = load_prompt_cache("mistral_prompt.safetensors")
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
from mlx_lm import generate, load
|
from mlx_lm import generate, load
|
||||||
|
|
||||||
# Specify the checkpoint
|
# Specify the checkpoint
|
||||||
|
@ -6,13 +6,15 @@ import sys
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
|
from .models.cache import load_prompt_cache
|
||||||
from .utils import generate, load
|
from .utils import generate, load
|
||||||
|
|
||||||
DEFAULT_PROMPT = "hello"
|
DEFAULT_PROMPT = "hello"
|
||||||
DEFAULT_MAX_TOKENS = 100
|
DEFAULT_MAX_TOKENS = 100
|
||||||
DEFAULT_TEMP = 0.6
|
DEFAULT_TEMP = 0.0
|
||||||
DEFAULT_TOP_P = 1.0
|
DEFAULT_TOP_P = 1.0
|
||||||
DEFAULT_SEED = 0
|
DEFAULT_SEED = 0
|
||||||
|
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||||
|
|
||||||
|
|
||||||
def str2bool(string):
|
def str2bool(string):
|
||||||
@ -25,7 +27,11 @@ def setup_arg_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
type=str,
|
type=str,
|
||||||
help="The path to the local model directory or Hugging Face repo.",
|
help=(
|
||||||
|
"The path to the local model directory or Hugging Face repo. "
|
||||||
|
f"If no model is specified, then {DEFAULT_MODEL} is used."
|
||||||
|
),
|
||||||
|
default=None,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--adapter-path",
|
"--adapter-path",
|
||||||
@ -96,7 +102,7 @@ def setup_arg_parser():
|
|||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--kv-cache-file",
|
"--prompt-cache-file",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="A file containing saved KV caches to avoid recomputing them",
|
help="A file containing saved KV caches to avoid recomputing them",
|
||||||
@ -131,24 +137,6 @@ def colorprint_by_t0(s, t0):
|
|||||||
colorprint(color, s)
|
colorprint(color, s)
|
||||||
|
|
||||||
|
|
||||||
def load_kv_cache_from_file(kv_cache_file):
|
|
||||||
if kv_cache_file is None:
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
kv_cache, metadata = mx.load(kv_cache_file, return_metadata=True)
|
|
||||||
cache_per_layer = {}
|
|
||||||
for k, x in kv_cache.items():
|
|
||||||
layer, kv_type = k.split("_")
|
|
||||||
if layer not in cache_per_layer:
|
|
||||||
cache_per_layer[layer] = {}
|
|
||||||
cache_per_layer[layer][kv_type] = x
|
|
||||||
|
|
||||||
cache_history = [None] * len(cache_per_layer)
|
|
||||||
for layer, c in cache_per_layer.items():
|
|
||||||
cache_history[int(layer)] = (c["keys"], c["values"])
|
|
||||||
return cache_history, metadata
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = setup_arg_parser()
|
parser = setup_arg_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -158,22 +146,33 @@ def main():
|
|||||||
if args.cache_limit_gb is not None:
|
if args.cache_limit_gb is not None:
|
||||||
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
|
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
|
||||||
|
|
||||||
# Load the kv cache and metadata if a kv cache file is provided
|
# Load the prompt cache and metadata if a cache file is provided
|
||||||
cache_history, metadata = load_kv_cache_from_file(args.kv_cache_file)
|
using_cache = args.prompt_cache_file is not None
|
||||||
|
if using_cache:
|
||||||
|
prompt_cache, metadata = load_prompt_cache(
|
||||||
|
args.prompt_cache_file, return_metadata=True
|
||||||
|
)
|
||||||
|
|
||||||
# Building tokenizer_config
|
# Building tokenizer_config
|
||||||
tokenizer_config = (
|
tokenizer_config = (
|
||||||
{} if cache_history is None else json.loads(metadata["tokenizer_config"])
|
{} if not using_cache else json.loads(metadata["tokenizer_config"])
|
||||||
)
|
)
|
||||||
if args.trust_remote_code:
|
if args.trust_remote_code:
|
||||||
tokenizer_config["trust_remote_code"] = True
|
tokenizer_config["trust_remote_code"] = True
|
||||||
if args.eos_token is not None:
|
if args.eos_token is not None:
|
||||||
tokenizer_config["eos_token"] = args.eos_token
|
tokenizer_config["eos_token"] = args.eos_token
|
||||||
|
|
||||||
# If no model path is provided then use the one in the kv cache history
|
|
||||||
model_path = args.model
|
model_path = args.model
|
||||||
if cache_history is not None and model_path is None:
|
if using_cache:
|
||||||
model_path = metadata["model"]
|
if model_path is None:
|
||||||
|
model_path = metadata["model"]
|
||||||
|
elif model_path != metadata["model"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Providing a different model ({model_path}) than that "
|
||||||
|
f"used to create the prompt cache ({metadata['model']}) "
|
||||||
|
"is an error."
|
||||||
|
)
|
||||||
|
model_path = model_path or DEFAULT_MODEL
|
||||||
|
|
||||||
model, tokenizer = load(
|
model, tokenizer = load(
|
||||||
model_path,
|
model_path,
|
||||||
@ -184,7 +183,7 @@ def main():
|
|||||||
if args.use_default_chat_template:
|
if args.use_default_chat_template:
|
||||||
if tokenizer.chat_template is None:
|
if tokenizer.chat_template is None:
|
||||||
tokenizer.chat_template = tokenizer.default_chat_template
|
tokenizer.chat_template = tokenizer.default_chat_template
|
||||||
elif cache_history is not None:
|
elif using_cache:
|
||||||
tokenizer.chat_template = metadata["chat_template"]
|
tokenizer.chat_template = metadata["chat_template"]
|
||||||
|
|
||||||
if not args.ignore_chat_template and (
|
if not args.ignore_chat_template and (
|
||||||
@ -203,7 +202,7 @@ def main():
|
|||||||
|
|
||||||
# Treat the prompt as a suffix assuming that the prefix is in the
|
# Treat the prompt as a suffix assuming that the prefix is in the
|
||||||
# stored kv cache.
|
# stored kv cache.
|
||||||
if cache_history is not None:
|
if using_cache:
|
||||||
test_prompt = tokenizer.apply_chat_template(
|
test_prompt = tokenizer.apply_chat_template(
|
||||||
[{"role": "user", "content": "<query>"}],
|
[{"role": "user", "content": "<query>"}],
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
@ -217,12 +216,6 @@ def main():
|
|||||||
raise ValueError("Cannot use --colorize with --verbose=False")
|
raise ValueError("Cannot use --colorize with --verbose=False")
|
||||||
formatter = colorprint_by_t0 if args.colorize else None
|
formatter = colorprint_by_t0 if args.colorize else None
|
||||||
|
|
||||||
# Determine the max kv size from the kv cache or passed arguments
|
|
||||||
max_kv_size = args.max_kv_size
|
|
||||||
if cache_history is not None:
|
|
||||||
max_kv_size = metadata["max_kv_size"]
|
|
||||||
max_kv_size = int(max_kv_size) if max_kv_size.isdigit() else None
|
|
||||||
|
|
||||||
response = generate(
|
response = generate(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@ -232,8 +225,8 @@ def main():
|
|||||||
formatter=formatter,
|
formatter=formatter,
|
||||||
temp=args.temp,
|
temp=args.temp,
|
||||||
top_p=args.top_p,
|
top_p=args.top_p,
|
||||||
max_kv_size=max_kv_size,
|
max_kv_size=args.max_kv_size,
|
||||||
cache_history=cache_history,
|
prompt_cache=prompt_cache if using_cache else None,
|
||||||
)
|
)
|
||||||
if not args.verbose:
|
if not args.verbose:
|
||||||
print(response)
|
print(response)
|
||||||
|
@ -2,145 +2,9 @@
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
class KVCache:
|
|
||||||
|
|
||||||
def __init__(self, head_dim, n_kv_heads):
|
|
||||||
self.n_kv_heads = n_kv_heads
|
|
||||||
if isinstance(head_dim, int):
|
|
||||||
self.k_head_dim = self.v_head_dim = head_dim
|
|
||||||
elif isinstance(head_dim, tuple) and len(head_dim) == 2:
|
|
||||||
self.k_head_dim, self.v_head_dim = head_dim
|
|
||||||
else:
|
|
||||||
raise ValueError("head_dim must be an int or a tuple of two ints")
|
|
||||||
self.keys = None
|
|
||||||
self.values = None
|
|
||||||
self.offset = 0
|
|
||||||
self.step = 256
|
|
||||||
|
|
||||||
def update_and_fetch(self, keys, values):
|
|
||||||
prev = self.offset
|
|
||||||
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
|
|
||||||
B = keys.shape[0]
|
|
||||||
n_steps = (self.step + keys.shape[2] - 1) // self.step
|
|
||||||
k_shape = (B, self.n_kv_heads, n_steps * self.step, self.k_head_dim)
|
|
||||||
v_shape = (B, self.n_kv_heads, n_steps * self.step, self.v_head_dim)
|
|
||||||
new_k = mx.zeros(k_shape, keys.dtype)
|
|
||||||
new_v = mx.zeros(v_shape, values.dtype)
|
|
||||||
if self.keys is not None:
|
|
||||||
if prev % self.step != 0:
|
|
||||||
self.keys = self.keys[..., :prev, :]
|
|
||||||
self.values = self.values[..., :prev, :]
|
|
||||||
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
|
||||||
self.values = mx.concatenate([self.values, new_v], axis=2)
|
|
||||||
else:
|
|
||||||
self.keys, self.values = new_k, new_v
|
|
||||||
|
|
||||||
self.offset += keys.shape[2]
|
|
||||||
self.keys[..., prev : self.offset, :] = keys
|
|
||||||
self.values[..., prev : self.offset, :] = values
|
|
||||||
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def state(self):
|
|
||||||
return self.keys, self.values
|
|
||||||
|
|
||||||
|
|
||||||
class RotatingKVCache:
|
|
||||||
|
|
||||||
def __init__(self, head_dim, n_kv_heads, max_size, keep=0, step=256):
|
|
||||||
self.n_kv_heads = n_kv_heads
|
|
||||||
if isinstance(head_dim, int):
|
|
||||||
self.k_head_dim = self.v_head_dim = head_dim
|
|
||||||
elif isinstance(head_dim, tuple) and len(head_dim) == 2:
|
|
||||||
self.k_head_dim, self.v_head_dim = head_dim
|
|
||||||
else:
|
|
||||||
raise ValueError("head_dim must be an int or a tuple of two ints")
|
|
||||||
self.keep = keep
|
|
||||||
self.keys = None
|
|
||||||
self.values = None
|
|
||||||
self.offset = 0
|
|
||||||
self.max_size = max_size
|
|
||||||
self.step = step
|
|
||||||
self._idx = 0
|
|
||||||
|
|
||||||
def _trim(self, trim_size, v, append=None):
|
|
||||||
to_cat = []
|
|
||||||
if trim_size > 0:
|
|
||||||
to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]]
|
|
||||||
else:
|
|
||||||
to_cat = [v]
|
|
||||||
if append is not None:
|
|
||||||
to_cat.append(append)
|
|
||||||
return mx.concatenate(to_cat, axis=2)
|
|
||||||
|
|
||||||
def update_and_fetch(self, keys, values):
|
|
||||||
prev = self.offset
|
|
||||||
B, _, S = keys.shape[:3]
|
|
||||||
|
|
||||||
# Prefill mode
|
|
||||||
if S > 1:
|
|
||||||
if self.keys is None:
|
|
||||||
self.keys = keys
|
|
||||||
self.values = values
|
|
||||||
else:
|
|
||||||
# The largest size is self.max_size + S - 1 to ensure
|
|
||||||
# every token gets at least self.max_size context
|
|
||||||
trim_size = self.keys.shape[2] - self.max_size + 1
|
|
||||||
self.keys = self._trim(trim_size, self.keys, keys)
|
|
||||||
self.values = self._trim(trim_size, self.values, values)
|
|
||||||
self.offset += S
|
|
||||||
self._idx = self.keys.shape[2]
|
|
||||||
return self.keys, self.values
|
|
||||||
|
|
||||||
# Generation mode
|
|
||||||
# May not have hit the max size yet, so potentially
|
|
||||||
# keep growing the cache
|
|
||||||
if self.keys is None or (
|
|
||||||
prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
|
|
||||||
):
|
|
||||||
new_size = min(self.step, self.max_size - prev)
|
|
||||||
k_shape = (B, self.n_kv_heads, new_size, self.k_head_dim)
|
|
||||||
v_shape = (B, self.n_kv_heads, new_size, self.v_head_dim)
|
|
||||||
new_k = mx.zeros(k_shape, keys.dtype)
|
|
||||||
new_v = mx.zeros(v_shape, values.dtype)
|
|
||||||
if self.keys is not None:
|
|
||||||
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
|
||||||
self.values = mx.concatenate([self.values, new_v], axis=2)
|
|
||||||
else:
|
|
||||||
self.keys, self.values = new_k, new_v
|
|
||||||
self._idx = prev
|
|
||||||
|
|
||||||
# Trim if needed
|
|
||||||
trim_size = self.keys.shape[2] - self.max_size
|
|
||||||
if trim_size > 0:
|
|
||||||
self.keys = self._trim(trim_size, self.keys)
|
|
||||||
self.values = self._trim(trim_size, self.values)
|
|
||||||
self._idx = self.max_size
|
|
||||||
|
|
||||||
# Rotate
|
|
||||||
if self._idx == self.max_size:
|
|
||||||
self._idx = self.keep
|
|
||||||
|
|
||||||
# Assign
|
|
||||||
self.keys[..., self._idx : self._idx + 1, :] = keys
|
|
||||||
self.values[..., self._idx : self._idx + 1, :] = values
|
|
||||||
self.offset += 1
|
|
||||||
self._idx += 1
|
|
||||||
|
|
||||||
# If the buffer is not full, slice off the end
|
|
||||||
if self.offset < self.max_size:
|
|
||||||
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
|
||||||
return self.keys, self.values
|
|
||||||
|
|
||||||
@property
|
|
||||||
def state(self):
|
|
||||||
return self.keys, self.values
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -156,25 +20,30 @@ class BaseModelArgs:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_additive_causal_mask(N: int, offset: int = 0):
|
def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None):
|
||||||
rinds = mx.arange(offset + N)
|
rinds = mx.arange(offset + N)
|
||||||
linds = mx.arange(offset, offset + N) if offset else rinds
|
linds = mx.arange(offset, offset + N) if offset else rinds
|
||||||
mask = linds[:, None] < rinds[None]
|
linds = linds[:, None]
|
||||||
|
rinds = rinds[None]
|
||||||
|
mask = linds < rinds
|
||||||
|
if window_size is not None:
|
||||||
|
mask = mask | (linds > rinds + window_size)
|
||||||
return mask * -1e9
|
return mask * -1e9
|
||||||
|
|
||||||
|
|
||||||
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
||||||
T = h.shape[1]
|
T = h.shape[1]
|
||||||
if T > 1:
|
if T > 1:
|
||||||
|
window_size = None
|
||||||
|
offset = 0
|
||||||
if cache is not None and cache[0] is not None:
|
if cache is not None and cache[0] is not None:
|
||||||
c = cache[0]
|
c = cache[0]
|
||||||
if isinstance(c, RotatingKVCache):
|
if hasattr(c, "max_size"):
|
||||||
offset = min(c.max_size - 1, c.offset)
|
offset = min(c.max_size - 1, c.offset)
|
||||||
|
window_size = c.max_size
|
||||||
else:
|
else:
|
||||||
offset = c.offset
|
offset = c.offset
|
||||||
else:
|
mask = create_causal_mask(T, offset, window_size=window_size)
|
||||||
offset = 0
|
|
||||||
mask = create_additive_causal_mask(T, offset)
|
|
||||||
mask = mask.astype(h.dtype)
|
mask = mask.astype(h.dtype)
|
||||||
else:
|
else:
|
||||||
mask = None
|
mask = None
|
||||||
|
333
llms/mlx_lm/models/cache.py
Normal file
333
llms/mlx_lm/models/cache.py
Normal file
@ -0,0 +1,333 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from mlx.utils import tree_flatten, tree_unflatten
|
||||||
|
|
||||||
|
|
||||||
|
def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> List[Any]:
|
||||||
|
"""
|
||||||
|
Construct the model's cache for use when cgeneration.
|
||||||
|
|
||||||
|
This function will defer the cache construction to the model if it has a
|
||||||
|
``make_cache`` method, otherwise it will make a default KV cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): The language model.
|
||||||
|
max_kv_size (Optional[int]): If provided and the model does not have a
|
||||||
|
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
|
||||||
|
size of ``max_kv_size``
|
||||||
|
"""
|
||||||
|
if hasattr(model, "make_cache"):
|
||||||
|
return model.make_cache()
|
||||||
|
|
||||||
|
num_layers = len(model.layers)
|
||||||
|
if max_kv_size is not None:
|
||||||
|
return [
|
||||||
|
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return [KVCache() for _ in range(num_layers)]
|
||||||
|
|
||||||
|
|
||||||
|
def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}):
|
||||||
|
"""
|
||||||
|
Save a pre-computed prompt cache to a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_name (str): The ``.safetensors`` file name.
|
||||||
|
cache (List[Any]): The model state.
|
||||||
|
metadata (Dict[str, str]): Optional metadata to save along with model
|
||||||
|
state.
|
||||||
|
"""
|
||||||
|
cache_data = [c.state for c in cache]
|
||||||
|
cache_info = [c.meta_state for c in cache]
|
||||||
|
cache_data = dict(tree_flatten(cache_data))
|
||||||
|
cache_classes = [type(c).__name__ for c in cache]
|
||||||
|
cache_metadata = [cache_info, metadata, cache_classes]
|
||||||
|
cache_metadata = dict(tree_flatten(cache_metadata))
|
||||||
|
mx.save_safetensors(file_name, cache_data, cache_metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def load_prompt_cache(file_name, return_metadata=False):
|
||||||
|
"""
|
||||||
|
Load a prompt cache from a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_name (str): The ``.safetensors`` file name.
|
||||||
|
return_metadata (bool): Whether or not to return metadata.
|
||||||
|
Default: ``False``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and
|
||||||
|
the metadata if requested.
|
||||||
|
"""
|
||||||
|
arrays, cache_metadata = mx.load(file_name, return_metadata=True)
|
||||||
|
arrays = tree_unflatten(list(arrays.items()))
|
||||||
|
cache_metadata = tree_unflatten(list(cache_metadata.items()))
|
||||||
|
info, metadata, classes = cache_metadata
|
||||||
|
cache = [globals()[c]() for c in classes]
|
||||||
|
for c, state, meta_state in zip(cache, arrays, info):
|
||||||
|
c.state = state
|
||||||
|
c.meta_state = meta_state
|
||||||
|
if return_metadata:
|
||||||
|
return cache, metadata
|
||||||
|
return cache
|
||||||
|
|
||||||
|
|
||||||
|
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
|
||||||
|
"""
|
||||||
|
Trim the model's cache by the given number of tokens.
|
||||||
|
|
||||||
|
This function will trim the cache if possible (in-place) and return the
|
||||||
|
number of tokens that were trimmed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache (List[Any]): The model's cache.
|
||||||
|
num_tokens (int): The number of tokens to trim.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(int): The number of tokens that were trimmed.
|
||||||
|
"""
|
||||||
|
if not all(c.is_trimmable() for c in cache) or len(cache) == 0:
|
||||||
|
return 0
|
||||||
|
return [c.trim(num_tokens) for c in cache][0]
|
||||||
|
|
||||||
|
|
||||||
|
class _BaseCache:
|
||||||
|
@property
|
||||||
|
def state(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
@state.setter
|
||||||
|
def state(self, v):
|
||||||
|
if v is not None and v:
|
||||||
|
raise ValueError("This cache has no state but a state was set.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def meta_state(self):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
@meta_state.setter
|
||||||
|
def meta_state(self, v):
|
||||||
|
if v is not None and v:
|
||||||
|
raise ValueError("This cache has no meta_state but a meta_state was set.")
|
||||||
|
|
||||||
|
def is_trimmable(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class KVCache(_BaseCache):
|
||||||
|
def __init__(self):
|
||||||
|
self.keys = None
|
||||||
|
self.values = None
|
||||||
|
self.offset = 0
|
||||||
|
self.step = 256
|
||||||
|
|
||||||
|
def update_and_fetch(self, keys, values):
|
||||||
|
prev = self.offset
|
||||||
|
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
|
||||||
|
B, n_kv_heads, _, k_head_dim = keys.shape
|
||||||
|
v_head_dim = values.shape[3]
|
||||||
|
n_steps = (self.step + keys.shape[2] - 1) // self.step
|
||||||
|
k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim)
|
||||||
|
v_shape = (B, n_kv_heads, n_steps * self.step, v_head_dim)
|
||||||
|
new_k = mx.zeros(k_shape, keys.dtype)
|
||||||
|
new_v = mx.zeros(v_shape, values.dtype)
|
||||||
|
if self.keys is not None:
|
||||||
|
if prev % self.step != 0:
|
||||||
|
self.keys = self.keys[..., :prev, :]
|
||||||
|
self.values = self.values[..., :prev, :]
|
||||||
|
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
||||||
|
self.values = mx.concatenate([self.values, new_v], axis=2)
|
||||||
|
else:
|
||||||
|
self.keys, self.values = new_k, new_v
|
||||||
|
|
||||||
|
self.offset += keys.shape[2]
|
||||||
|
self.keys[..., prev : self.offset, :] = keys
|
||||||
|
self.values[..., prev : self.offset, :] = values
|
||||||
|
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self):
|
||||||
|
if self.offset == self.keys.shape[2]:
|
||||||
|
return self.keys, self.values
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
self.keys[..., : self.offset, :],
|
||||||
|
self.values[..., : self.offset, :],
|
||||||
|
)
|
||||||
|
|
||||||
|
@state.setter
|
||||||
|
def state(self, v):
|
||||||
|
self.keys, self.values = v
|
||||||
|
self.offset = self.keys.shape[2]
|
||||||
|
|
||||||
|
def is_trimmable(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def trim(self, n):
|
||||||
|
n = min(self.offset, n)
|
||||||
|
self.offset -= n
|
||||||
|
return n
|
||||||
|
|
||||||
|
|
||||||
|
class RotatingKVCache(_BaseCache):
|
||||||
|
|
||||||
|
def __init__(self, max_size=None, keep=0, step=256):
|
||||||
|
self.keep = keep
|
||||||
|
self.keys = None
|
||||||
|
self.values = None
|
||||||
|
self.offset = 0
|
||||||
|
self.max_size = max_size
|
||||||
|
self.step = step
|
||||||
|
self._idx = 0
|
||||||
|
|
||||||
|
def _trim(self, trim_size, v, append=None):
|
||||||
|
to_cat = []
|
||||||
|
if trim_size > 0:
|
||||||
|
to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]]
|
||||||
|
else:
|
||||||
|
to_cat = [v]
|
||||||
|
if append is not None:
|
||||||
|
to_cat.append(append)
|
||||||
|
return mx.concatenate(to_cat, axis=2)
|
||||||
|
|
||||||
|
def _temporal_order(self, v):
|
||||||
|
"""
|
||||||
|
Rearrange the cache into temporal order, slicing off the end if unused.
|
||||||
|
"""
|
||||||
|
if self._idx == v.shape[2]:
|
||||||
|
return v
|
||||||
|
elif self._idx < self.offset:
|
||||||
|
return mx.concatenate(
|
||||||
|
[
|
||||||
|
v[..., : self.keep, :],
|
||||||
|
v[..., self._idx :, :],
|
||||||
|
v[..., self.keep : self._idx, :],
|
||||||
|
],
|
||||||
|
axis=2,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return v[..., : self._idx, :]
|
||||||
|
|
||||||
|
def _update_concat(self, keys, values):
|
||||||
|
if self.keys is None:
|
||||||
|
self.keys = keys
|
||||||
|
self.values = values
|
||||||
|
else:
|
||||||
|
# Put the keys/values in temporal order to
|
||||||
|
# preserve context
|
||||||
|
self.keys = self._temporal_order(self.keys)
|
||||||
|
self.values = self._temporal_order(self.values)
|
||||||
|
|
||||||
|
# The largest size is self.max_size + S - 1 to ensure
|
||||||
|
# every token gets at least self.max_size context
|
||||||
|
trim_size = self._idx - self.max_size + 1
|
||||||
|
self.keys = self._trim(trim_size, self.keys, keys)
|
||||||
|
self.values = self._trim(trim_size, self.values, values)
|
||||||
|
self.offset += keys.shape[2]
|
||||||
|
self._idx = self.keys.shape[2]
|
||||||
|
return self.keys, self.values
|
||||||
|
|
||||||
|
def _update_in_place(self, keys, values):
|
||||||
|
# May not have hit the max size yet, so potentially
|
||||||
|
# keep growing the cache
|
||||||
|
B, n_kv_heads, S, k_head_dim = keys.shape
|
||||||
|
prev = self.offset
|
||||||
|
if self.keys is None or (
|
||||||
|
prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
|
||||||
|
):
|
||||||
|
v_head_dim = values.shape[3]
|
||||||
|
new_size = min(self.step, self.max_size - prev)
|
||||||
|
k_shape = (B, n_kv_heads, new_size, k_head_dim)
|
||||||
|
v_shape = (B, n_kv_heads, new_size, v_head_dim)
|
||||||
|
new_k = mx.zeros(k_shape, keys.dtype)
|
||||||
|
new_v = mx.zeros(v_shape, values.dtype)
|
||||||
|
if self.keys is not None:
|
||||||
|
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
||||||
|
self.values = mx.concatenate([self.values, new_v], axis=2)
|
||||||
|
else:
|
||||||
|
self.keys, self.values = new_k, new_v
|
||||||
|
self._idx = prev
|
||||||
|
|
||||||
|
# Trim if needed
|
||||||
|
trim_size = self.keys.shape[2] - self.max_size
|
||||||
|
if trim_size > 0:
|
||||||
|
self.keys = self._trim(trim_size, self.keys)
|
||||||
|
self.values = self._trim(trim_size, self.values)
|
||||||
|
self._idx = self.max_size
|
||||||
|
|
||||||
|
# Rotate
|
||||||
|
if self._idx == self.max_size:
|
||||||
|
self._idx = self.keep
|
||||||
|
|
||||||
|
# Assign
|
||||||
|
self.keys[..., self._idx : self._idx + S, :] = keys
|
||||||
|
self.values[..., self._idx : self._idx + S, :] = values
|
||||||
|
self.offset += S
|
||||||
|
self._idx += S
|
||||||
|
|
||||||
|
# If the buffer is not full, slice off the end
|
||||||
|
if self.offset < self.max_size:
|
||||||
|
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||||
|
return self.keys, self.values
|
||||||
|
|
||||||
|
def update_and_fetch(self, keys, values):
|
||||||
|
if keys.shape[2] == 1:
|
||||||
|
return self._update_in_place(keys, values)
|
||||||
|
return self._update_concat(keys, values)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self):
|
||||||
|
if self.offset < self.keys.shape[2]:
|
||||||
|
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||||
|
else:
|
||||||
|
return self.keys, self.values
|
||||||
|
|
||||||
|
@state.setter
|
||||||
|
def state(self, v):
|
||||||
|
self.keys, self.values = v
|
||||||
|
|
||||||
|
@property
|
||||||
|
def meta_state(self):
|
||||||
|
return tuple(
|
||||||
|
map(str, (self.keep, self.max_size, self.step, self.offset, self._idx))
|
||||||
|
)
|
||||||
|
|
||||||
|
@meta_state.setter
|
||||||
|
def meta_state(self, v):
|
||||||
|
self.keep, self.max_size, self.step, self.offset, self._idx = map(
|
||||||
|
int,
|
||||||
|
v,
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_trimmable(self):
|
||||||
|
return self.offset < self.max_size
|
||||||
|
|
||||||
|
def trim(self, n):
|
||||||
|
n = min(self.offset, n)
|
||||||
|
self.offset -= n
|
||||||
|
self._idx -= n
|
||||||
|
return n
|
||||||
|
|
||||||
|
|
||||||
|
class MambaCache(_BaseCache):
|
||||||
|
def __init__(self):
|
||||||
|
self.cache = [None, None]
|
||||||
|
|
||||||
|
def __setitem__(self, idx, value):
|
||||||
|
self.cache[idx] = value
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.cache[idx]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self):
|
||||||
|
return self.cache
|
||||||
|
|
||||||
|
@state.setter
|
||||||
|
def state(self, v):
|
||||||
|
self.cache = v
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -69,7 +69,7 @@ class Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
|
|
||||||
@ -129,7 +129,7 @@ class TransformerBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
h = self.input_layernorm(x)
|
h = self.input_layernorm(x)
|
||||||
attn_h = self.self_attn(h, mask, cache)
|
attn_h = self.self_attn(h, mask, cache)
|
||||||
@ -190,11 +190,3 @@ class Model(nn.Module):
|
|||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|
||||||
@property
|
|
||||||
def head_dim(self):
|
|
||||||
return self.args.hidden_size // self.args.num_attention_heads
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_kv_heads(self):
|
|
||||||
return self.args.num_key_value_heads
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -49,7 +49,7 @@ class Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
|
|
||||||
qkv = self.Wqkv(x)
|
qkv = self.Wqkv(x)
|
||||||
@ -92,7 +92,7 @@ class NormAttnNorm(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
h = self.attn(self.norm_1(x), mask=mask, cache=cache)
|
h = self.attn(self.norm_1(x), mask=mask, cache=cache)
|
||||||
x = h + x
|
x = h + x
|
||||||
@ -179,7 +179,7 @@ class DecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r, h = self.norm_attn_norm(x, mask, cache)
|
r, h = self.norm_attn_norm(x, mask, cache)
|
||||||
out = self.ffn(h) + r
|
out = self.ffn(h) + r
|
||||||
@ -249,11 +249,3 @@ class Model(nn.Module):
|
|||||||
experts = [(s, sv.T) for s, sv in experts]
|
experts = [(s, sv.T) for s, sv in experts]
|
||||||
new_weights.update(experts)
|
new_weights.update(experts)
|
||||||
return new_weights
|
return new_weights
|
||||||
|
|
||||||
@property
|
|
||||||
def head_dim(self):
|
|
||||||
return self.args.d_model // self.args.n_heads
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_kv_heads(self):
|
|
||||||
return self.args.attn_config["kv_n_heads"]
|
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask
|
||||||
from .switch_layers import SwitchGLU
|
from .switch_layers import SwitchGLU
|
||||||
|
|
||||||
|
|
||||||
@ -77,7 +77,7 @@ class DeepseekAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[KVCache] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
B, L, _ = x.shape
|
B, L, _ = x.shape
|
||||||
|
|
||||||
@ -108,8 +108,8 @@ class DeepseekMLP(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: ModelArgs,
|
config: ModelArgs,
|
||||||
hidden_size: int | None = None,
|
hidden_size: Optional[int] = None,
|
||||||
intermediate_size: int | None = None,
|
intermediate_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -188,7 +188,7 @@ class DeepseekDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[KVCache] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
@ -210,7 +210,7 @@ class DeepseekModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
cache: Optional[KVCache] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
h = self.embed_tokens(x)
|
h = self.embed_tokens(x)
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
@ -235,7 +235,7 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
cache: Optional[KVCache] = None,
|
cache: Optional[Any] = None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, cache)
|
||||||
return self.lm_head(out)
|
return self.lm_head(out)
|
||||||
@ -256,11 +256,3 @@ class Model(nn.Module):
|
|||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|
||||||
@property
|
|
||||||
def head_dim(self):
|
|
||||||
return self.args.hidden_size // self.args.num_attention_heads
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_kv_heads(self):
|
|
||||||
return self.args.num_key_value_heads
|
|
||||||
|
@ -2,12 +2,12 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask
|
||||||
from .switch_layers import SwitchGLU
|
from .switch_layers import SwitchGLU
|
||||||
|
|
||||||
|
|
||||||
@ -38,7 +38,7 @@ class ModelArgs(BaseModelArgs):
|
|||||||
max_position_embeddings: int = 2048
|
max_position_embeddings: int = 2048
|
||||||
rms_norm_eps: float = 1e-6
|
rms_norm_eps: float = 1e-6
|
||||||
rope_theta: float = 10000.0
|
rope_theta: float = 10000.0
|
||||||
rope_scaling: Optional[Dict] = None
|
rope_scaling: Dict = None
|
||||||
attention_bias: bool = False
|
attention_bias: bool = False
|
||||||
|
|
||||||
|
|
||||||
@ -172,12 +172,11 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
bias=config.attention_bias,
|
bias=config.attention_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.config.rope_scaling is not None:
|
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
|
||||||
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
|
scaling_factor = self.config.rope_scaling["factor"]
|
||||||
scaling_factor = self.config.rope_scaling["factor"]
|
if mscale_all_dim:
|
||||||
if mscale_all_dim:
|
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
|
||||||
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
|
self.scale = self.scale * mscale * mscale
|
||||||
self.scale = self.scale * mscale * mscale
|
|
||||||
|
|
||||||
rope_kwargs = {
|
rope_kwargs = {
|
||||||
key: self.config.rope_scaling[key]
|
key: self.config.rope_scaling[key]
|
||||||
@ -202,7 +201,7 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[KVCache] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
|
|
||||||
@ -347,7 +346,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[KVCache] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
@ -370,7 +369,7 @@ class DeepseekV2Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
cache: Optional[KVCache] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
h = self.embed_tokens(x)
|
h = self.embed_tokens(x)
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
@ -395,7 +394,7 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
cache: Optional[KVCache] = None,
|
cache: Optional[Any] = None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, cache)
|
||||||
return self.lm_head(out)
|
return self.lm_head(out)
|
||||||
@ -416,14 +415,3 @@ class Model(nn.Module):
|
|||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|
||||||
@property
|
|
||||||
def head_dim(self):
|
|
||||||
return (
|
|
||||||
self.args.qk_nope_head_dim + self.args.qk_rope_head_dim,
|
|
||||||
self.args.v_head_dim,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_kv_heads(self):
|
|
||||||
return self.args.num_key_value_heads
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -60,7 +60,7 @@ class Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
|
|
||||||
@ -113,7 +113,7 @@ class TransformerBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
@ -173,11 +173,3 @@ class Model(nn.Module):
|
|||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|
||||||
@property
|
|
||||||
def head_dim(self):
|
|
||||||
return self.args.head_dim
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_kv_heads(self):
|
|
||||||
return self.args.num_key_value_heads
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -64,7 +64,7 @@ class Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||||
@ -135,13 +135,11 @@ class TransformerBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r = self.self_attn(self.input_layernorm(x.astype(mx.float32)), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + self.post_attention_layernorm(r)
|
h = x + self.post_attention_layernorm(r)
|
||||||
r = self.mlp(self.pre_feedforward_layernorm(h).astype(mx.float16)).astype(
|
r = self.mlp(self.pre_feedforward_layernorm(h))
|
||||||
mx.float32
|
|
||||||
)
|
|
||||||
out = h + self.post_feedforward_layernorm(r)
|
out = h + self.post_feedforward_layernorm(r)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -200,11 +198,3 @@ class Model(nn.Module):
|
|||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|
||||||
@property
|
|
||||||
def head_dim(self):
|
|
||||||
return self.args.head_dim
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_kv_heads(self):
|
|
||||||
return self.args.num_key_value_heads
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -46,7 +46,7 @@ class Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
|
|
||||||
@ -100,7 +100,7 @@ class TransformerBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r = self.attn(self.ln_1(x), mask, cache)
|
r = self.attn(self.ln_1(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
@ -196,11 +196,3 @@ class Model(nn.Module):
|
|||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.h
|
return self.model.h
|
||||||
|
|
||||||
@property
|
|
||||||
def head_dim(self):
|
|
||||||
return self.args.n_embd // self.args.n_head
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_kv_heads(self):
|
|
||||||
return self.args.num_key_value_heads
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -57,7 +57,7 @@ class Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
|
|
||||||
@ -114,7 +114,7 @@ class TransformerBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r = self.attn(self.ln_1(x), mask, cache)
|
r = self.attn(self.ln_1(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
@ -184,11 +184,3 @@ class Model(nn.Module):
|
|||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.transformer.h
|
return self.transformer.h
|
||||||
|
|
||||||
@property
|
|
||||||
def head_dim(self):
|
|
||||||
return self.args.n_embd // self.args.n_head
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_kv_heads(self):
|
|
||||||
return self.args.num_key_value_heads
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -60,7 +60,7 @@ class Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
|
|
||||||
@ -120,7 +120,7 @@ class TransformerBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
residual = x
|
residual = x
|
||||||
# NeoX runs attention and feedforward network in parallel.
|
# NeoX runs attention and feedforward network in parallel.
|
||||||
@ -214,11 +214,3 @@ class Model(nn.Module):
|
|||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.h
|
return self.model.h
|
||||||
|
|
||||||
@property
|
|
||||||
def head_dim(self):
|
|
||||||
return self.args.hidden_size // self.args.num_attention_heads
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_kv_heads(self):
|
|
||||||
return self.args.num_key_value_heads
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -116,7 +116,7 @@ class Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
|
|
||||||
@ -171,7 +171,7 @@ class TransformerBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r = self.attention(self.attention_norm(x), mask, cache)
|
r = self.attention(self.attention_norm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
@ -236,11 +236,3 @@ class Model(nn.Module):
|
|||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|
||||||
@property
|
|
||||||
def head_dim(self):
|
|
||||||
return self.args.hidden_size // self.args.num_attention_heads
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_kv_heads(self):
|
|
||||||
return self.args.num_key_value_heads
|
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -171,7 +171,7 @@ class Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[KVCache] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
|
|
||||||
@ -233,7 +233,7 @@ class TransformerBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[KVCache] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
@ -303,13 +303,3 @@ class Model(nn.Module):
|
|||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|
||||||
@property
|
|
||||||
def head_dim(self):
|
|
||||||
return (
|
|
||||||
self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_kv_heads(self):
|
|
||||||
return self.args.num_key_value_heads
|
|
||||||
|
@ -7,6 +7,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs
|
from .base import BaseModelArgs
|
||||||
|
from .cache import MambaCache
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -45,21 +46,6 @@ class ModelArgs(BaseModelArgs):
|
|||||||
self.time_step_rank = math.ceil(self.hidden_size / 16)
|
self.time_step_rank = math.ceil(self.hidden_size / 16)
|
||||||
|
|
||||||
|
|
||||||
class MambaCache:
|
|
||||||
def __init__(self):
|
|
||||||
self.cache = [None, None]
|
|
||||||
|
|
||||||
def __setitem__(self, idx, value):
|
|
||||||
self.cache[idx] = value
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
return self.cache[idx]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def state(self):
|
|
||||||
return self.cache
|
|
||||||
|
|
||||||
|
|
||||||
class DepthWiseConv1d(nn.Module):
|
class DepthWiseConv1d(nn.Module):
|
||||||
def __init__(self, channels, kernel_size, bias=True, padding=0):
|
def __init__(self, channels, kernel_size, bias=True, padding=0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -223,7 +209,7 @@ class Model(nn.Module):
|
|||||||
weights[k] = v.moveaxis(2, 1)
|
weights[k] = v.moveaxis(2, 1)
|
||||||
return weights
|
return weights
|
||||||
|
|
||||||
def make_cache(self, batch_size: int = 1):
|
def make_cache(self):
|
||||||
return [MambaCache() for _ in range(len(self.layers))]
|
return [MambaCache() for _ in range(len(self.layers))]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -85,7 +85,7 @@ class Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Any] = None,
|
||||||
):
|
):
|
||||||
B, L, _ = x.shape
|
B, L, _ = x.shape
|
||||||
|
|
||||||
@ -135,7 +135,7 @@ class DecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r * (self.scale_depth / np.sqrt(self.num_hidden_layers))
|
h = x + r * (self.scale_depth / np.sqrt(self.num_hidden_layers))
|
||||||
@ -205,11 +205,3 @@ class Model(nn.Module):
|
|||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|
||||||
@property
|
|
||||||
def head_dim(self):
|
|
||||||
return self.args.hidden_size // self.args.num_attention_heads
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_kv_heads(self):
|
|
||||||
return self.args.num_key_value_heads
|
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -66,7 +66,7 @@ class MixtralAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
|
|
||||||
@ -138,7 +138,7 @@ class MixtralDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
@ -215,11 +215,3 @@ class Model(nn.Module):
|
|||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|
||||||
@property
|
|
||||||
def head_dim(self):
|
|
||||||
return self.args.hidden_size // self.args.num_attention_heads
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_kv_heads(self):
|
|
||||||
return self.args.num_key_value_heads
|
|
||||||
|
@ -2,12 +2,12 @@
|
|||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, Optional, Union
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -94,7 +94,7 @@ class Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[KVCache] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
B, L, _ = x.shape
|
B, L, _ = x.shape
|
||||||
|
|
||||||
@ -151,7 +151,7 @@ class TransformerBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[KVCache] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
@ -215,13 +215,3 @@ class Model(nn.Module):
|
|||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|
||||||
@property
|
|
||||||
def head_dim(self):
|
|
||||||
return (
|
|
||||||
self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_kv_heads(self):
|
|
||||||
return self.args.num_key_value_heads
|
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import sys
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from sys import exit
|
from typing import Any, Optional, Tuple
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -13,7 +13,7 @@ try:
|
|||||||
import hf_olmo
|
import hf_olmo
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("To run olmo install ai2-olmo: pip install ai2-olmo")
|
print("To run olmo install ai2-olmo: pip install ai2-olmo")
|
||||||
exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -68,7 +68,7 @@ class TransformerBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
|
|
||||||
@ -98,7 +98,7 @@ class TransformerBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r = self.attend(self.att_norm(x), mask, cache)
|
r = self.attend(self.att_norm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
@ -174,11 +174,3 @@ class Model(nn.Module):
|
|||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.transformer.blocks
|
return self.model.transformer.blocks
|
||||||
|
|
||||||
@property
|
|
||||||
def head_dim(self):
|
|
||||||
return self.args.d_model // self.args.n_heads
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_kv_heads(self):
|
|
||||||
return self.args.n_heads
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -80,7 +80,7 @@ class Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
|
|
||||||
@ -152,7 +152,7 @@ class TransformerBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r = self.attn(self.attn_norm(x), mask, cache)
|
r = self.attn(self.attn_norm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
@ -218,11 +218,3 @@ class Model(nn.Module):
|
|||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.transformer.layers
|
return self.transformer.layers
|
||||||
|
|
||||||
@property
|
|
||||||
def head_dim(self):
|
|
||||||
return self.args.head_dim
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_kv_heads(self):
|
|
||||||
return self.args.num_kv_heads
|
|
||||||
|
@ -162,19 +162,11 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
cache: mx.array = None,
|
cache=None,
|
||||||
) -> Tuple[mx.array, mx.array]:
|
) -> mx.array:
|
||||||
y = self.model(x, cache)
|
y = self.model(x, cache)
|
||||||
return self.lm_head(y)
|
return self.lm_head(y)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|
||||||
@property
|
|
||||||
def head_dim(self):
|
|
||||||
return self.args.hidden_size // self.args.num_attention_heads
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_kv_heads(self):
|
|
||||||
return self.args.num_key_value_heads
|
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask
|
||||||
from .su_rope import SuScaledRotaryEmbedding
|
from .su_rope import SuScaledRotaryEmbedding
|
||||||
|
|
||||||
|
|
||||||
@ -84,7 +84,7 @@ class Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[KVCache] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
|
|
||||||
@ -143,7 +143,7 @@ class TransformerBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[KVCache] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
@ -202,11 +202,3 @@ class Model(nn.Module):
|
|||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|
||||||
@property
|
|
||||||
def head_dim(self):
|
|
||||||
return self.args.hidden_size // self.args.num_attention_heads
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_kv_heads(self):
|
|
||||||
return self.args.num_key_value_heads
|
|
||||||
|
@ -3,12 +3,12 @@
|
|||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Any, Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -22,14 +22,14 @@ class ModelArgs(BaseModelArgs):
|
|||||||
num_attention_heads: int
|
num_attention_heads: int
|
||||||
layer_norm_epsilon: float
|
layer_norm_epsilon: float
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
num_key_value_heads: Optional[int] = None
|
num_key_value_heads: int
|
||||||
mup_attn_multiplier: float = 1.0
|
mup_attn_multiplier: float = 1.0
|
||||||
mup_use_scaling: bool = True
|
mup_use_scaling: bool = True
|
||||||
mup_embedding_multiplier: float = 10.0
|
mup_embedding_multiplier: float = 10.0
|
||||||
mup_width_multiplier: float = 8.0
|
mup_width_multiplier: float = 8.0
|
||||||
rope_embedding_base: float = 1000000
|
rope_embedding_base: float = 1000000
|
||||||
rope_position_scale: float = 1.0
|
rope_position_scale: float = 1.0
|
||||||
blocksparse_block_size: Tuple[int] = (64,)
|
blocksparse_block_size: int = 64
|
||||||
blocksparse_num_local_blocks: int = 16
|
blocksparse_num_local_blocks: int = 16
|
||||||
blocksparse_vert_stride: int = 8
|
blocksparse_vert_stride: int = 8
|
||||||
|
|
||||||
@ -61,7 +61,6 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
dim = args.hidden_size
|
dim = args.hidden_size
|
||||||
self.n_heads = n_heads = args.num_attention_heads
|
self.n_heads = n_heads = args.num_attention_heads
|
||||||
assert args.num_key_value_heads is not None
|
|
||||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||||
self.n_q_per_kv = n_heads // n_kv_heads
|
self.n_q_per_kv = n_heads // n_kv_heads
|
||||||
|
|
||||||
@ -161,7 +160,7 @@ class Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[KVCache] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
|
|
||||||
@ -230,7 +229,7 @@ class TransformerBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[KVCache] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
@ -304,16 +303,8 @@ class Model(nn.Module):
|
|||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|
||||||
@property
|
|
||||||
def head_dim(self):
|
|
||||||
return self.args.hidden_size // self.args.num_attention_heads
|
|
||||||
|
|
||||||
def sanitize(self, weights):
|
def sanitize(self, weights):
|
||||||
# Remove unused precomputed rotary freqs
|
# Remove unused precomputed rotary freqs
|
||||||
return {
|
return {
|
||||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
|
||||||
def n_kv_heads(self):
|
|
||||||
return self.args.num_key_value_heads
|
|
||||||
|
@ -173,6 +173,7 @@ class PhiMoEModel(nn.Module):
|
|||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.model_type = args.model_type
|
||||||
self.args = args
|
self.args = args
|
||||||
self.model = PhiMoEModel(args)
|
self.model = PhiMoEModel(args)
|
||||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=True)
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=True)
|
||||||
@ -208,11 +209,3 @@ class Model(nn.Module):
|
|||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|
||||||
@property
|
|
||||||
def head_dim(self):
|
|
||||||
return self.args.hidden_size // self.args.num_attention_heads
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_kv_heads(self):
|
|
||||||
return self.args.num_key_value_heads
|
|
||||||
|
@ -168,8 +168,8 @@ class Model(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: mx.array = None,
|
mask: mx.array = None,
|
||||||
cache: mx.array = None,
|
cache=None,
|
||||||
) -> Tuple[mx.array, mx.array]:
|
) -> mx.array:
|
||||||
mask = create_attention_mask(x, cache)
|
mask = create_attention_mask(x, cache)
|
||||||
|
|
||||||
y = self.transformer(x, mask, cache)
|
y = self.transformer(x, mask, cache)
|
||||||
@ -193,11 +193,3 @@ class Model(nn.Module):
|
|||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.transformer.h
|
return self.transformer.h
|
||||||
|
|
||||||
@property
|
|
||||||
def head_dim(self):
|
|
||||||
return self.args.model_dim // self.args.num_heads
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_kv_heads(self):
|
|
||||||
return self.args.num_heads
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, List, Optional, Tuple, Union
|
from typing import Any, Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -62,8 +62,8 @@ class Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: mx.array,
|
hidden_states: mx.array,
|
||||||
attention_mask: Optional[mx.array] = None,
|
attention_mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
|
) -> mx.array:
|
||||||
bsz, q_len, _ = hidden_states.shape
|
bsz, q_len, _ = hidden_states.shape
|
||||||
|
|
||||||
queries = self.q_proj(hidden_states)
|
queries = self.q_proj(hidden_states)
|
||||||
@ -127,8 +127,8 @@ class PlamoDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: mx.array,
|
hidden_states: mx.array,
|
||||||
attention_mask: Optional[mx.array] = None,
|
attention_mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> Tuple[Any, ...]:
|
):
|
||||||
# from LlamaDecoder
|
# from LlamaDecoder
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
@ -169,8 +169,8 @@ class PlamoModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
cache: Optional[List[Union[Tuple[mx.array, mx.array], None]]] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> Tuple[mx.array, Optional[List[Union[Tuple[mx.array, mx.array], None]]]]:
|
) -> mx.array:
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
@ -197,19 +197,11 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> Tuple[mx.array, mx.array]:
|
) -> mx.array:
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, cache)
|
||||||
return self.lm_head(out)
|
return self.lm_head(out)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers.layers
|
return self.model.layers.layers
|
||||||
|
|
||||||
@property
|
|
||||||
def head_dim(self):
|
|
||||||
return self.args.hidden_size // self.args.num_attention_heads
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_kv_heads(self):
|
|
||||||
return self.args.num_attention_heads // self.args.n_shared_head
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -149,19 +148,11 @@ class Model(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: mx.array = None,
|
mask: mx.array = None,
|
||||||
cache: mx.array = None,
|
cache=None,
|
||||||
) -> Tuple[mx.array, mx.array]:
|
) -> mx.array:
|
||||||
y = self.transformer(x, mask, cache)
|
y = self.transformer(x, mask, cache)
|
||||||
return self.lm_head(y)
|
return self.lm_head(y)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.transformer.h
|
return self.transformer.h
|
||||||
|
|
||||||
@property
|
|
||||||
def head_dim(self):
|
|
||||||
return self.args.hidden_size // self.args.num_attention_heads
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_kv_heads(self):
|
|
||||||
return self.args.num_attention_heads
|
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -70,7 +70,7 @@ class Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[KVCache] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
|
|
||||||
@ -124,7 +124,7 @@ class TransformerBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[KVCache] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
@ -196,11 +196,3 @@ class Model(nn.Module):
|
|||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|
||||||
@property
|
|
||||||
def head_dim(self):
|
|
||||||
return self.args.hidden_size // self.args.num_attention_heads
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_kv_heads(self):
|
|
||||||
return self.args.num_key_value_heads
|
|
||||||
|
@ -2,12 +2,12 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask
|
||||||
from .switch_layers import SwitchGLU
|
from .switch_layers import SwitchGLU
|
||||||
|
|
||||||
|
|
||||||
@ -70,7 +70,7 @@ class Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[KVCache] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
|
|
||||||
@ -162,7 +162,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[KVCache] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
@ -236,11 +236,3 @@ class Model(nn.Module):
|
|||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|
||||||
@property
|
|
||||||
def head_dim(self):
|
|
||||||
return self.args.hidden_size // self.args.num_attention_heads
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_kv_heads(self):
|
|
||||||
return self.args.num_key_value_heads
|
|
||||||
|
@ -7,13 +7,13 @@ from typing import List, Literal, Optional
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs
|
from .base import BaseModelArgs, create_attention_mask
|
||||||
|
from .cache import MambaCache, RotatingKVCache
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArgs(BaseModelArgs):
|
class ModelArgs(BaseModelArgs):
|
||||||
model_type: str
|
model_type: str
|
||||||
hidden_size: int
|
|
||||||
attention_bias: bool
|
attention_bias: bool
|
||||||
conv1d_width: int
|
conv1d_width: int
|
||||||
hidden_size: int
|
hidden_size: int
|
||||||
@ -36,59 +36,6 @@ class ModelArgs(BaseModelArgs):
|
|||||||
self.block_types = self._block_types
|
self.block_types = self._block_types
|
||||||
|
|
||||||
|
|
||||||
def create_window_causal_mask(N: int, window_size: int):
|
|
||||||
inds = mx.arange(N)
|
|
||||||
linds = inds[:, None]
|
|
||||||
rinds = inds[None]
|
|
||||||
mask = (linds < rinds) | (linds > rinds + window_size)
|
|
||||||
return mask * -1e9
|
|
||||||
|
|
||||||
|
|
||||||
class RecurrentCache:
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._cache = (None, None)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
return self._cache[idx]
|
|
||||||
|
|
||||||
def update(self, conv_state, recurrent_state):
|
|
||||||
self._cache = (conv_state, recurrent_state)
|
|
||||||
|
|
||||||
def state(self):
|
|
||||||
return self._cache
|
|
||||||
|
|
||||||
|
|
||||||
class WindowKVCache:
|
|
||||||
|
|
||||||
def __init__(self, window_size):
|
|
||||||
self.keys = None
|
|
||||||
self.values = None
|
|
||||||
self.offset = 0
|
|
||||||
self.window_size = window_size
|
|
||||||
|
|
||||||
def update_and_fetch(self, keys, values):
|
|
||||||
# TODO consider using rotating buffer here
|
|
||||||
# especially for very long generations
|
|
||||||
def _update(x, v):
|
|
||||||
t = x.shape[2] - self.window_size
|
|
||||||
if t > 0:
|
|
||||||
x = x[..., t:, :]
|
|
||||||
return mx.concatenate([x, v], axis=2)
|
|
||||||
|
|
||||||
self.offset += keys.shape[2]
|
|
||||||
if self.keys is None:
|
|
||||||
self.keys = keys
|
|
||||||
self.values = values
|
|
||||||
else:
|
|
||||||
self.keys = _update(self.keys, keys)
|
|
||||||
self.values = _update(self.values, values)
|
|
||||||
return self.keys, self.values
|
|
||||||
|
|
||||||
def state(self):
|
|
||||||
return self.keys, self.values
|
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
def __init__(self, dims: int, eps: float = 1e-5):
|
def __init__(self, dims: int, eps: float = 1e-5):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -136,31 +83,22 @@ class Conv1d(nn.Module):
|
|||||||
kernel_size: int,
|
kernel_size: int,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = mx.zeros((kernel_size, channels))
|
self.weight = mx.zeros((channels, kernel_size, 1))
|
||||||
self.bias = mx.zeros((channels,))
|
self.bias = mx.zeros((channels,))
|
||||||
|
|
||||||
def __call__(self, x, cache=None):
|
def __call__(self, x, cache=None):
|
||||||
w = self.weight.T[..., None]
|
B, L, C = x.shape
|
||||||
kw, groups = self.weight.shape
|
groups, K, _ = self.weight.shape
|
||||||
if cache is not None:
|
|
||||||
l = []
|
|
||||||
# Pad the cache if needed
|
|
||||||
if cache.shape[1] < kw - 1:
|
|
||||||
l.append(
|
|
||||||
mx.zeros(
|
|
||||||
(x.shape[0], kw - 1 - cache.shape[1], groups), dtype=x.dtype
|
|
||||||
)
|
|
||||||
)
|
|
||||||
l.extend([cache, x])
|
|
||||||
x = mx.concatenate(l, axis=1)
|
|
||||||
y = (x * w.swapaxes(0, 2)).sum(axis=1, keepdims=True)
|
|
||||||
else:
|
|
||||||
y = mx.conv_general(x, w, padding=([kw - 1], [0]), groups=groups)
|
|
||||||
|
|
||||||
# The cache is always kw - 1
|
if cache is not None:
|
||||||
cache = x[:, max(x.shape[1] - kw + 1, 0) :, :]
|
x = mx.concatenate([cache, x], axis=1)
|
||||||
|
else:
|
||||||
|
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
|
||||||
|
|
||||||
|
y = mx.conv_general(x, self.weight, groups=groups)
|
||||||
y = y + self.bias
|
y = y + self.bias
|
||||||
return y, cache
|
|
||||||
|
return y, x[:, -K + 1 :, :]
|
||||||
|
|
||||||
|
|
||||||
class RGLRU(nn.Module):
|
class RGLRU(nn.Module):
|
||||||
@ -269,19 +207,9 @@ class RecurrentBlock(nn.Module):
|
|||||||
# x branch.
|
# x branch.
|
||||||
x = self.linear_x(x)
|
x = self.linear_x(x)
|
||||||
if cache is None:
|
if cache is None:
|
||||||
conv_state, recurrent_state = (None, None)
|
cache = [None, None]
|
||||||
else:
|
x, cache[0] = self.conv_1d(x=x, cache=cache[0])
|
||||||
conv_state, recurrent_state = cache[0], cache[1]
|
x, cache[1] = self.rg_lru(x=x, cache=cache[1])
|
||||||
x, conv_state = self.conv_1d(
|
|
||||||
x=x,
|
|
||||||
cache=conv_state,
|
|
||||||
)
|
|
||||||
x, recurrent_state = self.rg_lru(
|
|
||||||
x=x,
|
|
||||||
cache=recurrent_state,
|
|
||||||
)
|
|
||||||
if cache is not None:
|
|
||||||
cache.update(conv_state, recurrent_state)
|
|
||||||
|
|
||||||
x = x * y
|
x = x * y
|
||||||
x = self.linear_out(x)
|
x = self.linear_out(x)
|
||||||
@ -467,12 +395,14 @@ class Griffin(nn.Module):
|
|||||||
if self.scale_by_sqrt_dim:
|
if self.scale_by_sqrt_dim:
|
||||||
x = x * math.sqrt(x.shape[-1])
|
x = x * math.sqrt(x.shape[-1])
|
||||||
|
|
||||||
mask = None
|
if cache is None:
|
||||||
if x.shape[1] > 1:
|
cache = [None] * len(self.layers)
|
||||||
mask = create_window_causal_mask(
|
|
||||||
x.shape[1], self.config.attention_window_size
|
for i, block in enumerate(self.layers):
|
||||||
)
|
if block.temporal_block_type != "recurrent":
|
||||||
mask = mask.astype(x.dtype)
|
mask_cache = [cache[i]]
|
||||||
|
|
||||||
|
mask = create_attention_mask(x, mask_cache)
|
||||||
|
|
||||||
for i, block in enumerate(self.layers):
|
for i, block in enumerate(self.layers):
|
||||||
x = block(x, mask=mask, cache=cache[i])
|
x = block(x, mask=mask, cache=cache[i])
|
||||||
@ -485,6 +415,7 @@ class Model(nn.Module):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
self.args = config
|
self.args = config
|
||||||
self.model = Griffin(config)
|
self.model = Griffin(config)
|
||||||
|
self.model_type = config.model_type
|
||||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
def __call__(self, tokens: mx.array, cache=None) -> mx.array:
|
def __call__(self, tokens: mx.array, cache=None) -> mx.array:
|
||||||
@ -508,10 +439,9 @@ class Model(nn.Module):
|
|||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|
||||||
def sanitize(self, weights):
|
def sanitize(self, weights):
|
||||||
# Remove unused precomputed rotary freqs
|
|
||||||
for k, v in weights.items():
|
for k, v in weights.items():
|
||||||
if "conv_1d.weight" in k and v.ndim == 3:
|
if "conv_1d.weight" in k and v.ndim == 3:
|
||||||
weights[k] = v.squeeze(1).T
|
weights[k] = v.moveaxis(2, 1)
|
||||||
if "lm_head.weight" not in weights:
|
if "lm_head.weight" not in weights:
|
||||||
self.pop("lm_head")
|
self.pop("lm_head")
|
||||||
return weights
|
return weights
|
||||||
@ -520,7 +450,7 @@ class Model(nn.Module):
|
|||||||
cache = []
|
cache = []
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
if layer.temporal_block_type == "recurrent":
|
if layer.temporal_block_type == "recurrent":
|
||||||
cache.append(RecurrentCache())
|
cache.append(MambaCache())
|
||||||
else:
|
else:
|
||||||
cache.append(WindowKVCache(self.args.attention_window_size))
|
cache.append(RotatingKVCache(max_size=self.args.attention_window_size))
|
||||||
return cache
|
return cache
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -198,8 +197,8 @@ class Model(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: mx.array = None,
|
mask: mx.array = None,
|
||||||
cache: mx.array = None,
|
cache=None,
|
||||||
) -> Tuple[mx.array, mx.array]:
|
) -> mx.array:
|
||||||
mask = create_attention_mask(x, cache)
|
mask = create_attention_mask(x, cache)
|
||||||
y = self.model(x, mask, cache)
|
y = self.model(x, mask, cache)
|
||||||
return self.lm_head(y)
|
return self.lm_head(y)
|
||||||
@ -207,11 +206,3 @@ class Model(nn.Module):
|
|||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|
||||||
@property
|
|
||||||
def head_dim(self):
|
|
||||||
return self.args.hidden_size // self.args.num_attention_heads
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_kv_heads(self):
|
|
||||||
return self.args.num_key_value_heads
|
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple
|
from typing import Any, Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -45,7 +45,7 @@ class Attention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[KVCache] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
B, L, D = x.shape
|
B, L, D = x.shape
|
||||||
|
|
||||||
@ -100,7 +100,7 @@ class TransformerBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[KVCache] = None,
|
cache: Optional[Any] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
@ -164,11 +164,3 @@ class Model(nn.Module):
|
|||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|
||||||
@property
|
|
||||||
def head_dim(self):
|
|
||||||
return self.args.hidden_size // self.args.num_attention_heads
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n_kv_heads(self):
|
|
||||||
return self.args.num_key_value_heads
|
|
||||||
|
@ -18,7 +18,7 @@ from mlx.utils import tree_flatten
|
|||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
# Local imports
|
# Local imports
|
||||||
from .models.base import KVCache, RotatingKVCache
|
from .models import base, cache
|
||||||
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
|
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
|
||||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||||
from .tuner.utils import dequantize as dequantize_model
|
from .tuner.utils import dequantize as dequantize_model
|
||||||
@ -124,26 +124,6 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float)
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def make_kv_caches(
|
|
||||||
model: nn.Module, max_kv_size: Optional[int] = None
|
|
||||||
) -> List[Union[KVCache, RotatingKVCache]]:
|
|
||||||
if hasattr(model, "make_cache"):
|
|
||||||
return model.make_cache()
|
|
||||||
|
|
||||||
kv_heads = (
|
|
||||||
[model.n_kv_heads] * len(model.layers)
|
|
||||||
if isinstance(model.n_kv_heads, int)
|
|
||||||
else model.n_kv_heads
|
|
||||||
)
|
|
||||||
if max_kv_size is not None:
|
|
||||||
return [
|
|
||||||
RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4)
|
|
||||||
for n in kv_heads
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
return [KVCache(model.head_dim, n) for n in kv_heads]
|
|
||||||
|
|
||||||
|
|
||||||
def generate_step(
|
def generate_step(
|
||||||
prompt: mx.array,
|
prompt: mx.array,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -155,7 +135,7 @@ def generate_step(
|
|||||||
min_tokens_to_keep: int = 1,
|
min_tokens_to_keep: int = 1,
|
||||||
prefill_step_size: int = 512,
|
prefill_step_size: int = 512,
|
||||||
max_kv_size: Optional[int] = None,
|
max_kv_size: Optional[int] = None,
|
||||||
cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
prompt_cache: Optional[Any] = None,
|
||||||
logit_bias: Optional[Dict[int, float]] = None,
|
logit_bias: Optional[Dict[int, float]] = None,
|
||||||
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||||
@ -180,6 +160,8 @@ def generate_step(
|
|||||||
prefill_step_size (int): Step size for processing the prompt.
|
prefill_step_size (int): Step size for processing the prompt.
|
||||||
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
||||||
entries (except the first 4 tokens) will be overwritten.
|
entries (except the first 4 tokens) will be overwritten.
|
||||||
|
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
||||||
|
provided, the cache will be updated in place.
|
||||||
logit_bias (dictionary, optional): Additive logit bias.
|
logit_bias (dictionary, optional): Additive logit bias.
|
||||||
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||||
A list of functions that take tokens and logits and return the processed
|
A list of functions that take tokens and logits and return the processed
|
||||||
@ -237,20 +219,13 @@ def generate_step(
|
|||||||
tokens = None
|
tokens = None
|
||||||
|
|
||||||
# Create the KV cache for generation
|
# Create the KV cache for generation
|
||||||
cache = make_kv_caches(model, max_kv_size)
|
if prompt_cache is None:
|
||||||
|
prompt_cache = cache.make_prompt_cache(model, max_kv_size)
|
||||||
if cache_history is not None:
|
elif len(prompt_cache) != len(model.layers):
|
||||||
if len(cache_history) != len(cache):
|
raise ValueError("Wrong number of layers in the prompt cache.")
|
||||||
raise ValueError("Wrong number of layers in the cache history")
|
|
||||||
|
|
||||||
# Set the history in the cache objects and evaluate them to prepare for
|
|
||||||
# generation.
|
|
||||||
for c, h in zip(cache, cache_history):
|
|
||||||
c.update_and_fetch(h[0], h[1])
|
|
||||||
mx.eval([c.state for c in cache])
|
|
||||||
|
|
||||||
def _step(y):
|
def _step(y):
|
||||||
logits = model(y[None], cache=cache)
|
logits = model(y[None], cache=prompt_cache)
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
|
|
||||||
if logits_processor:
|
if logits_processor:
|
||||||
@ -305,9 +280,9 @@ def stream_generate(
|
|||||||
detokenizer = tokenizer.detokenizer
|
detokenizer = tokenizer.detokenizer
|
||||||
|
|
||||||
detokenizer.reset()
|
detokenizer.reset()
|
||||||
for (token, _), n in zip(
|
for n, (token, _) in zip(
|
||||||
generate_step(prompt_tokens, model, **kwargs),
|
|
||||||
range(max_tokens),
|
range(max_tokens),
|
||||||
|
generate_step(prompt_tokens, model, **kwargs),
|
||||||
):
|
):
|
||||||
if token == tokenizer.eos_token_id:
|
if token == tokenizer.eos_token_id:
|
||||||
break
|
break
|
||||||
@ -357,9 +332,9 @@ def generate(
|
|||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
detokenizer.reset()
|
detokenizer.reset()
|
||||||
|
|
||||||
for (token, logprobs), n in zip(
|
for n, (token, logprobs) in zip(
|
||||||
generate_step(prompt_tokens, model, **kwargs),
|
|
||||||
range(max_tokens),
|
range(max_tokens),
|
||||||
|
generate_step(prompt_tokens, model, **kwargs),
|
||||||
):
|
):
|
||||||
if n == 0:
|
if n == 0:
|
||||||
prompt_time = time.perf_counter() - tic
|
prompt_time = time.perf_counter() - tic
|
||||||
|
@ -32,6 +32,7 @@ setup(
|
|||||||
entry_points={
|
entry_points={
|
||||||
"console_scripts": [
|
"console_scripts": [
|
||||||
"mlx_lm.cache_prompt = mlx_lm.cache_prompt:main",
|
"mlx_lm.cache_prompt = mlx_lm.cache_prompt:main",
|
||||||
|
"mlx_lm.chat = mlx_lm.chat:main",
|
||||||
"mlx_lm.convert = mlx_lm.convert:main",
|
"mlx_lm.convert = mlx_lm.convert:main",
|
||||||
"mlx_lm.fuse = mlx_lm.fuse:main",
|
"mlx_lm.fuse = mlx_lm.fuse:main",
|
||||||
"mlx_lm.generate = mlx_lm.generate:main",
|
"mlx_lm.generate = mlx_lm.generate:main",
|
||||||
|
@ -1,17 +1,15 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.utils import tree_map
|
from mlx.utils import tree_map
|
||||||
from mlx_lm.models.base import KVCache, RotatingKVCache
|
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
|
||||||
from mlx_lm.utils import make_kv_caches
|
|
||||||
|
|
||||||
|
|
||||||
class TestModels(unittest.TestCase):
|
class TestModels(unittest.TestCase):
|
||||||
|
|
||||||
def test_kv_cache(self):
|
def test_kv_cache(self):
|
||||||
cache = KVCache(32, 4)
|
cache = KVCache()
|
||||||
|
|
||||||
k = mx.ones((1, 4, 1, 32), mx.float16)
|
k = mx.ones((1, 4, 1, 32), mx.float16)
|
||||||
v = mx.ones((1, 4, 1, 32), mx.float16)
|
v = mx.ones((1, 4, 1, 32), mx.float16)
|
||||||
@ -32,7 +30,7 @@ class TestModels(unittest.TestCase):
|
|||||||
|
|
||||||
def test_rotating_kv_cache(self):
|
def test_rotating_kv_cache(self):
|
||||||
b, h, d = 1, 2, 32
|
b, h, d = 1, 2, 32
|
||||||
cache = RotatingKVCache(d, h, max_size=8, step=4)
|
cache = RotatingKVCache(max_size=8, step=4)
|
||||||
|
|
||||||
k = mx.random.uniform(shape=(b, h, 2, d))
|
k = mx.random.uniform(shape=(b, h, 2, d))
|
||||||
v = mx.random.uniform(shape=(b, h, 2, d))
|
v = mx.random.uniform(shape=(b, h, 2, d))
|
||||||
@ -65,7 +63,7 @@ class TestModels(unittest.TestCase):
|
|||||||
idx %= 8
|
idx %= 8
|
||||||
|
|
||||||
# Try with nonzero keep
|
# Try with nonzero keep
|
||||||
cache = RotatingKVCache(d, h, max_size=8, step=4, keep=2)
|
cache = RotatingKVCache(max_size=8, step=4, keep=2)
|
||||||
|
|
||||||
# Check a large update
|
# Check a large update
|
||||||
k = mx.random.uniform(shape=(b, h, 20, d))
|
k = mx.random.uniform(shape=(b, h, 20, d))
|
||||||
@ -88,6 +86,46 @@ class TestModels(unittest.TestCase):
|
|||||||
if idx >= 8:
|
if idx >= 8:
|
||||||
idx = 2
|
idx = 2
|
||||||
|
|
||||||
|
def test_rotating_kv_cache_chat_mode(self):
|
||||||
|
# Test that the rotating kv cache can handle
|
||||||
|
# alternating prompt/prefill with generation
|
||||||
|
d = 4
|
||||||
|
h = 2
|
||||||
|
cache = RotatingKVCache(max_size=18, step=4)
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(1, h, 8, d))
|
||||||
|
k, v = cache.update_and_fetch(x, x)
|
||||||
|
self.assertEqual(k.shape[2], 8)
|
||||||
|
self.assertEqual(cache.offset, 8)
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(1, h, 1, d))
|
||||||
|
k, v = cache.update_and_fetch(x, x)
|
||||||
|
self.assertEqual(k.shape[2], 9)
|
||||||
|
self.assertEqual(cache.offset, 9)
|
||||||
|
self.assertTrue(mx.allclose(x, k[..., 8:9, :]))
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(1, h, 2, d))
|
||||||
|
k, v = cache.update_and_fetch(x, x)
|
||||||
|
self.assertEqual(k.shape[2], 11)
|
||||||
|
self.assertEqual(cache.offset, 11)
|
||||||
|
self.assertTrue(mx.allclose(x, k[..., 9:11, :]))
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(1, h, 3, d))
|
||||||
|
k, v = cache.update_and_fetch(x, x)
|
||||||
|
self.assertEqual(k.shape[2], 14)
|
||||||
|
self.assertEqual(cache.offset, 14)
|
||||||
|
self.assertTrue(mx.allclose(x, k[..., 11:14, :]))
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(1, h, 6, d))
|
||||||
|
k, v = cache.update_and_fetch(x, x)
|
||||||
|
self.assertEqual(cache.offset, 20)
|
||||||
|
self.assertTrue(mx.allclose(x, k[..., -6:, :]))
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(1, h, 2, d))
|
||||||
|
k, v = cache.update_and_fetch(x, x)
|
||||||
|
self.assertEqual(cache.offset, 22)
|
||||||
|
self.assertTrue(mx.allclose(x, k[..., -2:, :]))
|
||||||
|
|
||||||
def model_test_runner(self, model, model_type, vocab_size, num_layers):
|
def model_test_runner(self, model, model_type, vocab_size, num_layers):
|
||||||
|
|
||||||
self.assertEqual(len(model.layers), num_layers)
|
self.assertEqual(len(model.layers), num_layers)
|
||||||
@ -101,7 +139,7 @@ class TestModels(unittest.TestCase):
|
|||||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||||
self.assertEqual(outputs.dtype, t)
|
self.assertEqual(outputs.dtype, t)
|
||||||
|
|
||||||
cache = make_kv_caches(model)
|
cache = make_prompt_cache(model)
|
||||||
outputs = model(inputs, cache)
|
outputs = model(inputs, cache)
|
||||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||||
self.assertEqual(outputs.dtype, t)
|
self.assertEqual(outputs.dtype, t)
|
||||||
@ -549,6 +587,179 @@ class TestModels(unittest.TestCase):
|
|||||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_deepseek(self):
|
||||||
|
from mlx_lm.models import deepseek
|
||||||
|
|
||||||
|
args = deepseek.ModelArgs(
|
||||||
|
model_type="deepseek",
|
||||||
|
vocab_size=1024,
|
||||||
|
hidden_size=128,
|
||||||
|
intermediate_size=256,
|
||||||
|
moe_intermediate_size=256,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
num_attention_heads=8,
|
||||||
|
num_key_value_heads=4,
|
||||||
|
)
|
||||||
|
model = deepseek.Model(args)
|
||||||
|
self.model_test_runner(
|
||||||
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_deepseek_v2(self):
|
||||||
|
from mlx_lm.models import deepseek_v2
|
||||||
|
|
||||||
|
args = deepseek_v2.ModelArgs(
|
||||||
|
model_type="deepseek_v2",
|
||||||
|
vocab_size=1024,
|
||||||
|
hidden_size=128,
|
||||||
|
intermediate_size=256,
|
||||||
|
moe_intermediate_size=256,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
num_attention_heads=4,
|
||||||
|
num_key_value_heads=2,
|
||||||
|
kv_lora_rank=4,
|
||||||
|
q_lora_rank=4,
|
||||||
|
qk_rope_head_dim=32,
|
||||||
|
v_head_dim=16,
|
||||||
|
qk_nope_head_dim=32,
|
||||||
|
rope_scaling={
|
||||||
|
"beta_fast": 32,
|
||||||
|
"beta_slow": 1,
|
||||||
|
"factor": 40,
|
||||||
|
"mscale": 1.0,
|
||||||
|
"mscale_all_dim": 1.0,
|
||||||
|
"original_max_position_embeddings": 4096,
|
||||||
|
"type": "yarn",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
model = deepseek_v2.Model(args)
|
||||||
|
self.model_test_runner(
|
||||||
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_gemma2(self):
|
||||||
|
from mlx_lm.models import gemma2
|
||||||
|
|
||||||
|
args = gemma2.ModelArgs(
|
||||||
|
model_type="gemma2",
|
||||||
|
hidden_size=128,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
intermediate_size=256,
|
||||||
|
num_attention_heads=2,
|
||||||
|
head_dim=32,
|
||||||
|
rms_norm_eps=1e-4,
|
||||||
|
vocab_size=1024,
|
||||||
|
num_key_value_heads=2,
|
||||||
|
)
|
||||||
|
model = gemma2.Model(args)
|
||||||
|
self.model_test_runner(
|
||||||
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_gpt_bigcode(self):
|
||||||
|
from mlx_lm.models import gpt_bigcode
|
||||||
|
|
||||||
|
args = gpt_bigcode.ModelArgs(
|
||||||
|
model_type="gpt_bigcode",
|
||||||
|
n_embd=128,
|
||||||
|
n_layer=128,
|
||||||
|
n_inner=256,
|
||||||
|
n_head=4,
|
||||||
|
n_positions=1000,
|
||||||
|
layer_norm_epsilon=1e-5,
|
||||||
|
vocab_size=1024,
|
||||||
|
)
|
||||||
|
model = gpt_bigcode.Model(args)
|
||||||
|
self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layer)
|
||||||
|
|
||||||
|
def test_nemotron(self):
|
||||||
|
from mlx_lm.models import nemotron
|
||||||
|
|
||||||
|
args = nemotron.ModelArgs(
|
||||||
|
model_type="nemotron",
|
||||||
|
hidden_size=128,
|
||||||
|
hidden_act="gelu",
|
||||||
|
num_hidden_layers=4,
|
||||||
|
intermediate_size=256,
|
||||||
|
num_attention_heads=4,
|
||||||
|
norm_eps=1e-5,
|
||||||
|
vocab_size=1024,
|
||||||
|
num_key_value_heads=2,
|
||||||
|
)
|
||||||
|
model = nemotron.Model(args)
|
||||||
|
self.model_test_runner(
|
||||||
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_phi3small(self):
|
||||||
|
from mlx_lm.models import phi3small
|
||||||
|
|
||||||
|
args = phi3small.ModelArgs(
|
||||||
|
model_type="phi3small",
|
||||||
|
hidden_size=128,
|
||||||
|
dense_attention_every_n_layers=2,
|
||||||
|
ff_intermediate_size=256,
|
||||||
|
gegelu_limit=1.0,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
num_attention_heads=4,
|
||||||
|
num_key_value_heads=2,
|
||||||
|
layer_norm_epsilon=1e-4,
|
||||||
|
vocab_size=1000,
|
||||||
|
)
|
||||||
|
model = phi3small.Model(args)
|
||||||
|
self.model_test_runner(
|
||||||
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_phimoe(self):
|
||||||
|
from mlx_lm.models import phimoe
|
||||||
|
|
||||||
|
args = phimoe.ModelArgs(
|
||||||
|
model_type="phimoe",
|
||||||
|
vocab_size=320,
|
||||||
|
hidden_size=128,
|
||||||
|
intermediate_size=256,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
num_attention_heads=4,
|
||||||
|
num_key_value_heads=4,
|
||||||
|
rope_scaling={
|
||||||
|
"long_factor": [1.0] * 16,
|
||||||
|
"long_mscale": 1.243163121016122,
|
||||||
|
"original_max_position_embeddings": 4096,
|
||||||
|
"short_factor": [1.0] * 16,
|
||||||
|
"short_mscale": 1.243163121016122,
|
||||||
|
"type": "longrope",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
model = phimoe.Model(args)
|
||||||
|
self.model_test_runner(
|
||||||
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_recurrent_gemma(self):
|
||||||
|
from mlx_lm.models import recurrent_gemma
|
||||||
|
|
||||||
|
args = recurrent_gemma.ModelArgs(
|
||||||
|
model_type="recurrent_gemma",
|
||||||
|
hidden_size=128,
|
||||||
|
attention_bias=False,
|
||||||
|
conv1d_width=3,
|
||||||
|
intermediate_size=256,
|
||||||
|
logits_soft_cap=1.0,
|
||||||
|
num_attention_heads=4,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
num_key_value_heads=2,
|
||||||
|
rms_norm_eps=1e-4,
|
||||||
|
rope_theta=1000,
|
||||||
|
attention_window_size=1024,
|
||||||
|
vocab_size=1000,
|
||||||
|
block_types=["recurrent", "recurrent", "attention"],
|
||||||
|
)
|
||||||
|
model = recurrent_gemma.Model(args)
|
||||||
|
self.model_test_runner(
|
||||||
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
220
llms/tests/test_prompt_cache.py
Normal file
220
llms/tests/test_prompt_cache.py
Normal file
@ -0,0 +1,220 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx_lm.models.cache import (
|
||||||
|
KVCache,
|
||||||
|
MambaCache,
|
||||||
|
RotatingKVCache,
|
||||||
|
load_prompt_cache,
|
||||||
|
make_prompt_cache,
|
||||||
|
save_prompt_cache,
|
||||||
|
trim_prompt_cache,
|
||||||
|
)
|
||||||
|
from mlx_lm.utils import generate_step, load
|
||||||
|
|
||||||
|
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptCache(unittest.TestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.test_dir_fid = tempfile.TemporaryDirectory()
|
||||||
|
cls.test_dir = cls.test_dir_fid.name
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
cls.test_dir_fid.cleanup()
|
||||||
|
|
||||||
|
def test_save_load(self):
|
||||||
|
cache = [KVCache() for _ in range(4)]
|
||||||
|
for c in cache:
|
||||||
|
x = mx.random.uniform(shape=(1, 8, 10, 4))
|
||||||
|
c.update_and_fetch(x, x)
|
||||||
|
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
|
||||||
|
save_prompt_cache(cache_file, cache)
|
||||||
|
loaded_cache = load_prompt_cache(cache_file)
|
||||||
|
self.assertTrue(len(cache), len(loaded_cache))
|
||||||
|
for c, lc in zip(cache, loaded_cache):
|
||||||
|
self.assertEqual(c.offset, lc.offset)
|
||||||
|
self.assertTrue(mx.array_equal(c.state[0], lc.state[0]))
|
||||||
|
self.assertTrue(mx.array_equal(c.state[1], lc.state[1]))
|
||||||
|
|
||||||
|
# Test with metadata
|
||||||
|
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
|
||||||
|
metadata = {"a": "b", "c": "d"}
|
||||||
|
save_prompt_cache(cache_file, cache, metadata)
|
||||||
|
_, loaded_metadata = load_prompt_cache(cache_file, return_metadata=True)
|
||||||
|
self.assertEqual(metadata, loaded_metadata)
|
||||||
|
|
||||||
|
def test_save_load_rotating_cache(self):
|
||||||
|
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
|
||||||
|
|
||||||
|
# Test with rotating cache
|
||||||
|
cache = [RotatingKVCache(max_size=8, keep=2) for _ in range(4)]
|
||||||
|
for c in cache:
|
||||||
|
x = mx.random.uniform(shape=(1, 8, 10, 4))
|
||||||
|
c.update_and_fetch(x, x)
|
||||||
|
|
||||||
|
save_prompt_cache(cache_file, cache)
|
||||||
|
loaded_cache = load_prompt_cache(cache_file)
|
||||||
|
self.assertTrue(len(cache), len(loaded_cache))
|
||||||
|
for c, lc in zip(cache, loaded_cache):
|
||||||
|
self.assertEqual(c.offset, lc.offset)
|
||||||
|
self.assertEqual(c.keep, lc.keep)
|
||||||
|
self.assertEqual(c.max_size, lc.max_size)
|
||||||
|
self.assertEqual(c.step, lc.step)
|
||||||
|
self.assertTrue(mx.array_equal(c.state[0], lc.state[0]))
|
||||||
|
self.assertTrue(mx.array_equal(c.state[1], lc.state[1]))
|
||||||
|
|
||||||
|
# Do a couple single token updates to get a rotation
|
||||||
|
for _ in range(2):
|
||||||
|
for c in cache:
|
||||||
|
x = mx.random.uniform(shape=(1, 8, 1, 4))
|
||||||
|
c.update_and_fetch(x, x)
|
||||||
|
|
||||||
|
save_prompt_cache(cache_file, cache)
|
||||||
|
loaded_cache = load_prompt_cache(cache_file)
|
||||||
|
|
||||||
|
for c, lc in zip(cache, loaded_cache):
|
||||||
|
x = mx.random.uniform(shape=(1, 8, 1, 4))
|
||||||
|
k, v = c.update_and_fetch(x, x)
|
||||||
|
lk, lv = lc.update_and_fetch(x, x)
|
||||||
|
self.assertEqual(c.offset, lc.offset)
|
||||||
|
self.assertTrue(mx.array_equal(k, lk))
|
||||||
|
self.assertTrue(mx.array_equal(v, lv))
|
||||||
|
|
||||||
|
def test_save_load_mixed_cache(self):
|
||||||
|
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
|
||||||
|
|
||||||
|
cache = [MambaCache(), KVCache(), RotatingKVCache(8), MambaCache()]
|
||||||
|
for c in cache:
|
||||||
|
if isinstance(c, MambaCache):
|
||||||
|
c[0] = mx.random.uniform(shape=(4, 4, 4))
|
||||||
|
c[1] = mx.random.uniform(shape=(4, 4, 4))
|
||||||
|
else:
|
||||||
|
x = mx.random.uniform(shape=(4, 4, 7, 4))
|
||||||
|
y = mx.random.uniform(shape=(4, 4, 7, 4))
|
||||||
|
c.update_and_fetch(x, y)
|
||||||
|
|
||||||
|
save_prompt_cache(cache_file, cache)
|
||||||
|
loaded_cache = load_prompt_cache(cache_file)
|
||||||
|
for c, lc in zip(cache, loaded_cache):
|
||||||
|
if isinstance(c, MambaCache):
|
||||||
|
self.assertTrue(mx.array_equal(c[0], lc[0]))
|
||||||
|
self.assertTrue(mx.array_equal(c[1], lc[1]))
|
||||||
|
else:
|
||||||
|
x = mx.random.uniform(shape=(4, 4, 1, 4))
|
||||||
|
y = mx.random.uniform(shape=(4, 4, 1, 4))
|
||||||
|
k, v = c.update_and_fetch(x, y)
|
||||||
|
lk, lv = lc.update_and_fetch(x, y)
|
||||||
|
self.assertEqual(c.offset, lc.offset)
|
||||||
|
self.assertTrue(mx.array_equal(k, lk))
|
||||||
|
self.assertTrue(mx.array_equal(v, lv))
|
||||||
|
|
||||||
|
def test_cache_with_generate(self):
|
||||||
|
model, tokenizer = load(HF_MODEL_PATH)
|
||||||
|
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
|
||||||
|
results = zip(range(4), generate_step(prompt, model))
|
||||||
|
toks, all_logits = zip(*(r[1] for r in results))
|
||||||
|
|
||||||
|
prompt_cache = make_prompt_cache(model)
|
||||||
|
i = 0
|
||||||
|
for _, (tok, logits) in zip(
|
||||||
|
range(2), generate_step(prompt, model, prompt_cache=prompt_cache)
|
||||||
|
):
|
||||||
|
self.assertEqual(tok, toks[i])
|
||||||
|
self.assertTrue(mx.allclose(logits, all_logits[i]))
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
for _, (tok, logits) in zip(
|
||||||
|
range(1),
|
||||||
|
generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache),
|
||||||
|
):
|
||||||
|
i += 1
|
||||||
|
self.assertEqual(tok, toks[i])
|
||||||
|
self.assertTrue(mx.allclose(logits, all_logits[i]))
|
||||||
|
|
||||||
|
def test_trim_cache(self):
|
||||||
|
cache = [KVCache() for _ in range(2)]
|
||||||
|
for c in cache:
|
||||||
|
x = mx.random.uniform(shape=(1, 8, 10, 4))
|
||||||
|
c.update_and_fetch(x, x)
|
||||||
|
|
||||||
|
# Trim
|
||||||
|
num_trimmed = trim_prompt_cache(cache, 7)
|
||||||
|
self.assertEqual(num_trimmed, 7)
|
||||||
|
|
||||||
|
# Trim more tokens than remain
|
||||||
|
num_trimmed = trim_prompt_cache(cache, 4)
|
||||||
|
self.assertEqual(num_trimmed, 3)
|
||||||
|
|
||||||
|
# Can't trim mamba cache
|
||||||
|
cache = [MambaCache() for _ in range(2)]
|
||||||
|
for c in cache:
|
||||||
|
c.state = mx.zeros((5, 5))
|
||||||
|
num_trimmed = trim_prompt_cache(cache, 7)
|
||||||
|
self.assertEqual(num_trimmed, 0)
|
||||||
|
|
||||||
|
# All cache's have to be trimmable
|
||||||
|
cache = [MambaCache(), KVCache()]
|
||||||
|
cache[0].state = mx.zeros((5, 5))
|
||||||
|
x = mx.random.uniform(shape=(1, 8, 10, 4))
|
||||||
|
cache[1].update_and_fetch(x, x)
|
||||||
|
num_trimmed = trim_prompt_cache(cache, 1)
|
||||||
|
self.assertEqual(num_trimmed, 0)
|
||||||
|
|
||||||
|
cache = [RotatingKVCache(max_size=6) for _ in range(2)]
|
||||||
|
for c in cache:
|
||||||
|
x = mx.random.uniform(shape=(1, 8, 5, 4))
|
||||||
|
c.update_and_fetch(x, x)
|
||||||
|
|
||||||
|
num_trimmed = trim_prompt_cache(cache, 4)
|
||||||
|
self.assertEqual(num_trimmed, 4)
|
||||||
|
|
||||||
|
# Can't trim fixed-size KV cache after processing
|
||||||
|
# more than max_kv_size tokens
|
||||||
|
for c in cache:
|
||||||
|
x = mx.random.uniform(shape=(1, 8, 10, 4))
|
||||||
|
c.update_and_fetch(x, x)
|
||||||
|
|
||||||
|
num_trimmed = trim_prompt_cache(cache, 4)
|
||||||
|
self.assertEqual(num_trimmed, 0)
|
||||||
|
|
||||||
|
def test_trim_cache_with_generate(self):
|
||||||
|
model, tokenizer = load(HF_MODEL_PATH)
|
||||||
|
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
|
||||||
|
|
||||||
|
prompt_cache = make_prompt_cache(model)
|
||||||
|
|
||||||
|
# Generate one token so we process the full prompt
|
||||||
|
last_tok, _ = next(generate_step(prompt, model, prompt_cache=prompt_cache))
|
||||||
|
last_tok = mx.array([last_tok])
|
||||||
|
|
||||||
|
# Generate two more tokens
|
||||||
|
results = zip(
|
||||||
|
range(2), generate_step(last_tok, model, prompt_cache=prompt_cache)
|
||||||
|
)
|
||||||
|
toks, all_logits = zip(*(r[1] for r in results))
|
||||||
|
|
||||||
|
# To get back to the cache just after processing the prompt,
|
||||||
|
# trim by 3 tokens
|
||||||
|
trim_prompt_cache(prompt_cache, 3)
|
||||||
|
|
||||||
|
# Generate the same thing again
|
||||||
|
results = zip(
|
||||||
|
range(2), generate_step(last_tok, model, prompt_cache=prompt_cache)
|
||||||
|
)
|
||||||
|
second_toks, second_all_logits = zip(*(r[1] for r in results))
|
||||||
|
self.assertEqual(toks, second_toks)
|
||||||
|
self.assertTrue(
|
||||||
|
all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue
Block a user