mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 10:41:18 +08:00
Merge branch 'ml-explore:main' into adding-support-for-mamba2
This commit is contained in:
commit
49d3f188f8
@ -188,7 +188,7 @@ The adapters are saved in `mlx_output` and can be used directly by the
|
|||||||
|
|
||||||
```shell
|
```shell
|
||||||
python txt2image.py --model dev --save-raw --image-size 512x512 --n-images 1 \
|
python txt2image.py --model dev --save-raw --image-size 512x512 --n-images 1 \
|
||||||
--adapter mlx_output/0001200_adapters.safetensors \
|
--adapter mlx_output/final_adapters.safetensors \
|
||||||
--fuse-adapter \
|
--fuse-adapter \
|
||||||
--no-t5-padding \
|
--no-t5-padding \
|
||||||
'A photo of an sks dog lying on the sand at a beach in Greece'
|
'A photo of an sks dog lying on the sand at a beach in Greece'
|
||||||
|
@ -13,7 +13,7 @@ from mlx.nn.utils import average_gradients
|
|||||||
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from flux import FluxPipeline, Trainer, load_dataset
|
from flux import FluxPipeline, Trainer, load_dataset, save_config
|
||||||
|
|
||||||
|
|
||||||
def generate_progress_images(iteration, flux, args):
|
def generate_progress_images(iteration, flux, args):
|
||||||
@ -43,10 +43,10 @@ def generate_progress_images(iteration, flux, args):
|
|||||||
im.save(out_file)
|
im.save(out_file)
|
||||||
|
|
||||||
|
|
||||||
def save_adapters(iteration, flux, args):
|
def save_adapters(adapter_name, flux, args):
|
||||||
out_dir = Path(args.output_dir)
|
out_dir = Path(args.output_dir)
|
||||||
out_dir.mkdir(parents=True, exist_ok=True)
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
out_file = out_dir / f"{iteration:07d}_adapters.safetensors"
|
out_file = out_dir / adapter_name
|
||||||
print(f"Saving {str(out_file)}")
|
print(f"Saving {str(out_file)}")
|
||||||
|
|
||||||
mx.save_safetensors(
|
mx.save_safetensors(
|
||||||
@ -157,6 +157,10 @@ if __name__ == "__main__":
|
|||||||
parser = setup_arg_parser()
|
parser = setup_arg_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
output_path = Path(args.output_dir)
|
||||||
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
save_config(vars(args), output_path / "adapter_config.json")
|
||||||
|
|
||||||
# Load the model and set it up for LoRA training. We use the same random
|
# Load the model and set it up for LoRA training. We use the same random
|
||||||
# state when creating the LoRA layers so all workers will have the same
|
# state when creating the LoRA layers so all workers will have the same
|
||||||
# initial weights.
|
# initial weights.
|
||||||
@ -278,8 +282,11 @@ if __name__ == "__main__":
|
|||||||
generate_progress_images(i + 1, flux, args)
|
generate_progress_images(i + 1, flux, args)
|
||||||
|
|
||||||
if (i + 1) % args.checkpoint_every == 0:
|
if (i + 1) % args.checkpoint_every == 0:
|
||||||
save_adapters(i + 1, flux, args)
|
save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args)
|
||||||
|
|
||||||
if (i + 1) % 10 == 0:
|
if (i + 1) % 10 == 0:
|
||||||
losses = []
|
losses = []
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
|
|
||||||
|
save_adapters("final_adapters.safetensors", flux, args)
|
||||||
|
print(f"Training successful. Saved final weights to {args.adapter_file}.")
|
||||||
|
@ -12,4 +12,5 @@ from .utils import (
|
|||||||
load_flow_model,
|
load_flow_model,
|
||||||
load_t5,
|
load_t5,
|
||||||
load_t5_tokenizer,
|
load_t5_tokenizer,
|
||||||
|
save_config,
|
||||||
)
|
)
|
||||||
|
@ -3,7 +3,8 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
@ -207,3 +208,23 @@ def load_clip_tokenizer(name: str):
|
|||||||
def load_t5_tokenizer(name: str, pad: bool = True):
|
def load_t5_tokenizer(name: str, pad: bool = True):
|
||||||
model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model")
|
model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model")
|
||||||
return T5Tokenizer(model_file, 256 if "schnell" in name else 512)
|
return T5Tokenizer(model_file, 256 if "schnell" in name else 512)
|
||||||
|
|
||||||
|
|
||||||
|
def save_config(
|
||||||
|
config: dict,
|
||||||
|
config_path: Union[str, Path],
|
||||||
|
) -> None:
|
||||||
|
"""Save the model configuration to the ``config_path``.
|
||||||
|
|
||||||
|
The final configuration will be sorted before saving for better readability.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (dict): The model configuration.
|
||||||
|
config_path (Union[str, Path]): Model configuration file path.
|
||||||
|
"""
|
||||||
|
# Sort the config for better readability
|
||||||
|
config = dict(sorted(config.items()))
|
||||||
|
|
||||||
|
# Write the config to the provided file
|
||||||
|
with open(config_path, "w") as fid:
|
||||||
|
json.dump(config, fid, indent=4)
|
||||||
|
@ -101,7 +101,8 @@ To see a description of all the arguments you can do:
|
|||||||
#### Streaming
|
#### Streaming
|
||||||
|
|
||||||
For streaming generation, use the `stream_generate` function. This returns a
|
For streaming generation, use the `stream_generate` function. This returns a
|
||||||
generator object which streams the output text. For example,
|
generator object which streams the output text, token, and log probabilities.
|
||||||
|
For example,
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from mlx_lm import load, stream_generate
|
from mlx_lm import load, stream_generate
|
||||||
@ -116,7 +117,7 @@ prompt = tokenizer.apply_chat_template(
|
|||||||
messages, tokenize=False, add_generation_prompt=True
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
)
|
)
|
||||||
|
|
||||||
for t in stream_generate(model, tokenizer, prompt, max_tokens=512):
|
for text, *_ in stream_generate(model, tokenizer, prompt, max_tokens=512):
|
||||||
print(t, end="", flush=True)
|
print(t, end="", flush=True)
|
||||||
print()
|
print()
|
||||||
```
|
```
|
||||||
@ -221,6 +222,7 @@ Here are a few examples of Hugging Face models that work with this example:
|
|||||||
- [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct)
|
- [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct)
|
||||||
- [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b)
|
- [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b)
|
||||||
- [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b)
|
- [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b)
|
||||||
|
- [tiiuae/falcon-mamba-7b-instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct)
|
||||||
|
|
||||||
Most
|
Most
|
||||||
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),
|
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),
|
||||||
@ -248,3 +250,28 @@ model, tokenizer = load(
|
|||||||
tokenizer_config={"eos_token": "<|endoftext|>", "trust_remote_code": True},
|
tokenizer_config={"eos_token": "<|endoftext|>", "trust_remote_code": True},
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Large Models
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
This requires macOS 15.0 or higher to work.
|
||||||
|
|
||||||
|
Models which are large relative to the total RAM available on the machine can
|
||||||
|
be slow. `mlx-lm` will attempt to make them faster by wiring the memory
|
||||||
|
occupied by the model and cache. This requires macOS 15 or higher to
|
||||||
|
work.
|
||||||
|
|
||||||
|
If you see the following warning message:
|
||||||
|
|
||||||
|
> [WARNING] Generating with a model that requires ...
|
||||||
|
|
||||||
|
then the model will likely be slow on the given machine. If the model fits in
|
||||||
|
RAM then it can often be sped up by increasing the system wired memory limit.
|
||||||
|
To increase the limit, set the following `sysctl`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sudo sysctl iogpu.wired_limit_mb=N
|
||||||
|
```
|
||||||
|
|
||||||
|
The value `N` should be larger than the size of the model in megabytes but
|
||||||
|
smaller than the memory size of the machine.
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
__version__ = "0.19.1"
|
__version__ = "0.19.3"
|
||||||
|
@ -8,7 +8,9 @@ import time
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
from .models.cache import make_prompt_cache, save_prompt_cache
|
from .models.cache import make_prompt_cache, save_prompt_cache
|
||||||
from .utils import load
|
from .utils import load, maybe_quantize_kv_cache
|
||||||
|
|
||||||
|
DEFAULT_QUANTIZED_KV_START = 5000
|
||||||
|
|
||||||
|
|
||||||
def setup_arg_parser():
|
def setup_arg_parser():
|
||||||
@ -70,6 +72,26 @@ def setup_arg_parser():
|
|||||||
required=True,
|
required=True,
|
||||||
help="Message to be processed by the model ('-' reads from stdin)",
|
help="Message to be processed by the model ('-' reads from stdin)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--kv-bits",
|
||||||
|
type=int,
|
||||||
|
help="Number of bits for KV cache quantization. "
|
||||||
|
"Defaults to no quantization.",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--kv-group-size",
|
||||||
|
type=int,
|
||||||
|
help="Group size for KV cache quantization.",
|
||||||
|
default=64,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--quantized-kv-start",
|
||||||
|
help="When --kv-bits is set, start quantizing the KV cache "
|
||||||
|
"from this step onwards.",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_QUANTIZED_KV_START,
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -127,8 +149,10 @@ def main():
|
|||||||
start = time.time()
|
start = time.time()
|
||||||
max_msg_len = 0
|
max_msg_len = 0
|
||||||
while y.size > 0:
|
while y.size > 0:
|
||||||
|
|
||||||
model(y[:step_size][None], cache=cache)
|
model(y[:step_size][None], cache=cache)
|
||||||
mx.eval([c.state for c in cache])
|
mx.eval([c.state for c in cache])
|
||||||
|
mx.metal.clear_cache()
|
||||||
processed += min(y.size, step_size)
|
processed += min(y.size, step_size)
|
||||||
y = y[step_size:]
|
y = y[step_size:]
|
||||||
current = time.time()
|
current = time.time()
|
||||||
@ -136,15 +160,19 @@ def main():
|
|||||||
msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)"
|
msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)"
|
||||||
max_msg_len = max(max_msg_len, len(msg))
|
max_msg_len = max(max_msg_len, len(msg))
|
||||||
print(msg + " " * (max_msg_len - len(msg)), end="", flush=True)
|
print(msg + " " * (max_msg_len - len(msg)), end="", flush=True)
|
||||||
|
|
||||||
|
maybe_quantize_kv_cache(
|
||||||
|
cache, args.quantized_kv_start, args.kv_group_size, args.kv_bits
|
||||||
|
)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
|
print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB")
|
||||||
|
|
||||||
print("Saving...")
|
print("Saving...")
|
||||||
metadata = {}
|
metadata = {}
|
||||||
metadata["model"] = args.model
|
metadata["model"] = args.model
|
||||||
metadata["chat_template"] = tokenizer.chat_template
|
metadata["chat_template"] = tokenizer.chat_template
|
||||||
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
|
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
|
||||||
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
|
|
||||||
save_prompt_cache(args.prompt_cache_file, cache, metadata)
|
save_prompt_cache(args.prompt_cache_file, cache, metadata)
|
||||||
|
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@ from .utils import load, stream_generate
|
|||||||
DEFAULT_TEMP = 0.0
|
DEFAULT_TEMP = 0.0
|
||||||
DEFAULT_TOP_P = 1.0
|
DEFAULT_TOP_P = 1.0
|
||||||
DEFAULT_SEED = 0
|
DEFAULT_SEED = 0
|
||||||
|
DEFAULT_MAX_TOKENS = 256
|
||||||
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||||
|
|
||||||
|
|
||||||
@ -41,6 +42,13 @@ def setup_arg_parser():
|
|||||||
help="Set the maximum key-value cache size",
|
help="Set the maximum key-value cache size",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-tokens",
|
||||||
|
"-m",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_MAX_TOKENS,
|
||||||
|
help="Maximum number of tokens to generate",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -56,7 +64,7 @@ def main():
|
|||||||
tokenizer_config={"trust_remote_code": True},
|
tokenizer_config={"trust_remote_code": True},
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"[INFO] Starting chat sessiong with {args.model}. To exit, enter 'q'.")
|
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.")
|
||||||
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
||||||
while True:
|
while True:
|
||||||
query = input(">> ")
|
query = input(">> ")
|
||||||
@ -66,10 +74,11 @@ def main():
|
|||||||
prompt = tokenizer.apply_chat_template(
|
prompt = tokenizer.apply_chat_template(
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
)
|
)
|
||||||
for response in stream_generate(
|
for response, *_ in stream_generate(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
prompt,
|
prompt,
|
||||||
|
args.max_tokens,
|
||||||
temp=args.temp,
|
temp=args.temp,
|
||||||
top_p=args.top_p,
|
top_p=args.top_p,
|
||||||
prompt_cache=prompt_cache,
|
prompt_cache=prompt_cache,
|
||||||
|
@ -6,15 +6,18 @@ import sys
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
from .models.cache import load_prompt_cache
|
from .models.cache import QuantizedKVCache, load_prompt_cache
|
||||||
from .utils import generate, load
|
from .utils import generate, load
|
||||||
|
|
||||||
DEFAULT_PROMPT = "hello"
|
DEFAULT_PROMPT = "hello"
|
||||||
DEFAULT_MAX_TOKENS = 100
|
DEFAULT_MAX_TOKENS = 100
|
||||||
DEFAULT_TEMP = 0.0
|
DEFAULT_TEMP = 0.0
|
||||||
DEFAULT_TOP_P = 1.0
|
DEFAULT_TOP_P = 1.0
|
||||||
|
DEFAULT_MIN_P = 0.0
|
||||||
|
DEFAULT_MIN_TOKENS_TO_KEEP = 1
|
||||||
DEFAULT_SEED = 0
|
DEFAULT_SEED = 0
|
||||||
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||||
|
DEFAULT_QUANTIZED_KV_START = 5000
|
||||||
|
|
||||||
|
|
||||||
def str2bool(string):
|
def str2bool(string):
|
||||||
@ -51,6 +54,7 @@ def setup_arg_parser():
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prompt",
|
"--prompt",
|
||||||
|
"-p",
|
||||||
default=DEFAULT_PROMPT,
|
default=DEFAULT_PROMPT,
|
||||||
help="Message to be processed by the model ('-' reads from stdin)",
|
help="Message to be processed by the model ('-' reads from stdin)",
|
||||||
)
|
)
|
||||||
@ -67,6 +71,15 @@ def setup_arg_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
|
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--min-p", type=float, default=DEFAULT_MIN_P, help="Sampling min-p"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--min-tokens-to-keep",
|
||||||
|
type=float,
|
||||||
|
default=DEFAULT_MIN_TOKENS_TO_KEEP,
|
||||||
|
help="Minimum tokens to keep for min-p sampling.",
|
||||||
|
)
|
||||||
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
|
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ignore-chat-template",
|
"--ignore-chat-template",
|
||||||
@ -89,12 +102,6 @@ def setup_arg_parser():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Colorize output based on T[0] probability",
|
help="Colorize output based on T[0] probability",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--cache-limit-gb",
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help="Set the MLX cache limit in GB",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-kv-size",
|
"--max-kv-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -107,6 +114,26 @@ def setup_arg_parser():
|
|||||||
default=None,
|
default=None,
|
||||||
help="A file containing saved KV caches to avoid recomputing them",
|
help="A file containing saved KV caches to avoid recomputing them",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--kv-bits",
|
||||||
|
type=int,
|
||||||
|
help="Number of bits for KV cache quantization. "
|
||||||
|
"Defaults to no quantization.",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--kv-group-size",
|
||||||
|
type=int,
|
||||||
|
help="Group size for KV cache quantization.",
|
||||||
|
default=64,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--quantized-kv-start",
|
||||||
|
help="When --kv-bits is set, start quantizing the KV cache "
|
||||||
|
"from this step onwards.",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_QUANTIZED_KV_START,
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -143,15 +170,22 @@ def main():
|
|||||||
|
|
||||||
mx.random.seed(args.seed)
|
mx.random.seed(args.seed)
|
||||||
|
|
||||||
if args.cache_limit_gb is not None:
|
|
||||||
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
|
|
||||||
|
|
||||||
# Load the prompt cache and metadata if a cache file is provided
|
# Load the prompt cache and metadata if a cache file is provided
|
||||||
using_cache = args.prompt_cache_file is not None
|
using_cache = args.prompt_cache_file is not None
|
||||||
if using_cache:
|
if using_cache:
|
||||||
prompt_cache, metadata = load_prompt_cache(
|
prompt_cache, metadata = load_prompt_cache(
|
||||||
args.prompt_cache_file, return_metadata=True
|
args.prompt_cache_file,
|
||||||
|
return_metadata=True,
|
||||||
)
|
)
|
||||||
|
if isinstance(prompt_cache[0], QuantizedKVCache):
|
||||||
|
if args.kv_bits is not None and args.kv_bits != prompt_cache[0].bits:
|
||||||
|
raise ValueError(
|
||||||
|
"--kv-bits does not match the kv cache loaded from --prompt-cache-file."
|
||||||
|
)
|
||||||
|
if args.kv_group_size != prompt_cache[0].group_size:
|
||||||
|
raise ValueError(
|
||||||
|
"--kv-group-size does not match the kv cache loaded from --prompt-cache-file."
|
||||||
|
)
|
||||||
|
|
||||||
# Building tokenizer_config
|
# Building tokenizer_config
|
||||||
tokenizer_config = (
|
tokenizer_config = (
|
||||||
@ -225,8 +259,13 @@ def main():
|
|||||||
formatter=formatter,
|
formatter=formatter,
|
||||||
temp=args.temp,
|
temp=args.temp,
|
||||||
top_p=args.top_p,
|
top_p=args.top_p,
|
||||||
|
min_p=args.min_p,
|
||||||
|
min_tokens_to_keep=args.min_tokens_to_keep,
|
||||||
max_kv_size=args.max_kv_size,
|
max_kv_size=args.max_kv_size,
|
||||||
prompt_cache=prompt_cache if using_cache else None,
|
prompt_cache=prompt_cache if using_cache else None,
|
||||||
|
kv_bits=args.kv_bits,
|
||||||
|
kv_group_size=args.kv_group_size,
|
||||||
|
quantized_kv_start=args.quantized_kv_start,
|
||||||
)
|
)
|
||||||
if not args.verbose:
|
if not args.verbose:
|
||||||
print(response)
|
print(response)
|
||||||
|
@ -5,6 +5,9 @@ from dataclasses import dataclass
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
from mlx.utils import tree_map
|
||||||
|
|
||||||
|
from .cache import QuantizedKVCache
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -39,7 +42,7 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
|||||||
if cache is not None and cache[0] is not None:
|
if cache is not None and cache[0] is not None:
|
||||||
c = cache[0]
|
c = cache[0]
|
||||||
if hasattr(c, "max_size"):
|
if hasattr(c, "max_size"):
|
||||||
offset = min(c.max_size - 1, c.offset)
|
offset = min(c.max_size, c.offset)
|
||||||
window_size = c.max_size
|
window_size = c.max_size
|
||||||
else:
|
else:
|
||||||
offset = c.offset
|
offset = c.offset
|
||||||
@ -48,3 +51,63 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
|||||||
else:
|
else:
|
||||||
mask = None
|
mask = None
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def quantized_scaled_dot_product_attention(
|
||||||
|
queries: mx.array,
|
||||||
|
q_keys: tuple[mx.array, mx.array, mx.array],
|
||||||
|
q_values: tuple[mx.array, mx.array, mx.array],
|
||||||
|
scale: float,
|
||||||
|
mask: Optional[mx.array],
|
||||||
|
group_size: int = 64,
|
||||||
|
bits: int = 8,
|
||||||
|
) -> mx.array:
|
||||||
|
B, n_q_heads, L, D = queries.shape
|
||||||
|
n_kv_heads = q_keys[0].shape[-3]
|
||||||
|
n_repeats = n_q_heads // n_kv_heads
|
||||||
|
|
||||||
|
queries *= scale
|
||||||
|
|
||||||
|
if n_repeats > 1:
|
||||||
|
queries = mx.reshape(queries, (B, n_kv_heads, n_repeats, L, D))
|
||||||
|
q_keys = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys)
|
||||||
|
q_values = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values)
|
||||||
|
|
||||||
|
scores = mx.quantized_matmul(
|
||||||
|
queries, *q_keys, transpose=True, group_size=group_size, bits=bits
|
||||||
|
)
|
||||||
|
if mask is not None:
|
||||||
|
scores += mask
|
||||||
|
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||||
|
out = mx.quantized_matmul(
|
||||||
|
scores, *q_values, transpose=False, group_size=group_size, bits=bits
|
||||||
|
)
|
||||||
|
|
||||||
|
if n_repeats > 1:
|
||||||
|
out = mx.reshape(out, (B, n_q_heads, L, D))
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def scaled_dot_product_attention(
|
||||||
|
queries,
|
||||||
|
keys,
|
||||||
|
values,
|
||||||
|
cache,
|
||||||
|
scale: float,
|
||||||
|
mask: Optional[mx.array],
|
||||||
|
) -> mx.array:
|
||||||
|
if isinstance(cache, QuantizedKVCache):
|
||||||
|
return quantized_scaled_dot_product_attention(
|
||||||
|
queries,
|
||||||
|
keys,
|
||||||
|
values,
|
||||||
|
scale=scale,
|
||||||
|
mask=mask,
|
||||||
|
group_size=cache.group_size,
|
||||||
|
bits=cache.bits,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return mx.fast.scaled_dot_product_attention(
|
||||||
|
queries, keys, values, scale=scale, mask=mask
|
||||||
|
)
|
||||||
|
@ -4,10 +4,13 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from mlx.utils import tree_flatten, tree_unflatten
|
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||||
|
|
||||||
|
|
||||||
def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> List[Any]:
|
def make_prompt_cache(
|
||||||
|
model: nn.Module,
|
||||||
|
max_kv_size: Optional[int] = None,
|
||||||
|
) -> List[Any]:
|
||||||
"""
|
"""
|
||||||
Construct the model's cache for use when cgeneration.
|
Construct the model's cache for use when cgeneration.
|
||||||
|
|
||||||
@ -126,6 +129,88 @@ class _BaseCache:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizedKVCache(_BaseCache):
|
||||||
|
def __init__(self, group_size: int = 64, bits: int = 8):
|
||||||
|
self.keys = None
|
||||||
|
self.values = None
|
||||||
|
self.offset = 0
|
||||||
|
self.step = 256
|
||||||
|
self.group_size = group_size
|
||||||
|
self.bits = bits
|
||||||
|
|
||||||
|
def update_and_fetch(self, keys, values):
|
||||||
|
B, n_kv_heads, num_steps, k_head_dim = keys.shape
|
||||||
|
v_head_dim = values.shape[-1]
|
||||||
|
prev = self.offset
|
||||||
|
|
||||||
|
if self.keys is None or (prev + num_steps) > self.keys[0].shape[-2]:
|
||||||
|
el_per_int = 8 * mx.uint32.size // self.bits
|
||||||
|
new_steps = (self.step + num_steps - 1) // self.step * self.step
|
||||||
|
shape = (B, n_kv_heads, new_steps)
|
||||||
|
|
||||||
|
def init_quant(dim):
|
||||||
|
return (
|
||||||
|
mx.zeros((*shape, dim // el_per_int), dtype=mx.uint32),
|
||||||
|
mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype),
|
||||||
|
mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
def expand_quant(x):
|
||||||
|
new_x = mx.zeros((*shape, x.shape[-1]), dtype=x.dtype)
|
||||||
|
return mx.concatenate([x, new_x], axis=-2)
|
||||||
|
|
||||||
|
if self.keys is not None:
|
||||||
|
if prev % self.step != 0:
|
||||||
|
self.keys, self.values = tree_map(
|
||||||
|
lambda x: x[..., :prev, :], (self.keys, self.values)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.keys, self.values = tree_map(
|
||||||
|
expand_quant, (self.keys, self.values)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.keys, self.values = init_quant(k_head_dim), init_quant(v_head_dim)
|
||||||
|
|
||||||
|
self.offset += num_steps
|
||||||
|
|
||||||
|
keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits)
|
||||||
|
values = mx.quantize(values, group_size=self.group_size, bits=self.bits)
|
||||||
|
for i in range(len(self.keys)):
|
||||||
|
self.keys[i][..., prev : self.offset, :] = keys[i]
|
||||||
|
self.values[i][..., prev : self.offset, :] = values[i]
|
||||||
|
|
||||||
|
return tree_map(lambda x: x[..., : self.offset, :], (self.keys, self.values))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self):
|
||||||
|
if self.offset == self.keys[0].shape[2]:
|
||||||
|
return self.keys, self.values
|
||||||
|
else:
|
||||||
|
return tree_map(
|
||||||
|
lambda x: x[..., : self.offset, :], (self.keys, self.values)
|
||||||
|
)
|
||||||
|
|
||||||
|
@state.setter
|
||||||
|
def state(self, v):
|
||||||
|
self.keys, self.values = v
|
||||||
|
|
||||||
|
@property
|
||||||
|
def meta_state(self):
|
||||||
|
return tuple(map(str, (self.step, self.offset, self.group_size, self.bits)))
|
||||||
|
|
||||||
|
@meta_state.setter
|
||||||
|
def meta_state(self, v):
|
||||||
|
self.step, self.offset, self.group_size, self.bits = map(int, v)
|
||||||
|
|
||||||
|
def is_trimmable(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def trim(self, n):
|
||||||
|
n = min(self.offset, n)
|
||||||
|
self.offset -= n
|
||||||
|
return n
|
||||||
|
|
||||||
|
|
||||||
class KVCache(_BaseCache):
|
class KVCache(_BaseCache):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.keys = None
|
self.keys = None
|
||||||
@ -180,6 +265,16 @@ class KVCache(_BaseCache):
|
|||||||
self.offset -= n
|
self.offset -= n
|
||||||
return n
|
return n
|
||||||
|
|
||||||
|
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
|
||||||
|
quant_cache = QuantizedKVCache(group_size=group_size, bits=bits)
|
||||||
|
quant_cache.offset = self.offset
|
||||||
|
if self.keys is not None:
|
||||||
|
quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits)
|
||||||
|
quant_cache.values = mx.quantize(
|
||||||
|
self.values, group_size=group_size, bits=bits
|
||||||
|
)
|
||||||
|
return quant_cache
|
||||||
|
|
||||||
|
|
||||||
class RotatingKVCache(_BaseCache):
|
class RotatingKVCache(_BaseCache):
|
||||||
|
|
||||||
@ -230,9 +325,9 @@ class RotatingKVCache(_BaseCache):
|
|||||||
self.keys = self._temporal_order(self.keys)
|
self.keys = self._temporal_order(self.keys)
|
||||||
self.values = self._temporal_order(self.values)
|
self.values = self._temporal_order(self.values)
|
||||||
|
|
||||||
# The largest size is self.max_size + S - 1 to ensure
|
# The largest size is self.max_size + S to ensure
|
||||||
# every token gets at least self.max_size context
|
# every token gets at least self.max_size context
|
||||||
trim_size = self._idx - self.max_size + 1
|
trim_size = self._idx - self.max_size
|
||||||
self.keys = self._trim(trim_size, self.keys, keys)
|
self.keys = self._trim(trim_size, self.keys, keys)
|
||||||
self.values = self._trim(trim_size, self.values, values)
|
self.values = self._trim(trim_size, self.values, values)
|
||||||
self.offset += keys.shape[2]
|
self.offset += keys.shape[2]
|
||||||
@ -320,6 +415,9 @@ class RotatingKVCache(_BaseCache):
|
|||||||
self._idx -= n
|
self._idx -= n
|
||||||
return n
|
return n
|
||||||
|
|
||||||
|
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
|
||||||
|
raise NotImplementedError("RotatingKVCache Quantization NYI")
|
||||||
|
|
||||||
|
|
||||||
class MambaCache:
|
class MambaCache:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -6,7 +6,7 @@ from typing import Any, Optional, Tuple
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -93,8 +93,8 @@ class Attention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
|
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
@ -7,7 +7,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -74,8 +74,8 @@ class Attention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.out_proj(output)
|
return self.out_proj(output)
|
||||||
|
@ -4,7 +4,7 @@ from typing import Any, Dict, Optional
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
from .switch_layers import SwitchGLU
|
from .switch_layers import SwitchGLU
|
||||||
|
|
||||||
|
|
||||||
@ -97,8 +97,8 @@ class DeepseekAttention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output)
|
return self.o_proj(output)
|
||||||
|
@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Tuple
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
from .switch_layers import SwitchGLU
|
from .switch_layers import SwitchGLU
|
||||||
|
|
||||||
|
|
||||||
@ -235,8 +235,8 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
|
|
||||||
queries = mx.concatenate([q_nope, q_pe], axis=-1)
|
queries = mx.concatenate([q_nope, q_pe], axis=-1)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output)
|
return self.o_proj(output)
|
||||||
|
@ -6,7 +6,7 @@ from typing import Any, Optional, Tuple
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -79,8 +79,8 @@ class Attention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
|
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
@ -7,7 +7,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -61,8 +61,8 @@ class Attention(nn.Module):
|
|||||||
if cache is not None:
|
if cache is not None:
|
||||||
keys, values = cache.update_and_fetch(keys, values)
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
|
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
@ -7,7 +7,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -74,8 +74,8 @@ class Attention(nn.Module):
|
|||||||
if cache is not None:
|
if cache is not None:
|
||||||
keys, values = cache.update_and_fetch(keys, values)
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.c_proj(output)
|
return self.c_proj(output)
|
||||||
|
@ -7,7 +7,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
# Based on the transformers implementation at:
|
# Based on the transformers implementation at:
|
||||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
||||||
@ -79,8 +79,8 @@ class Attention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
|
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
@ -6,7 +6,7 @@ from typing import Any, Dict, Optional, Tuple, Union
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -141,8 +141,8 @@ class Attention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.wo(output)
|
return self.wo(output)
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -190,9 +190,10 @@ class Attention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
|
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output)
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
@ -23,6 +23,8 @@ class ModelArgs(BaseModelArgs):
|
|||||||
use_conv_bias: bool
|
use_conv_bias: bool
|
||||||
time_step_rank: int
|
time_step_rank: int
|
||||||
tie_word_embeddings: bool = True
|
tie_word_embeddings: bool = True
|
||||||
|
use_bcdt_rms: bool = False
|
||||||
|
mixer_rms_eps: float = 1e-6
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if not hasattr(self, "hidden_size") and hasattr(self, "d_model"):
|
if not hasattr(self, "hidden_size") and hasattr(self, "d_model"):
|
||||||
@ -44,6 +46,8 @@ class ModelArgs(BaseModelArgs):
|
|||||||
|
|
||||||
if self.time_step_rank == "auto":
|
if self.time_step_rank == "auto":
|
||||||
self.time_step_rank = math.ceil(self.hidden_size / 16)
|
self.time_step_rank = math.ceil(self.hidden_size / 16)
|
||||||
|
if self.model_type == "falcon_mamba":
|
||||||
|
self.use_bcdt_rms = True
|
||||||
|
|
||||||
|
|
||||||
class DepthWiseConv1d(nn.Module):
|
class DepthWiseConv1d(nn.Module):
|
||||||
@ -83,6 +87,11 @@ class MambaBlock(nn.Module):
|
|||||||
self.intermediate_size = args.intermediate_size
|
self.intermediate_size = args.intermediate_size
|
||||||
self.time_step_rank = int(args.time_step_rank)
|
self.time_step_rank = int(args.time_step_rank)
|
||||||
self.use_conv_bias = args.use_conv_bias
|
self.use_conv_bias = args.use_conv_bias
|
||||||
|
self.use_bcdt_rms = args.use_bcdt_rms
|
||||||
|
if self.use_bcdt_rms:
|
||||||
|
self.mixer_norm = lambda x: mx.fast.rms_norm(
|
||||||
|
x, mx.ones(x.shape[-1], x.dtype), eps=args.mixer_rms_eps
|
||||||
|
)
|
||||||
|
|
||||||
self.in_proj = nn.Linear(
|
self.in_proj = nn.Linear(
|
||||||
self.hidden_size, self.intermediate_size * 2, bias=args.use_bias
|
self.hidden_size, self.intermediate_size * 2, bias=args.use_bias
|
||||||
@ -126,6 +135,8 @@ class MambaBlock(nn.Module):
|
|||||||
],
|
],
|
||||||
axis=-1,
|
axis=-1,
|
||||||
)
|
)
|
||||||
|
if self.use_bcdt_rms:
|
||||||
|
delta, B, C = map(self.mixer_norm, (delta, B, C))
|
||||||
delta = nn.softplus(self.dt_proj(delta))
|
delta = nn.softplus(self.dt_proj(delta))
|
||||||
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
|
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
|
||||||
if state is not None:
|
if state is not None:
|
||||||
|
@ -7,7 +7,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -105,8 +105,8 @@ class Attention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
attn_output = mx.fast.scaled_dot_product_attention(
|
attn_output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Tuple, Union
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
from .switch_layers import SwitchGLU
|
from .switch_layers import SwitchGLU
|
||||||
|
|
||||||
|
|
||||||
@ -87,8 +87,8 @@ class MixtralAttention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output)
|
return self.o_proj(output)
|
||||||
|
@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Union
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -113,8 +113,8 @@ class Attention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output)
|
return self.o_proj(output)
|
||||||
|
@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -107,8 +107,8 @@ class Attention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
|
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
@ -7,7 +7,7 @@ from typing import Tuple
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -93,8 +93,13 @@ class PhiAttention(nn.Module):
|
|||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
scale = math.sqrt(1 / queries.shape[-1])
|
scale = math.sqrt(1 / queries.shape[-1])
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries.astype(mx.float32), keys, values, scale=scale, mask=mask
|
queries.astype(mx.float32),
|
||||||
|
keys,
|
||||||
|
values,
|
||||||
|
cache=cache,
|
||||||
|
scale=scale,
|
||||||
|
mask=mask,
|
||||||
).astype(values.dtype)
|
).astype(values.dtype)
|
||||||
|
|
||||||
output = output.moveaxis(2, 1).reshape(B, L, -1)
|
output = output.moveaxis(2, 1).reshape(B, L, -1)
|
||||||
|
@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
from .su_rope import SuScaledRotaryEmbedding
|
from .su_rope import SuScaledRotaryEmbedding
|
||||||
|
|
||||||
|
|
||||||
@ -107,8 +107,8 @@ class Attention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output)
|
return self.o_proj(output)
|
||||||
|
@ -8,7 +8,7 @@ from typing import Any, Optional
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -188,8 +188,8 @@ class Attention(nn.Module):
|
|||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.dense(output)
|
return self.dense(output)
|
||||||
|
@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Union
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
from .su_rope import SuScaledRotaryEmbedding
|
from .su_rope import SuScaledRotaryEmbedding
|
||||||
from .switch_layers import SwitchGLU
|
from .switch_layers import SwitchGLU
|
||||||
|
|
||||||
@ -79,8 +79,8 @@ class Attention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output)
|
return self.o_proj(output)
|
||||||
|
@ -8,7 +8,7 @@ from typing import Tuple
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import create_attention_mask
|
from .base import create_attention_mask, scaled_dot_product_attention
|
||||||
from .switch_layers import SwitchMLP
|
from .switch_layers import SwitchMLP
|
||||||
|
|
||||||
|
|
||||||
@ -71,8 +71,13 @@ class RoPEAttention(nn.Module):
|
|||||||
# Finally perform the attention computation
|
# Finally perform the attention computation
|
||||||
scale = math.sqrt(1 / queries.shape[-1])
|
scale = math.sqrt(1 / queries.shape[-1])
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries.astype(mx.float32), keys, values, scale=scale, mask=mask
|
queries.astype(mx.float32),
|
||||||
|
keys,
|
||||||
|
values,
|
||||||
|
cache=cache,
|
||||||
|
scale=scale,
|
||||||
|
mask=mask,
|
||||||
).astype(values.dtype)
|
).astype(values.dtype)
|
||||||
output = output.moveaxis(2, 1).reshape(B, L, -1)
|
output = output.moveaxis(2, 1).reshape(B, L, -1)
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -92,10 +92,11 @@ class Attention(nn.Module):
|
|||||||
keys = mx.tile(keys, [1, self.config.n_shared_head, 1, 1])
|
keys = mx.tile(keys, [1, self.config.n_shared_head, 1, 1])
|
||||||
values = mx.tile(values, [1, self.config.n_shared_head, 1, 1])
|
values = mx.tile(values, [1, self.config.n_shared_head, 1, 1])
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries,
|
queries,
|
||||||
keys,
|
keys,
|
||||||
values,
|
values,
|
||||||
|
cache=cache,
|
||||||
scale=self.scale,
|
scale=self.scale,
|
||||||
mask=attention_mask,
|
mask=attention_mask,
|
||||||
)
|
)
|
||||||
|
@ -5,7 +5,7 @@ from dataclasses import dataclass
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -64,8 +64,8 @@ class Attention(nn.Module):
|
|||||||
queries = self.rotary_emb(queries)
|
queries = self.rotary_emb(queries)
|
||||||
keys = self.rotary_emb(keys)
|
keys = self.rotary_emb(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ from typing import Any, Dict, Optional, Union
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -89,8 +89,8 @@ class Attention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output)
|
return self.o_proj(output)
|
||||||
|
@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Union
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
from .switch_layers import SwitchGLU
|
from .switch_layers import SwitchGLU
|
||||||
|
|
||||||
|
|
||||||
@ -89,8 +89,8 @@ class Attention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output)
|
return self.o_proj(output)
|
||||||
|
@ -7,7 +7,7 @@ from typing import List, Literal, Optional
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
from .cache import MambaCache, RotatingKVCache
|
from .cache import MambaCache, RotatingKVCache
|
||||||
|
|
||||||
|
|
||||||
@ -263,8 +263,8 @@ class LocalAttentionBlock(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output)
|
return self.o_proj(output)
|
||||||
|
@ -6,7 +6,7 @@ from dataclasses import dataclass
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -120,8 +120,8 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
# Finally perform the attention computation
|
# Finally perform the attention computation
|
||||||
scale = math.sqrt(1 / queries.shape[-1])
|
scale = math.sqrt(1 / queries.shape[-1])
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=scale, mask=mask
|
queries, keys, values, cache=cache, scale=scale, mask=mask
|
||||||
).astype(values.dtype)
|
).astype(values.dtype)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output)
|
return self.o_proj(output)
|
||||||
|
@ -6,7 +6,7 @@ from typing import Any, Optional
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -64,8 +64,8 @@ class Attention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, cache=cache, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
|
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
mlx>=0.17.0
|
mlx>=0.19.2
|
||||||
numpy
|
numpy
|
||||||
transformers[sentencepiece]>=4.39.3
|
transformers[sentencepiece]>=4.39.3
|
||||||
protobuf
|
protobuf
|
||||||
|
@ -1,10 +1,83 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import Callable, Dict, Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
|
def make_sampler(
|
||||||
|
temp: float = 0.0,
|
||||||
|
top_p: float = 0.0,
|
||||||
|
min_p: float = 0.0,
|
||||||
|
min_tokens_to_keep: int = 1,
|
||||||
|
) -> Callable[mx.array, mx.array]:
|
||||||
|
"""
|
||||||
|
Make a sampler function for use with ``generate_step``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
temp (float): The temperature for sampling, if 0 the argmax is used.
|
||||||
|
Default: ``0``.
|
||||||
|
top_p (float, optional): Nulceus sampling, higher means model considers
|
||||||
|
more less likely words.
|
||||||
|
min_p (float, optional): The minimum value (scaled by the top token's
|
||||||
|
probability) that a token probability must have to be considered.
|
||||||
|
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
||||||
|
be filtered by min_p sampling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable[mx.array, mx.array]:
|
||||||
|
A sampler which takes log-probabilities and returns tokens.
|
||||||
|
"""
|
||||||
|
if temp == 0:
|
||||||
|
return lambda x: mx.argmax(x, axis=-1)
|
||||||
|
elif top_p > 0 and top_p < 1.0:
|
||||||
|
return lambda x: top_p_sampling(x, top_p, temp)
|
||||||
|
elif min_p != 0.0:
|
||||||
|
return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp)
|
||||||
|
else:
|
||||||
|
return lambda x: categorical_sampling(x, temp)
|
||||||
|
|
||||||
|
|
||||||
|
def make_logits_processors(
|
||||||
|
logit_bias: Optional[Dict[int, float]] = None,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
repetition_context_size: Optional[int] = 20,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Make logits processors for use with ``generate_step``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repetition_penalty (float, optional): The penalty factor for repeating
|
||||||
|
tokens.
|
||||||
|
repetition_context_size (int, optional): The number of tokens to
|
||||||
|
consider for repetition penalty. Default: ``20``.
|
||||||
|
logit_bias (dictionary, optional): Additive logit bias.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Callable[[mx.array, mx.array], mx.array]]:
|
||||||
|
A list of logits processors. Each processor in the list is a
|
||||||
|
callable which takes an array of tokens and an array of logits
|
||||||
|
and returns the updated logits.
|
||||||
|
"""
|
||||||
|
logits_processors = []
|
||||||
|
if logit_bias:
|
||||||
|
indices = mx.array(list(logit_bias.keys()))
|
||||||
|
values = mx.array(list(logit_bias.values()))
|
||||||
|
|
||||||
|
def logit_bias_processor(_, logits):
|
||||||
|
logits[:, indices] += values
|
||||||
|
return logits
|
||||||
|
|
||||||
|
logits_processors.append(logit_bias_processor)
|
||||||
|
|
||||||
|
if repetition_penalty and repetition_penalty != 0.0:
|
||||||
|
logits_processors.append(
|
||||||
|
make_repetition_penalty(repetition_penalty, repetition_context_size)
|
||||||
|
)
|
||||||
|
return logits_processors
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||||
def min_p_sampling(
|
def min_p_sampling(
|
||||||
logits: mx.array,
|
logits: mx.array,
|
||||||
@ -100,3 +173,36 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
|
|||||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||||
def categorical_sampling(logits, temp):
|
def categorical_sampling(logits, temp):
|
||||||
return mx.random.categorical(logits * (1 / temp))
|
return mx.random.categorical(logits * (1 / temp))
|
||||||
|
|
||||||
|
|
||||||
|
def make_repetition_penalty(penalty: float, context_size: int = 20):
|
||||||
|
"""
|
||||||
|
Make repetition penalty processor.
|
||||||
|
|
||||||
|
Paper: https://arxiv.org/abs/1909.05858
|
||||||
|
|
||||||
|
Args:
|
||||||
|
penalty (float): The repetition penalty factor to be applied.
|
||||||
|
context_size (int): The number of previous tokens to use.
|
||||||
|
Default: ``20``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable[[mx.array, List[int]], mx.array]:
|
||||||
|
The repetition penalty processor.
|
||||||
|
"""
|
||||||
|
if penalty < 0 or not isinstance(penalty, float):
|
||||||
|
raise ValueError(f"penalty must be a non-negative float, got {penalty}")
|
||||||
|
|
||||||
|
def repetition_penalty_processor(tokens, logits):
|
||||||
|
if len(tokens) > 0:
|
||||||
|
tokens = tokens[-context_size:]
|
||||||
|
selected_logits = logits[:, tokens]
|
||||||
|
selected_logits = mx.where(
|
||||||
|
selected_logits < 0,
|
||||||
|
selected_logits * penalty,
|
||||||
|
selected_logits / penalty,
|
||||||
|
)
|
||||||
|
logits[:, tokens] = selected_logits
|
||||||
|
return logits
|
||||||
|
|
||||||
|
return repetition_penalty_processor
|
||||||
|
@ -27,7 +27,7 @@ from huggingface_hub import scan_cache_dir
|
|||||||
|
|
||||||
from ._version import __version__
|
from ._version import __version__
|
||||||
from .models.cache import make_prompt_cache
|
from .models.cache import make_prompt_cache
|
||||||
from .utils import generate_step, load
|
from .utils import load, stream_generate
|
||||||
|
|
||||||
|
|
||||||
def get_system_fingerprint():
|
def get_system_fingerprint():
|
||||||
@ -64,7 +64,7 @@ def stopping_criteria(
|
|||||||
end if it has (`trim_length`).
|
end if it has (`trim_length`).
|
||||||
"""
|
"""
|
||||||
if tokens and tokens[-1] == eos_token_id:
|
if tokens and tokens[-1] == eos_token_id:
|
||||||
return StopCondition(stop_met=True, trim_length=1)
|
return StopCondition(stop_met=True, trim_length=0)
|
||||||
|
|
||||||
for stop_ids in stop_id_sequences:
|
for stop_ids in stop_id_sequences:
|
||||||
if len(tokens) >= len(stop_ids):
|
if len(tokens) >= len(stop_ids):
|
||||||
@ -253,7 +253,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
self.max_tokens = self.body.get("max_completion_tokens", None)
|
self.max_tokens = self.body.get("max_completion_tokens", None)
|
||||||
if self.max_tokens is None:
|
if self.max_tokens is None:
|
||||||
self.max_tokens = self.body.get("max_tokens", 512)
|
self.max_tokens = self.body.get("max_tokens", 512)
|
||||||
self.temperature = self.body.get("temperature", 1.0)
|
self.temperature = self.body.get("temperature", 0.0)
|
||||||
self.top_p = self.body.get("top_p", 1.0)
|
self.top_p = self.body.get("top_p", 1.0)
|
||||||
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
|
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
|
||||||
self.repetition_context_size = self.body.get("repetition_context_size", 20)
|
self.repetition_context_size = self.body.get("repetition_context_size", 20)
|
||||||
@ -290,10 +290,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
# Call endpoint specific method
|
# Call endpoint specific method
|
||||||
prompt = endpoints[self.path]()
|
prompt = endpoints[self.path]()
|
||||||
|
self.handle_completion(prompt, stop_id_sequences)
|
||||||
# Call method based on response type
|
|
||||||
method = self.handle_stream if self.stream else self.handle_completion
|
|
||||||
method(prompt, stop_id_sequences)
|
|
||||||
|
|
||||||
def validate_model_parameters(self):
|
def validate_model_parameters(self):
|
||||||
"""
|
"""
|
||||||
@ -452,32 +449,40 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
stop_id_sequences (List[List[int]]): A list of stop words passed
|
stop_id_sequences (List[List[int]]): A list of stop words passed
|
||||||
to the stopping_criteria function
|
to the stopping_criteria function
|
||||||
"""
|
"""
|
||||||
detokenizer = self.tokenizer.detokenizer
|
|
||||||
detokenizer.reset()
|
|
||||||
tokens = []
|
tokens = []
|
||||||
finish_reason = "length"
|
finish_reason = "length"
|
||||||
stop_sequence_suffix = None
|
stop_sequence_suffix = None
|
||||||
logging.debug(f"Starting completion:")
|
if self.stream:
|
||||||
|
self.end_headers()
|
||||||
|
logging.debug(f"Starting stream:")
|
||||||
|
else:
|
||||||
|
logging.debug(f"Starting completion:")
|
||||||
token_logprobs = []
|
token_logprobs = []
|
||||||
top_tokens = []
|
top_tokens = []
|
||||||
|
|
||||||
prompt = self.get_prompt_cache(prompt)
|
prompt = self.get_prompt_cache(prompt)
|
||||||
|
|
||||||
for _, (token, logprobs) in zip(
|
text = ""
|
||||||
range(self.max_tokens),
|
tic = time.perf_counter()
|
||||||
generate_step(
|
for n, (segment, token, logprobs) in enumerate(
|
||||||
prompt=mx.array(prompt),
|
stream_generate(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
prompt=prompt,
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
temp=self.temperature,
|
temp=self.temperature,
|
||||||
top_p=self.top_p,
|
|
||||||
repetition_penalty=self.repetition_penalty,
|
repetition_penalty=self.repetition_penalty,
|
||||||
repetition_context_size=self.repetition_context_size,
|
repetition_context_size=self.repetition_context_size,
|
||||||
logit_bias=self.logit_bias,
|
logit_bias=self.logit_bias,
|
||||||
prompt_cache=self.prompt_cache.cache,
|
prompt_cache=self.prompt_cache.cache,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
detokenizer.add_token(token)
|
if n == 0:
|
||||||
logging.debug(detokenizer.text)
|
prompt_time = time.perf_counter() - tic
|
||||||
|
tic = time.perf_counter()
|
||||||
|
|
||||||
|
text += segment
|
||||||
|
logging.debug(text)
|
||||||
tokens.append(token)
|
tokens.append(token)
|
||||||
|
|
||||||
if self.logprobs > 0:
|
if self.logprobs > 0:
|
||||||
@ -498,121 +503,63 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
stop_sequence_suffix = self.tokenizer.decode(
|
stop_sequence_suffix = self.tokenizer.decode(
|
||||||
tokens[-stop_condition.trim_length :]
|
tokens[-stop_condition.trim_length :]
|
||||||
)
|
)
|
||||||
|
text = text[: -len(stop_sequence_suffix)]
|
||||||
break
|
break
|
||||||
|
|
||||||
self.prompt_cache.tokens.extend(tokens)
|
if self.stream:
|
||||||
detokenizer.finalize()
|
# If the end of tokens overlaps with a stop sequence, generate new
|
||||||
text = (
|
# tokens until we know if the stop sequence is hit or not
|
||||||
detokenizer.text
|
if any(
|
||||||
if stop_sequence_suffix is None
|
(
|
||||||
else detokenizer.text[: -len(stop_sequence_suffix)]
|
sequence_overlap(tokens, sequence)
|
||||||
)
|
for sequence in stop_id_sequences
|
||||||
response = self.generate_response(
|
|
||||||
text,
|
|
||||||
finish_reason,
|
|
||||||
len(prompt),
|
|
||||||
len(tokens),
|
|
||||||
token_logprobs=token_logprobs,
|
|
||||||
top_tokens=top_tokens,
|
|
||||||
tokens=tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
response_json = json.dumps(response).encode()
|
|
||||||
indent = "\t" # Backslashes can't be inside of f-strings
|
|
||||||
logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}")
|
|
||||||
|
|
||||||
# Send an additional Content-Length header when it is known
|
|
||||||
self.send_header("Content-Length", str(len(response_json)))
|
|
||||||
self.end_headers()
|
|
||||||
|
|
||||||
self.wfile.write(response_json)
|
|
||||||
self.wfile.flush()
|
|
||||||
|
|
||||||
def handle_stream(
|
|
||||||
self,
|
|
||||||
prompt: List[int],
|
|
||||||
stop_id_sequences: List[List[int]],
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Generate response to prompt and foward it to the client using a Server
|
|
||||||
Sent Events (SSE) stream.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt (mx.array): The tokenized prompt
|
|
||||||
stop_id_sequences (List[List[int]]): A list of stop words passed to
|
|
||||||
the stopping_criteria function
|
|
||||||
"""
|
|
||||||
# No additional headers are needed, call end_headers
|
|
||||||
self.end_headers()
|
|
||||||
|
|
||||||
detokenizer = self.tokenizer.detokenizer
|
|
||||||
detokenizer.reset()
|
|
||||||
tokens = []
|
|
||||||
|
|
||||||
stop_sequence_suffix = None
|
|
||||||
logging.debug(f"Starting stream:")
|
|
||||||
|
|
||||||
prompt = self.get_prompt_cache(prompt)
|
|
||||||
|
|
||||||
for _, (token, _) in zip(
|
|
||||||
range(self.max_tokens),
|
|
||||||
generate_step(
|
|
||||||
prompt=mx.array(prompt),
|
|
||||||
model=self.model,
|
|
||||||
temp=self.temperature,
|
|
||||||
top_p=self.top_p,
|
|
||||||
repetition_penalty=self.repetition_penalty,
|
|
||||||
repetition_context_size=self.repetition_context_size,
|
|
||||||
prompt_cache=self.prompt_cache.cache,
|
|
||||||
),
|
|
||||||
):
|
|
||||||
detokenizer.add_token(token)
|
|
||||||
logging.debug(detokenizer.text)
|
|
||||||
tokens.append(token)
|
|
||||||
|
|
||||||
stop_condition = stopping_criteria(
|
|
||||||
tokens,
|
|
||||||
stop_id_sequences,
|
|
||||||
self.tokenizer.eos_token_id,
|
|
||||||
)
|
|
||||||
if stop_condition.stop_met:
|
|
||||||
if stop_condition.trim_length:
|
|
||||||
stop_sequence_suffix = self.tokenizer.decode(
|
|
||||||
tokens[-stop_condition.trim_length :]
|
|
||||||
)
|
)
|
||||||
break
|
):
|
||||||
|
continue
|
||||||
# If the end of tokens overlaps with a stop sequence, generate new
|
elif segment:
|
||||||
# tokens until we know if the stop sequence is hit or not
|
response = self.generate_response(segment, None)
|
||||||
if any(
|
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||||
(sequence_overlap(tokens, sequence) for sequence in stop_id_sequences)
|
self.wfile.flush()
|
||||||
):
|
|
||||||
continue
|
|
||||||
|
|
||||||
new_text = detokenizer.last_segment
|
|
||||||
if new_text:
|
|
||||||
response = self.generate_response(new_text, None)
|
|
||||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
|
||||||
self.wfile.flush()
|
|
||||||
|
|
||||||
self.prompt_cache.tokens.extend(tokens)
|
self.prompt_cache.tokens.extend(tokens)
|
||||||
|
|
||||||
# check is there any remaining text to send
|
gen_time = time.perf_counter() - tic
|
||||||
detokenizer.finalize()
|
prompt_tps = len(prompt) / prompt_time
|
||||||
last_segment = detokenizer.last_segment
|
gen_tps = len(tokens) / gen_time
|
||||||
if last_segment:
|
peak_mem = mx.metal.get_peak_memory() / 1e9
|
||||||
if stop_sequence_suffix is not None:
|
logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||||
last_segment = last_segment[: -len(stop_sequence_suffix)]
|
logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||||
response = self.generate_response(last_segment, "length")
|
logging.debug(f"Peak memory: {peak_mem:.3f} GB")
|
||||||
|
|
||||||
|
if self.stream:
|
||||||
|
response = self.generate_response(segment, finish_reason)
|
||||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||||
self.wfile.flush()
|
self.wfile.flush()
|
||||||
|
if self.stream_options is not None and self.stream_options["include_usage"]:
|
||||||
|
response = self.completion_usage_response(len(prompt), len(tokens))
|
||||||
|
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||||
|
self.wfile.flush()
|
||||||
|
self.wfile.write("data: [DONE]\n\n".encode())
|
||||||
|
self.wfile.flush()
|
||||||
|
else:
|
||||||
|
response = self.generate_response(
|
||||||
|
text,
|
||||||
|
finish_reason,
|
||||||
|
len(prompt),
|
||||||
|
len(tokens),
|
||||||
|
token_logprobs=token_logprobs,
|
||||||
|
top_tokens=top_tokens,
|
||||||
|
tokens=tokens,
|
||||||
|
)
|
||||||
|
response_json = json.dumps(response).encode()
|
||||||
|
indent = "\t" # Backslashes can't be inside of f-strings
|
||||||
|
logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}")
|
||||||
|
|
||||||
if self.stream_options is not None and self.stream_options["include_usage"]:
|
# Send an additional Content-Length header when it is known
|
||||||
response = self.completion_usage_response(len(prompt), len(tokens))
|
self.send_header("Content-Length", str(len(response_json)))
|
||||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
self.end_headers()
|
||||||
|
self.wfile.write(response_json)
|
||||||
self.wfile.write("data: [DONE]\n\n".encode())
|
self.wfile.flush()
|
||||||
self.wfile.flush()
|
|
||||||
|
|
||||||
def completion_usage_response(
|
def completion_usage_response(
|
||||||
self,
|
self,
|
||||||
|
@ -6,12 +6,6 @@ from transformers import AutoTokenizer
|
|||||||
REPLACEMENT_CHAR = "\ufffd"
|
REPLACEMENT_CHAR = "\ufffd"
|
||||||
|
|
||||||
|
|
||||||
def _remove_space(x):
|
|
||||||
if x and x[0] == " ":
|
|
||||||
return x[1:]
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class StreamingDetokenizer:
|
class StreamingDetokenizer:
|
||||||
"""The streaming detokenizer interface so that we can detokenize one token at a time.
|
"""The streaming detokenizer interface so that we can detokenize one token at a time.
|
||||||
|
|
||||||
@ -123,42 +117,42 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
|
|||||||
|
|
||||||
def __init__(self, tokenizer, trim_space=True):
|
def __init__(self, tokenizer, trim_space=True):
|
||||||
self.trim_space = trim_space
|
self.trim_space = trim_space
|
||||||
|
self._sep = "\u2581".encode()
|
||||||
|
|
||||||
# Extract the tokens in a list from id to text
|
# Extract the tokens in a list from id to text
|
||||||
self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1)
|
self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1)
|
||||||
for value, tokenid in tokenizer.vocab.items():
|
for value, tokenid in tokenizer.vocab.items():
|
||||||
self.tokenmap[tokenid] = value
|
if value.startswith("<0x"):
|
||||||
|
# Replace bytes with their value
|
||||||
# Replace bytes with their value
|
self.tokenmap[tokenid] = bytes([int(value[3:5], 16)])
|
||||||
for i in range(len(self.tokenmap)):
|
else:
|
||||||
if self.tokenmap[i].startswith("<0x"):
|
self.tokenmap[tokenid] = value.encode()
|
||||||
self.tokenmap[i] = chr(int(self.tokenmap[i][3:5], 16))
|
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.offset = 0
|
self.offset = 0
|
||||||
self._unflushed = ""
|
self._unflushed = b""
|
||||||
self.text = ""
|
self.text = ""
|
||||||
self.tokens = []
|
self.tokens = []
|
||||||
|
|
||||||
|
def _flush(self):
|
||||||
|
text = self._unflushed.replace(self._sep, b" ").decode("utf-8")
|
||||||
|
if not self.text and self.trim_space and text and text[0] == " ":
|
||||||
|
text = text[1:]
|
||||||
|
self.text += text
|
||||||
|
|
||||||
def add_token(self, token):
|
def add_token(self, token):
|
||||||
v = self.tokenmap[token]
|
v = self.tokenmap[token]
|
||||||
if v[0] == "\u2581":
|
if v.startswith(self._sep):
|
||||||
if self.text or not self.trim_space:
|
self._flush()
|
||||||
self.text += self._unflushed.replace("\u2581", " ")
|
|
||||||
else:
|
|
||||||
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
|
|
||||||
self._unflushed = v
|
self._unflushed = v
|
||||||
else:
|
else:
|
||||||
self._unflushed += v
|
self._unflushed += v
|
||||||
|
|
||||||
def finalize(self):
|
def finalize(self):
|
||||||
if self.text or not self.trim_space:
|
self._flush()
|
||||||
self.text += self._unflushed.replace("\u2581", " ")
|
self._unflushed = b""
|
||||||
else:
|
|
||||||
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
|
|
||||||
self._unflushed = ""
|
|
||||||
|
|
||||||
|
|
||||||
class BPEStreamingDetokenizer(StreamingDetokenizer):
|
class BPEStreamingDetokenizer(StreamingDetokenizer):
|
||||||
@ -186,6 +180,8 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
|
|||||||
# https://github.com/openai/gpt-2/blob/master/src/encoder.py
|
# https://github.com/openai/gpt-2/blob/master/src/encoder.py
|
||||||
self.make_byte_decoder()
|
self.make_byte_decoder()
|
||||||
|
|
||||||
|
self._added_ids = set(tokenizer.added_tokens_decoder.keys())
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.offset = 0
|
self.offset = 0
|
||||||
self._unflushed = ""
|
self._unflushed = ""
|
||||||
@ -205,12 +201,17 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
|
|||||||
|
|
||||||
def add_token(self, token):
|
def add_token(self, token):
|
||||||
v = self.tokenmap[token]
|
v = self.tokenmap[token]
|
||||||
if self._byte_decoder[v[0]] == 32:
|
is_added = token in self._added_ids
|
||||||
|
if is_added or self._byte_decoder[v[0]] == 32:
|
||||||
current_text = bytearray(
|
current_text = bytearray(
|
||||||
self._byte_decoder[c] for c in self._unflushed
|
self._byte_decoder[c] for c in self._unflushed
|
||||||
).decode("utf-8")
|
).decode("utf-8")
|
||||||
self.text += self._maybe_trim_space(current_text)
|
self.text += self._maybe_trim_space(current_text)
|
||||||
self._unflushed = v
|
if is_added:
|
||||||
|
self.text += v
|
||||||
|
self._unflushed = ""
|
||||||
|
else:
|
||||||
|
self._unflushed = v
|
||||||
else:
|
else:
|
||||||
self._unflushed += v
|
self._unflushed += v
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ from typing import Union
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from mlx.nn.utils import average_gradients
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
|
|
||||||
@ -84,9 +85,16 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
|||||||
f" examples but only has {len(dataset)}."
|
f" examples but only has {len(dataset)}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If running in distributed mode (N machines) then each one should skip N-1
|
||||||
|
# samples
|
||||||
|
step = mx.distributed.init().size()
|
||||||
|
if batch_size % step != 0:
|
||||||
|
raise ValueError("The batch size must be divisible by the number of workers")
|
||||||
|
|
||||||
# Make the batches:
|
# Make the batches:
|
||||||
batch_idx = [
|
batch_idx = [
|
||||||
idx[i : i + batch_size] for i in range(0, len(idx) - batch_size + 1, batch_size)
|
idx[i : i + batch_size : step]
|
||||||
|
for i in range(0, len(idx) - batch_size + 1, batch_size)
|
||||||
]
|
]
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@ -112,9 +120,9 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
|||||||
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
|
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
|
||||||
max_length_in_batch = min(max_length_in_batch, max_seq_length)
|
max_length_in_batch = min(max_length_in_batch, max_seq_length)
|
||||||
|
|
||||||
batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32)
|
batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
|
||||||
|
|
||||||
for j in range(batch_size):
|
for j in range(batch_size // step):
|
||||||
truncated_length = min(lengths[j], max_seq_length)
|
truncated_length = min(lengths[j], max_seq_length)
|
||||||
batch_arr[j, :truncated_length] = batch[j][:truncated_length]
|
batch_arr[j, :truncated_length] = batch[j][:truncated_length]
|
||||||
lengths[j] = (
|
lengths[j] = (
|
||||||
@ -138,7 +146,7 @@ def evaluate(
|
|||||||
loss: callable = default_loss,
|
loss: callable = default_loss,
|
||||||
iterate_batches: callable = iterate_batches,
|
iterate_batches: callable = iterate_batches,
|
||||||
):
|
):
|
||||||
all_losses = []
|
all_losses = 0
|
||||||
ntokens = 0
|
ntokens = 0
|
||||||
|
|
||||||
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||||
@ -153,10 +161,14 @@ def evaluate(
|
|||||||
),
|
),
|
||||||
):
|
):
|
||||||
losses, toks = loss(model, *batch)
|
losses, toks = loss(model, *batch)
|
||||||
all_losses.append((losses * toks).item())
|
all_losses += losses * toks
|
||||||
ntokens += toks.item()
|
ntokens += toks
|
||||||
|
mx.eval(all_losses, ntokens)
|
||||||
|
|
||||||
return np.sum(all_losses) / ntokens
|
all_losses = mx.distributed.all_sum(all_losses)
|
||||||
|
ntokens = mx.distributed.all_sum(ntokens)
|
||||||
|
|
||||||
|
return (all_losses / ntokens).item()
|
||||||
|
|
||||||
|
|
||||||
class TrainingCallback:
|
class TrainingCallback:
|
||||||
@ -182,6 +194,11 @@ def train(
|
|||||||
training_callback: TrainingCallback = None,
|
training_callback: TrainingCallback = None,
|
||||||
):
|
):
|
||||||
print(f"Starting training..., iters: {args.iters}")
|
print(f"Starting training..., iters: {args.iters}")
|
||||||
|
world = mx.distributed.init()
|
||||||
|
world_size = world.size()
|
||||||
|
rank = world.rank()
|
||||||
|
if world_size > 1:
|
||||||
|
print(f"Node {rank} of {world_size}")
|
||||||
|
|
||||||
if args.grad_checkpoint:
|
if args.grad_checkpoint:
|
||||||
grad_checkpoint(model.layers[0])
|
grad_checkpoint(model.layers[0])
|
||||||
@ -192,6 +209,9 @@ def train(
|
|||||||
# Forward and backward pass
|
# Forward and backward pass
|
||||||
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
|
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
|
||||||
|
|
||||||
|
# All reduce the gradients if running in distributed mode
|
||||||
|
grad = average_gradients(grad)
|
||||||
|
|
||||||
# Model update
|
# Model update
|
||||||
optimizer.update(model, grad)
|
optimizer.update(model, grad)
|
||||||
|
|
||||||
@ -199,8 +219,9 @@ def train(
|
|||||||
|
|
||||||
loss_value_and_grad = nn.value_and_grad(model, loss)
|
loss_value_and_grad = nn.value_and_grad(model, loss)
|
||||||
|
|
||||||
losses = []
|
losses = 0
|
||||||
n_tokens = 0
|
n_tokens = 0
|
||||||
|
steps = 0
|
||||||
trained_tokens = 0
|
trained_tokens = 0
|
||||||
# Main training loop
|
# Main training loop
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
@ -229,9 +250,13 @@ def train(
|
|||||||
iterate_batches=iterate_batches,
|
iterate_batches=iterate_batches,
|
||||||
)
|
)
|
||||||
val_time = time.perf_counter() - stop
|
val_time = time.perf_counter() - stop
|
||||||
print(
|
if rank == 0:
|
||||||
f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s"
|
print(
|
||||||
)
|
f"Iter {it}: "
|
||||||
|
f"Val loss {val_loss:.3f}, "
|
||||||
|
f"Val took {val_time:.3f}s",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
if training_callback is not None:
|
if training_callback is not None:
|
||||||
val_info = {
|
val_info = {
|
||||||
@ -244,30 +269,33 @@ def train(
|
|||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
|
||||||
lvalue, toks = step(batch)
|
lvalue, toks = step(batch)
|
||||||
mx.eval(state, lvalue, toks)
|
losses += lvalue
|
||||||
|
n_tokens += toks
|
||||||
# Record loss
|
steps += 1
|
||||||
losses.append(lvalue.item())
|
mx.eval(state, losses, n_tokens)
|
||||||
n_tokens += toks.item()
|
|
||||||
|
|
||||||
# Report training loss if needed
|
# Report training loss if needed
|
||||||
if it % args.steps_per_report == 0 or it == args.iters:
|
if it % args.steps_per_report == 0 or it == args.iters:
|
||||||
stop = time.perf_counter()
|
stop = time.perf_counter()
|
||||||
|
|
||||||
train_loss = np.mean(losses)
|
train_loss = mx.distributed.all_sum(losses).item()
|
||||||
|
train_loss /= steps * mx.distributed.init().size()
|
||||||
|
n_tokens = mx.distributed.all_sum(n_tokens).item()
|
||||||
learning_rate = optimizer.learning_rate.item()
|
learning_rate = optimizer.learning_rate.item()
|
||||||
it_sec = args.steps_per_report / (stop - start)
|
it_sec = args.steps_per_report / (stop - start)
|
||||||
tokens_sec = float(n_tokens) / (stop - start)
|
tokens_sec = float(n_tokens) / (stop - start)
|
||||||
trained_tokens += n_tokens
|
trained_tokens += n_tokens
|
||||||
peak_mem = mx.metal.get_peak_memory() / 2**30
|
peak_mem = mx.metal.get_peak_memory() / 1e9
|
||||||
print(
|
if rank == 0:
|
||||||
f"Iter {it}: Train loss {train_loss:.3f}, "
|
print(
|
||||||
f"Learning Rate {learning_rate:.3e}, "
|
f"Iter {it}: Train loss {train_loss:.3f}, "
|
||||||
f"It/sec {it_sec:.3f}, "
|
f"Learning Rate {learning_rate:.3e}, "
|
||||||
f"Tokens/sec {tokens_sec:.3f}, "
|
f"It/sec {it_sec:.3f}, "
|
||||||
f"Trained Tokens {trained_tokens}, "
|
f"Tokens/sec {tokens_sec:.3f}, "
|
||||||
f"Peak mem {peak_mem:.3f} GB"
|
f"Trained Tokens {trained_tokens}, "
|
||||||
)
|
f"Peak mem {peak_mem:.3f} GB",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
if training_callback is not None:
|
if training_callback is not None:
|
||||||
train_info = {
|
train_info = {
|
||||||
@ -281,8 +309,9 @@ def train(
|
|||||||
}
|
}
|
||||||
training_callback.on_train_loss_report(train_info)
|
training_callback.on_train_loss_report(train_info)
|
||||||
|
|
||||||
losses = []
|
losses = 0
|
||||||
n_tokens = 0
|
n_tokens = 0
|
||||||
|
steps = 0
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
|
||||||
# Save adapter weights
|
# Save adapter weights
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
import glob
|
import glob
|
||||||
import importlib
|
import importlib
|
||||||
@ -14,12 +15,12 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten, tree_reduce
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
# Local imports
|
# Local imports
|
||||||
from .models import base, cache
|
from .models import cache
|
||||||
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
|
from .sample_utils import make_logits_processors, make_sampler
|
||||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||||
from .tuner.utils import dequantize as dequantize_model
|
from .tuner.utils import dequantize as dequantize_model
|
||||||
from .tuner.utils import load_adapters
|
from .tuner.utils import load_adapters
|
||||||
@ -28,10 +29,14 @@ from .tuner.utils import load_adapters
|
|||||||
MODEL_REMAPPING = {
|
MODEL_REMAPPING = {
|
||||||
"mistral": "llama", # mistral is compatible with llama
|
"mistral": "llama", # mistral is compatible with llama
|
||||||
"phi-msft": "phixtral",
|
"phi-msft": "phixtral",
|
||||||
|
"falcon_mamba": "mamba",
|
||||||
}
|
}
|
||||||
|
|
||||||
MAX_FILE_SIZE_GB = 5
|
MAX_FILE_SIZE_GB = 5
|
||||||
|
|
||||||
|
# A stream on the default device just for generation
|
||||||
|
generation_stream = mx.new_stream(mx.default_device())
|
||||||
|
|
||||||
|
|
||||||
class ModelNotFoundError(Exception):
|
class ModelNotFoundError(Exception):
|
||||||
def __init__(self, message):
|
def __init__(self, message):
|
||||||
@ -39,6 +44,40 @@ class ModelNotFoundError(Exception):
|
|||||||
super().__init__(self.message)
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None):
|
||||||
|
"""
|
||||||
|
A context manager to temporarily change the wired limit.
|
||||||
|
|
||||||
|
Note, the wired limit should not be changed during an async eval. If an
|
||||||
|
async eval could be running pass in the streams to synchronize with prior
|
||||||
|
to exiting the context manager.
|
||||||
|
"""
|
||||||
|
model_bytes = tree_reduce(
|
||||||
|
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
|
||||||
|
)
|
||||||
|
max_rec_size = mx.metal.device_info()["max_recommended_working_set_size"]
|
||||||
|
if model_bytes > 0.9 * max_rec_size:
|
||||||
|
model_mb = model_bytes // 2**20
|
||||||
|
max_rec_mb = max_rec_size // 2**20
|
||||||
|
print(
|
||||||
|
"[WARNING] Generating with a model that requires {model_mb} MB "
|
||||||
|
"which is close to the maximum recommended size of {max_rec_mb} "
|
||||||
|
"MB. This can be slow. See the documentation for possible work-arounds: "
|
||||||
|
"https://github.com/ml-explore/mlx-examples/tree/main/llms#large-models"
|
||||||
|
)
|
||||||
|
old_limit = mx.metal.set_wired_limit(max_rec_size)
|
||||||
|
try:
|
||||||
|
yield None
|
||||||
|
finally:
|
||||||
|
if streams is not None:
|
||||||
|
for s in streams:
|
||||||
|
mx.synchronize(s)
|
||||||
|
else:
|
||||||
|
mx.synchronize()
|
||||||
|
mx.metal.set_wired_limit(old_limit)
|
||||||
|
|
||||||
|
|
||||||
def _get_classes(config: dict):
|
def _get_classes(config: dict):
|
||||||
"""
|
"""
|
||||||
Retrieve the model and model args classes based on the configuration.
|
Retrieve the model and model args classes based on the configuration.
|
||||||
@ -101,27 +140,16 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
|
|||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float):
|
def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
|
||||||
"""
|
if (
|
||||||
Apply repetition penalty to specific logits based on the given context.
|
kv_bits is not None
|
||||||
|
and not isinstance(prompt_cache[0], cache.QuantizedKVCache)
|
||||||
Paper: https://arxiv.org/abs/1909.05858
|
and prompt_cache[0].offset > quantized_kv_start
|
||||||
|
):
|
||||||
Args:
|
for i in range(len(prompt_cache)):
|
||||||
logits (mx.array): The logits produced by the language model.
|
prompt_cache[i] = prompt_cache[i].to_quantized(
|
||||||
tokens (mx.array): A list of N previous tokens.
|
group_size=kv_group_size, bits=kv_bits
|
||||||
penalty (float): The repetition penalty factor to be applied.
|
)
|
||||||
|
|
||||||
Returns:
|
|
||||||
logits (mx.array): Logits with repetition penalty applied to generated tokens.
|
|
||||||
"""
|
|
||||||
if len(tokens) > 0:
|
|
||||||
selected_logits = logits[:, tokens]
|
|
||||||
selected_logits = mx.where(
|
|
||||||
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
|
|
||||||
)
|
|
||||||
logits[:, tokens] = selected_logits
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
def generate_step(
|
def generate_step(
|
||||||
@ -137,7 +165,10 @@ def generate_step(
|
|||||||
max_kv_size: Optional[int] = None,
|
max_kv_size: Optional[int] = None,
|
||||||
prompt_cache: Optional[Any] = None,
|
prompt_cache: Optional[Any] = None,
|
||||||
logit_bias: Optional[Dict[int, float]] = None,
|
logit_bias: Optional[Dict[int, float]] = None,
|
||||||
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||||
|
kv_bits: Optional[int] = None,
|
||||||
|
kv_group_size: int = 64,
|
||||||
|
quantized_kv_start: int = 0,
|
||||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||||
"""
|
"""
|
||||||
A generator producing token ids based on the given prompt from the model.
|
A generator producing token ids based on the given prompt from the model.
|
||||||
@ -163,80 +194,56 @@ def generate_step(
|
|||||||
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
||||||
provided, the cache will be updated in place.
|
provided, the cache will be updated in place.
|
||||||
logit_bias (dictionary, optional): Additive logit bias.
|
logit_bias (dictionary, optional): Additive logit bias.
|
||||||
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||||
A list of functions that take tokens and logits and return the processed
|
A list of functions that take tokens and logits and return the processed
|
||||||
logits. Default: ``None``.
|
logits. Default: ``None``.
|
||||||
|
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
||||||
|
None implies no cache quantization. Default: ``None``.
|
||||||
|
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
|
||||||
|
quantized_kv_start (int): Step to begin using a quantized KV cache.
|
||||||
|
when ``kv_bits`` is non-None. Default: ``0``.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
|
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
||||||
one token and a vector of log probabilities.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
|
||||||
logprobs = logits - mx.logsumexp(logits)
|
|
||||||
|
|
||||||
if temp == 0:
|
|
||||||
token = mx.argmax(logits, axis=-1)
|
|
||||||
else:
|
|
||||||
if top_p > 0 and top_p < 1.0:
|
|
||||||
token = top_p_sampling(logits, top_p, temp)
|
|
||||||
elif min_p != 0.0:
|
|
||||||
token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp)
|
|
||||||
else:
|
|
||||||
token = categorical_sampling(logits, temp)
|
|
||||||
|
|
||||||
return token, logprobs
|
|
||||||
|
|
||||||
if repetition_penalty and (
|
|
||||||
repetition_penalty < 0 or not isinstance(repetition_penalty, float)
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
|
|
||||||
)
|
|
||||||
|
|
||||||
logits_processor = logits_processor or []
|
|
||||||
|
|
||||||
if repetition_penalty:
|
|
||||||
|
|
||||||
def repetition_penalty_processor(tokens, logits):
|
|
||||||
return apply_repetition_penalty(
|
|
||||||
logits, tokens[-repetition_context_size:], repetition_penalty
|
|
||||||
)
|
|
||||||
|
|
||||||
logits_processor.append(repetition_penalty_processor)
|
|
||||||
|
|
||||||
if logit_bias:
|
|
||||||
indices = mx.array(list(logit_bias.keys()))
|
|
||||||
values = mx.array(list(logit_bias.values()))
|
|
||||||
|
|
||||||
def logit_bias_processor(_, logits):
|
|
||||||
logits[:, indices] += values
|
|
||||||
return logits
|
|
||||||
|
|
||||||
logits_processor.append(logit_bias_processor)
|
|
||||||
|
|
||||||
y = prompt
|
y = prompt
|
||||||
tokens = None
|
tokens = None
|
||||||
|
|
||||||
# Create the KV cache for generation
|
# Create the KV cache for generation
|
||||||
if prompt_cache is None:
|
if prompt_cache is None:
|
||||||
prompt_cache = cache.make_prompt_cache(model, max_kv_size)
|
prompt_cache = cache.make_prompt_cache(
|
||||||
|
model,
|
||||||
|
max_kv_size=max_kv_size,
|
||||||
|
)
|
||||||
elif len(prompt_cache) != len(model.layers):
|
elif len(prompt_cache) != len(model.layers):
|
||||||
raise ValueError("Wrong number of layers in the prompt cache.")
|
raise ValueError("Wrong number of layers in the prompt cache.")
|
||||||
|
|
||||||
|
sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep)
|
||||||
|
logits_processors = logits_processors or []
|
||||||
|
logits_processors.extend(
|
||||||
|
make_logits_processors(logit_bias, repetition_penalty, repetition_context_size)
|
||||||
|
)
|
||||||
|
|
||||||
def _step(y):
|
def _step(y):
|
||||||
logits = model(y[None], cache=prompt_cache)
|
with mx.stream(generation_stream):
|
||||||
logits = logits[:, -1, :]
|
logits = model(y[None], cache=prompt_cache)
|
||||||
|
logits = logits[:, -1, :]
|
||||||
|
|
||||||
if logits_processor:
|
if logits_processors:
|
||||||
nonlocal tokens
|
nonlocal tokens
|
||||||
tokens = mx.concat([tokens, y]) if tokens is not None else y
|
tokens = mx.concat([tokens, y]) if tokens is not None else y
|
||||||
|
|
||||||
for processor in logits_processor:
|
for processor in logits_processors:
|
||||||
logits = processor(tokens, logits)
|
logits = processor(tokens, logits)
|
||||||
|
|
||||||
y, logprobs = sample(logits)
|
maybe_quantize_kv_cache(
|
||||||
return y, logprobs.squeeze(0)
|
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
||||||
|
)
|
||||||
|
|
||||||
|
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||||
|
y = sampler(logprobs)
|
||||||
|
return y, logprobs.squeeze(0)
|
||||||
|
|
||||||
while y.size > prefill_step_size:
|
while y.size > prefill_step_size:
|
||||||
model(y[:prefill_step_size][None], cache=prompt_cache)
|
model(y[:prefill_step_size][None], cache=prompt_cache)
|
||||||
@ -247,53 +254,65 @@ def generate_step(
|
|||||||
y, logprobs = _step(y)
|
y, logprobs = _step(y)
|
||||||
|
|
||||||
mx.async_eval(y, logprobs)
|
mx.async_eval(y, logprobs)
|
||||||
|
n = 0
|
||||||
while True:
|
while True:
|
||||||
next_y, next_logprobs = _step(y)
|
next_y, next_logprobs = _step(y)
|
||||||
mx.async_eval(next_y, next_logprobs)
|
mx.async_eval(next_y, next_logprobs)
|
||||||
yield y.item(), logprobs
|
yield y.item(), logprobs
|
||||||
|
if n % 256 == 0:
|
||||||
|
mx.metal.clear_cache()
|
||||||
|
n += 1
|
||||||
y, logprobs = next_y, next_logprobs
|
y, logprobs = next_y, next_logprobs
|
||||||
|
|
||||||
|
|
||||||
def stream_generate(
|
def stream_generate(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||||
prompt: str,
|
prompt: Union[str, List[int]],
|
||||||
max_tokens: int = 100,
|
max_tokens: int = 100,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[str, Generator[str, None, None]]:
|
) -> Generator[Tuple[str, int, mx.array], None, None]:
|
||||||
"""
|
"""
|
||||||
A generator producing text based on the given prompt from the model.
|
A generator producing text based on the given prompt from the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt (mx.array): The input prompt.
|
|
||||||
model (nn.Module): The model to use for generation.
|
model (nn.Module): The model to use for generation.
|
||||||
max_tokens (int): The ma
|
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||||
|
prompt (Union[str, List[int]]): The input prompt string or integer tokens.
|
||||||
|
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
||||||
kwargs: The remaining options get passed to :func:`generate_step`.
|
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||||
See :func:`generate_step` for more details.
|
See :func:`generate_step` for more details.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Generator[Tuple[mx.array, mx.array]]: A generator producing text.
|
Tuple[str, int, mx.array]:
|
||||||
|
The next text segment, token, and vector of log probabilities.
|
||||||
"""
|
"""
|
||||||
if not isinstance(tokenizer, TokenizerWrapper):
|
if not isinstance(tokenizer, TokenizerWrapper):
|
||||||
tokenizer = TokenizerWrapper(tokenizer)
|
tokenizer = TokenizerWrapper(tokenizer)
|
||||||
|
|
||||||
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
prompt_tokens = mx.array(
|
||||||
|
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
|
||||||
|
)
|
||||||
detokenizer = tokenizer.detokenizer
|
detokenizer = tokenizer.detokenizer
|
||||||
|
|
||||||
detokenizer.reset()
|
with wired_limit(model, [generation_stream]):
|
||||||
for n, (token, _) in zip(
|
detokenizer.reset()
|
||||||
range(max_tokens),
|
for n, (token, logits) in zip(
|
||||||
generate_step(prompt_tokens, model, **kwargs),
|
range(max_tokens),
|
||||||
):
|
generate_step(prompt_tokens, model, **kwargs),
|
||||||
if token == tokenizer.eos_token_id:
|
):
|
||||||
break
|
if token == tokenizer.eos_token_id:
|
||||||
detokenizer.add_token(token)
|
break
|
||||||
|
|
||||||
# Yield the last segment if streaming
|
detokenizer.add_token(token)
|
||||||
yield detokenizer.last_segment
|
|
||||||
|
|
||||||
detokenizer.finalize()
|
if n == (max_tokens - 1):
|
||||||
yield detokenizer.last_segment
|
break
|
||||||
|
|
||||||
|
yield detokenizer.last_segment, token, logits
|
||||||
|
|
||||||
|
detokenizer.finalize()
|
||||||
|
yield detokenizer.last_segment, token, logits
|
||||||
|
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
@ -304,7 +323,7 @@ def generate(
|
|||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
formatter: Optional[Callable] = None,
|
formatter: Optional[Callable] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[str, Generator[str, None, None]]:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generate a complete response from the model.
|
Generate a complete response from the model.
|
||||||
|
|
||||||
@ -330,48 +349,49 @@ def generate(
|
|||||||
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
||||||
detokenizer = tokenizer.detokenizer
|
detokenizer = tokenizer.detokenizer
|
||||||
|
|
||||||
tic = time.perf_counter()
|
with wired_limit(model, [generation_stream]):
|
||||||
detokenizer.reset()
|
tic = time.perf_counter()
|
||||||
|
detokenizer.reset()
|
||||||
|
for n, (token, logprobs) in zip(
|
||||||
|
range(max_tokens),
|
||||||
|
generate_step(prompt_tokens, model, **kwargs),
|
||||||
|
):
|
||||||
|
if n == 0:
|
||||||
|
prompt_time = time.perf_counter() - tic
|
||||||
|
tic = time.perf_counter()
|
||||||
|
if token == tokenizer.eos_token_id:
|
||||||
|
break
|
||||||
|
detokenizer.add_token(token)
|
||||||
|
|
||||||
for n, (token, logprobs) in zip(
|
if verbose:
|
||||||
range(max_tokens),
|
if formatter:
|
||||||
generate_step(prompt_tokens, model, **kwargs),
|
# We have to finalize so that the prob corresponds to the last segment
|
||||||
):
|
detokenizer.finalize()
|
||||||
if n == 0:
|
prob = mx.exp(logprobs[token]).item()
|
||||||
prompt_time = time.perf_counter() - tic
|
formatter(detokenizer.last_segment, prob)
|
||||||
tic = time.perf_counter()
|
else:
|
||||||
if token == tokenizer.eos_token_id:
|
print(detokenizer.last_segment, end="", flush=True)
|
||||||
break
|
|
||||||
detokenizer.add_token(token)
|
token_count = n + 1
|
||||||
|
detokenizer.finalize()
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
if formatter:
|
gen_time = time.perf_counter() - tic
|
||||||
# We have to finalize so that the prob corresponds to the last segment
|
print(detokenizer.last_segment, flush=True)
|
||||||
detokenizer.finalize()
|
print("=" * 10)
|
||||||
with mx.stream(mx.cpu):
|
if token_count == 0:
|
||||||
prob = mx.exp(logprobs[token]).item()
|
print("No tokens generated for this prompt")
|
||||||
formatter(detokenizer.last_segment, prob)
|
return
|
||||||
else:
|
prompt_tps = prompt_tokens.size / prompt_time
|
||||||
print(detokenizer.last_segment, end="", flush=True)
|
gen_tps = (token_count - 1) / gen_time
|
||||||
|
print(
|
||||||
|
f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec"
|
||||||
|
)
|
||||||
|
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
|
||||||
|
peak_mem = mx.metal.get_peak_memory() / 1e9
|
||||||
|
print(f"Peak memory: {peak_mem:.3f} GB")
|
||||||
|
|
||||||
token_count = n + 1
|
return detokenizer.text
|
||||||
detokenizer.finalize()
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
gen_time = time.perf_counter() - tic
|
|
||||||
print(detokenizer.last_segment, flush=True)
|
|
||||||
print("=" * 10)
|
|
||||||
if token_count == 0:
|
|
||||||
print("No tokens generated for this prompt")
|
|
||||||
return
|
|
||||||
prompt_tps = prompt_tokens.size / prompt_time
|
|
||||||
gen_tps = (token_count - 1) / gen_time
|
|
||||||
print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec")
|
|
||||||
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
|
|
||||||
peak_mem = mx.metal.get_peak_memory() / 2**30
|
|
||||||
print(f"Peak memory: {peak_mem:.3f} GB")
|
|
||||||
|
|
||||||
return detokenizer.text
|
|
||||||
|
|
||||||
|
|
||||||
def load_config(model_path: Path) -> dict:
|
def load_config(model_path: Path) -> dict:
|
||||||
@ -553,7 +573,9 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
|||||||
f"""
|
f"""
|
||||||
# {upload_repo}
|
# {upload_repo}
|
||||||
|
|
||||||
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) using mlx-lm version **{__version__}**.
|
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was
|
||||||
|
converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path})
|
||||||
|
using mlx-lm version **{__version__}**.
|
||||||
|
|
||||||
## Use with mlx
|
## Use with mlx
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
import math
|
import math
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
|
from contextlib import contextmanager
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
@ -17,6 +18,14 @@ from mlx_lm.tuner.trainer import evaluate
|
|||||||
from mlx_lm.tuner.utils import build_schedule
|
from mlx_lm.tuner.utils import build_schedule
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def swapped_with_identity(obj, func):
|
||||||
|
old_func = getattr(obj, func)
|
||||||
|
setattr(obj, func, lambda x: x)
|
||||||
|
yield
|
||||||
|
setattr(obj, func, old_func)
|
||||||
|
|
||||||
|
|
||||||
class TestLora(unittest.TestCase):
|
class TestLora(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.capturedOutput = StringIO()
|
self.capturedOutput = StringIO()
|
||||||
@ -374,16 +383,17 @@ class TestScheduleConfig(unittest.TestCase):
|
|||||||
(MagicMock(return_value=0.4), MagicMock(return_value=180)),
|
(MagicMock(return_value=0.4), MagicMock(return_value=180)),
|
||||||
(MagicMock(return_value=0.6), MagicMock(return_value=120)),
|
(MagicMock(return_value=0.6), MagicMock(return_value=120)),
|
||||||
]
|
]
|
||||||
evaluate(
|
with swapped_with_identity(mx.distributed, "all_sum"):
|
||||||
model=mock_model,
|
evaluate(
|
||||||
dataset=mock_dataset,
|
model=mock_model,
|
||||||
tokenizer=mock_tokenizer,
|
dataset=mock_dataset,
|
||||||
batch_size=2,
|
tokenizer=mock_tokenizer,
|
||||||
num_batches=2,
|
batch_size=2,
|
||||||
max_seq_length=2048,
|
num_batches=2,
|
||||||
loss=mock_default_loss,
|
max_seq_length=2048,
|
||||||
iterate_batches=mock_iterate_batches,
|
loss=mock_default_loss,
|
||||||
)
|
iterate_batches=mock_iterate_batches,
|
||||||
|
)
|
||||||
|
|
||||||
mock_iterate_batches.assert_called_once_with(
|
mock_iterate_batches.assert_called_once_with(
|
||||||
dataset=mock_dataset,
|
dataset=mock_dataset,
|
||||||
@ -412,16 +422,17 @@ class TestScheduleConfig(unittest.TestCase):
|
|||||||
(MagicMock(return_value=0.2), MagicMock(return_value=150)),
|
(MagicMock(return_value=0.2), MagicMock(return_value=150)),
|
||||||
]
|
]
|
||||||
|
|
||||||
evaluate(
|
with swapped_with_identity(mx.distributed, "all_sum"):
|
||||||
model=mock_model,
|
evaluate(
|
||||||
dataset=mock_dataset,
|
model=mock_model,
|
||||||
tokenizer=mock_tokenizer,
|
dataset=mock_dataset,
|
||||||
batch_size=2,
|
tokenizer=mock_tokenizer,
|
||||||
num_batches=-1,
|
batch_size=2,
|
||||||
max_seq_length=2048,
|
num_batches=-1,
|
||||||
loss=mock_default_loss,
|
max_seq_length=2048,
|
||||||
iterate_batches=mock_iterate_batches,
|
loss=mock_default_loss,
|
||||||
)
|
iterate_batches=mock_iterate_batches,
|
||||||
|
)
|
||||||
|
|
||||||
mock_iterate_batches.assert_called_once_with(
|
mock_iterate_batches.assert_called_once_with(
|
||||||
dataset=mock_dataset,
|
dataset=mock_dataset,
|
||||||
|
@ -46,7 +46,7 @@ class TestGenerate(unittest.TestCase):
|
|||||||
"hello",
|
"hello",
|
||||||
max_tokens=5,
|
max_tokens=5,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
logits_processor=[logits_processor],
|
logits_processors=[logits_processor],
|
||||||
)
|
)
|
||||||
self.assertEqual(len(all_toks), len(init_toks) + 5)
|
self.assertEqual(len(all_toks), len(init_toks) + 5)
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@ import mlx.core as mx
|
|||||||
from mlx_lm.models.cache import (
|
from mlx_lm.models.cache import (
|
||||||
KVCache,
|
KVCache,
|
||||||
MambaCache,
|
MambaCache,
|
||||||
|
QuantizedKVCache,
|
||||||
RotatingKVCache,
|
RotatingKVCache,
|
||||||
load_prompt_cache,
|
load_prompt_cache,
|
||||||
make_prompt_cache,
|
make_prompt_cache,
|
||||||
@ -186,6 +187,18 @@ class TestPromptCache(unittest.TestCase):
|
|||||||
num_trimmed = trim_prompt_cache(cache, 4)
|
num_trimmed = trim_prompt_cache(cache, 4)
|
||||||
self.assertEqual(num_trimmed, 0)
|
self.assertEqual(num_trimmed, 0)
|
||||||
|
|
||||||
|
cache = [QuantizedKVCache() for _ in range(2)]
|
||||||
|
for c in cache:
|
||||||
|
x = mx.random.uniform(shape=(1, 8, 10, 64))
|
||||||
|
c.update_and_fetch(x, x)
|
||||||
|
|
||||||
|
num_trimmed = trim_prompt_cache(cache, 7)
|
||||||
|
self.assertEqual(num_trimmed, 7)
|
||||||
|
|
||||||
|
# Trim more tokens than remain
|
||||||
|
num_trimmed = trim_prompt_cache(cache, 4)
|
||||||
|
self.assertEqual(num_trimmed, 3)
|
||||||
|
|
||||||
def test_trim_cache_with_generate(self):
|
def test_trim_cache_with_generate(self):
|
||||||
model, tokenizer = load(HF_MODEL_PATH)
|
model, tokenizer = load(HF_MODEL_PATH)
|
||||||
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
|
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
|
||||||
@ -238,6 +251,56 @@ class TestPromptCache(unittest.TestCase):
|
|||||||
self.assertTrue(mx.allclose(old_cache[0].keys[..., 10:11, :], y))
|
self.assertTrue(mx.allclose(old_cache[0].keys[..., 10:11, :], y))
|
||||||
self.assertTrue(mx.allclose(cache[0].keys[..., 10:11, :], z))
|
self.assertTrue(mx.allclose(cache[0].keys[..., 10:11, :], z))
|
||||||
|
|
||||||
|
def test_save_load_quantized_cache(self):
|
||||||
|
cache = [QuantizedKVCache(bits=4, group_size=32) for _ in range(4)]
|
||||||
|
for c in cache:
|
||||||
|
x = mx.random.uniform(shape=(1, 8, 10, 32))
|
||||||
|
c.update_and_fetch(x, x)
|
||||||
|
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
|
||||||
|
save_prompt_cache(cache_file, cache)
|
||||||
|
loaded_cache = load_prompt_cache(cache_file)
|
||||||
|
self.assertTrue(loaded_cache[0].bits == cache[0].bits)
|
||||||
|
self.assertTrue(loaded_cache[0].group_size == cache[0].group_size)
|
||||||
|
self.assertTrue(len(cache), len(loaded_cache))
|
||||||
|
for c, lc in zip(cache, loaded_cache):
|
||||||
|
self.assertEqual(c.offset, lc.offset)
|
||||||
|
# Loop over quantized tuple
|
||||||
|
for i in range(3):
|
||||||
|
self.assertTrue(mx.array_equal(c.state[0][i], lc.state[0][i]))
|
||||||
|
self.assertTrue(mx.array_equal(c.state[1][i], lc.state[1][i]))
|
||||||
|
|
||||||
|
# Test with metadata
|
||||||
|
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
|
||||||
|
metadata = {"a": "b", "c": "d"}
|
||||||
|
save_prompt_cache(cache_file, cache, metadata)
|
||||||
|
_, loaded_metadata = load_prompt_cache(cache_file, return_metadata=True)
|
||||||
|
self.assertEqual(metadata, loaded_metadata)
|
||||||
|
|
||||||
|
def test_cache_to_quantized(self):
|
||||||
|
model, tokenizer = load(HF_MODEL_PATH)
|
||||||
|
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
|
||||||
|
results = zip(range(4), generate_step(prompt, model))
|
||||||
|
toks, all_logits = zip(*(r[1] for r in results))
|
||||||
|
|
||||||
|
prompt_cache = make_prompt_cache(model)
|
||||||
|
i = 0
|
||||||
|
for _, (tok, logits) in zip(
|
||||||
|
range(2), generate_step(prompt, model, prompt_cache=prompt_cache)
|
||||||
|
):
|
||||||
|
self.assertEqual(tok, toks[i])
|
||||||
|
self.assertTrue(mx.allclose(logits, all_logits[i]))
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
prompt_cache = [c.to_quantized(bits=8, group_size=32) for c in prompt_cache]
|
||||||
|
|
||||||
|
for _, (tok, logits) in zip(
|
||||||
|
range(1),
|
||||||
|
generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache),
|
||||||
|
):
|
||||||
|
i += 1
|
||||||
|
self.assertEqual(tok, toks[i])
|
||||||
|
self.assertTrue(mx.allclose(logits, all_logits[i], rtol=2e-2))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -42,6 +42,9 @@ class TestTokenizers(unittest.TestCase):
|
|||||||
text += detokenizer.last_segment
|
text += detokenizer.last_segment
|
||||||
self.assertEqual(text, expected_text)
|
self.assertEqual(text, expected_text)
|
||||||
|
|
||||||
|
tokens = tokenizer.encode("こんにちは!私の名前はAI")
|
||||||
|
check(tokens)
|
||||||
|
|
||||||
tokens = tokenizer.encode("a ,b")
|
tokens = tokenizer.encode("a ,b")
|
||||||
check(tokens)
|
check(tokens)
|
||||||
|
|
||||||
@ -74,6 +77,17 @@ class TestTokenizers(unittest.TestCase):
|
|||||||
tokenizer._detokenizer = NaiveStreamingDetokenizer(tokenizer)
|
tokenizer._detokenizer = NaiveStreamingDetokenizer(tokenizer)
|
||||||
self.check_tokenizer(tokenizer)
|
self.check_tokenizer(tokenizer)
|
||||||
|
|
||||||
|
def test_special_tokens(self):
|
||||||
|
tokenizer_repo = "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx"
|
||||||
|
tokenizer = self.download_tokenizer(tokenizer_repo)
|
||||||
|
|
||||||
|
detokenizer = tokenizer.detokenizer
|
||||||
|
detokenizer.reset()
|
||||||
|
detokenizer.add_token(tokenizer.eos_token_id)
|
||||||
|
detokenizer.finalize()
|
||||||
|
|
||||||
|
self.assertEqual(detokenizer.last_segment, tokenizer.eos_token)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -25,7 +25,7 @@ pip install mlx-whisper
|
|||||||
|
|
||||||
At its simplest:
|
At its simplest:
|
||||||
|
|
||||||
```
|
```sh
|
||||||
mlx_whisper audio_file.mp3
|
mlx_whisper audio_file.mp3
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -35,6 +35,15 @@ Use `-f` to specify the output format and `--model` to specify the model. There
|
|||||||
are many other supported command line options. To see them all, run
|
are many other supported command line options. To see them all, run
|
||||||
`mlx_whisper -h`.
|
`mlx_whisper -h`.
|
||||||
|
|
||||||
|
You can also pipe the audio content of other programs via stdin:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
some-process | mlx_whisper -
|
||||||
|
```
|
||||||
|
|
||||||
|
The default output file name will be `content.*`. You can specify the name with
|
||||||
|
the `--output-name` flag.
|
||||||
|
|
||||||
#### API
|
#### API
|
||||||
|
|
||||||
Transcribe audio with:
|
Transcribe audio with:
|
||||||
|
@ -181,7 +181,7 @@ def load_torch_weights_and_config(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if name_or_path.endswith(".pt"):
|
if name_or_path.endswith(".pt"):
|
||||||
checkpoint = torch.load(name_or_path, map_location="cpu")
|
checkpoint = torch.load(name_or_path, map_location="cpu", weights_only=False)
|
||||||
weights, config = checkpoint["model_state_dict"], checkpoint["dims"]
|
weights, config = checkpoint["model_state_dict"], checkpoint["dims"]
|
||||||
else:
|
else:
|
||||||
name_or_path = Path(name_or_path)
|
name_or_path = Path(name_or_path)
|
||||||
@ -387,7 +387,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Save weights
|
# Save weights
|
||||||
print("[INFO] Saving")
|
print("[INFO] Saving")
|
||||||
np.savez(str(mlx_path / "weights.npz"), **weights)
|
mx.save_safetensors(str(mlx_path / "weights.safetensors"), weights)
|
||||||
|
|
||||||
# Save config.json with model_type
|
# Save config.json with model_type
|
||||||
with open(str(mlx_path / "config.json"), "w") as f:
|
with open(str(mlx_path / "config.json"), "w") as f:
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
__version__ = "0.3.0"
|
__version__ = "0.4.1"
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
import os
|
import os
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from subprocess import CalledProcessError, run
|
from subprocess import CalledProcessError, run
|
||||||
from typing import Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -21,7 +21,7 @@ FRAMES_PER_SECOND = SAMPLE_RATE // HOP_LENGTH # 10ms per audio frame
|
|||||||
TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token
|
TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token
|
||||||
|
|
||||||
|
|
||||||
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
def load_audio(file: str = Optional[str], sr: int = SAMPLE_RATE, from_stdin=False):
|
||||||
"""
|
"""
|
||||||
Open an audio file and read as mono waveform, resampling as necessary
|
Open an audio file and read as mono waveform, resampling as necessary
|
||||||
|
|
||||||
@ -39,19 +39,21 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# This launches a subprocess to decode audio while down-mixing
|
# This launches a subprocess to decode audio while down-mixing
|
||||||
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
||||||
|
if from_stdin:
|
||||||
|
cmd = ["ffmpeg", "-i", "pipe:0"]
|
||||||
|
else:
|
||||||
|
cmd = ["ffmpeg", "-nostdin", "-i", file]
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
cmd = [
|
cmd.extend([
|
||||||
"ffmpeg",
|
|
||||||
"-nostdin",
|
|
||||||
"-threads", "0",
|
"-threads", "0",
|
||||||
"-i", file,
|
|
||||||
"-f", "s16le",
|
"-f", "s16le",
|
||||||
"-ac", "1",
|
"-ac", "1",
|
||||||
"-acodec", "pcm_s16le",
|
"-acodec", "pcm_s16le",
|
||||||
"-ar", str(sr),
|
"-ar", str(sr),
|
||||||
"-"
|
"-"
|
||||||
]
|
])
|
||||||
# fmt: on
|
# fmt: on
|
||||||
try:
|
try:
|
||||||
out = run(cmd, capture_output=True, check=True).stdout
|
out = run(cmd, capture_output=True, check=True).stdout
|
||||||
|
@ -2,9 +2,11 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
import pathlib
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
from . import audio
|
||||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE
|
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE
|
||||||
from .transcribe import transcribe
|
from .transcribe import transcribe
|
||||||
from .writers import get_writer
|
from .writers import get_writer
|
||||||
@ -27,15 +29,24 @@ def build_parser():
|
|||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"audio", nargs="+", type=str, help="Audio file(s) to transcribe"
|
parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe")
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
default="mlx-community/whisper-tiny",
|
default="mlx-community/whisper-tiny",
|
||||||
type=str,
|
type=str,
|
||||||
help="The model directory or hugging face repo",
|
help="The model directory or hugging face repo",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-name",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=(
|
||||||
|
"The name of transcription/translation output files before "
|
||||||
|
"--output-format extensions"
|
||||||
|
),
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output-dir",
|
"--output-dir",
|
||||||
"-o",
|
"-o",
|
||||||
@ -200,6 +211,7 @@ def main():
|
|||||||
path_or_hf_repo: str = args.pop("model")
|
path_or_hf_repo: str = args.pop("model")
|
||||||
output_dir: str = args.pop("output_dir")
|
output_dir: str = args.pop("output_dir")
|
||||||
output_format: str = args.pop("output_format")
|
output_format: str = args.pop("output_format")
|
||||||
|
output_name: str = args.pop("output_name")
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
writer = get_writer(output_format, output_dir)
|
writer = get_writer(output_format, output_dir)
|
||||||
@ -219,17 +231,25 @@ def main():
|
|||||||
warnings.warn("--max-line-count has no effect without --max-line-width")
|
warnings.warn("--max-line-count has no effect without --max-line-width")
|
||||||
if writer_args["max_words_per_line"] and writer_args["max_line_width"]:
|
if writer_args["max_words_per_line"] and writer_args["max_line_width"]:
|
||||||
warnings.warn("--max-words-per-line has no effect with --max-line-width")
|
warnings.warn("--max-words-per-line has no effect with --max-line-width")
|
||||||
for audio_path in args.pop("audio"):
|
|
||||||
|
for audio_obj in args.pop("audio"):
|
||||||
|
if audio_obj == "-":
|
||||||
|
# receive the contents from stdin rather than read a file
|
||||||
|
audio_obj = audio.load_audio(from_stdin=True)
|
||||||
|
|
||||||
|
output_name = output_name or "content"
|
||||||
|
else:
|
||||||
|
output_name = output_name or pathlib.Path(audio_obj).stem
|
||||||
try:
|
try:
|
||||||
result = transcribe(
|
result = transcribe(
|
||||||
audio_path,
|
audio_obj,
|
||||||
path_or_hf_repo=path_or_hf_repo,
|
path_or_hf_repo=path_or_hf_repo,
|
||||||
**args,
|
**args,
|
||||||
)
|
)
|
||||||
writer(result, audio_path, **writer_args)
|
writer(result, output_name, **writer_args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
|
print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -58,11 +58,12 @@ def detect_language(
|
|||||||
logits = model.logits(x, mel)[:, 0]
|
logits = model.logits(x, mel)[:, 0]
|
||||||
|
|
||||||
# collect detected languages; suppress all non-language tokens
|
# collect detected languages; suppress all non-language tokens
|
||||||
mask = np.full(logits.shape[-1], -np.inf, dtype=np.float32)
|
mask = mx.full(logits.shape[-1], -mx.inf, dtype=mx.float32)
|
||||||
mask[list(tokenizer.all_language_tokens)] = 0.0
|
mask[list(tokenizer.all_language_tokens)] = 0.0
|
||||||
logits += mx.array(mask)
|
logits += mask
|
||||||
language_tokens = mx.argmax(logits, axis=-1)
|
language_tokens = mx.argmax(logits, axis=-1)
|
||||||
language_token_probs = mx.softmax(logits, axis=-1)
|
language_token_probs = mx.softmax(logits, axis=-1)
|
||||||
|
language_token_probs = np.array(language_token_probs)
|
||||||
language_probs = [
|
language_probs = [
|
||||||
{
|
{
|
||||||
c: language_token_probs[i, j].item()
|
c: language_token_probs[i, j].item()
|
||||||
@ -129,17 +130,12 @@ class DecodingResult:
|
|||||||
|
|
||||||
|
|
||||||
class Inference:
|
class Inference:
|
||||||
def __init__(self, model: "Whisper", initial_token_length: int):
|
def __init__(self, model: "Whisper"):
|
||||||
self.model: "Whisper" = model
|
self.model: "Whisper" = model
|
||||||
self.initial_token_length = initial_token_length
|
|
||||||
self.kv_cache = None
|
self.kv_cache = None
|
||||||
|
|
||||||
def logits(self, tokens: mx.array, audio_features: mx.array) -> mx.array:
|
def logits(self, tokens: mx.array, audio_features: mx.array) -> mx.array:
|
||||||
"""Perform a forward pass on the decoder and return per-token logits"""
|
"""Perform a forward pass on the decoder and return per-token logits"""
|
||||||
if tokens.shape[-1] > self.initial_token_length:
|
|
||||||
# only need to use the last token except in the first forward pass
|
|
||||||
tokens = tokens[:, -1:]
|
|
||||||
|
|
||||||
logits, self.kv_cache, _ = self.model.decoder(
|
logits, self.kv_cache, _ = self.model.decoder(
|
||||||
tokens, audio_features, kv_cache=self.kv_cache
|
tokens, audio_features, kv_cache=self.kv_cache
|
||||||
)
|
)
|
||||||
@ -251,6 +247,11 @@ class TokenDecoder:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@mx.compile
|
||||||
|
def categorical(logits, temp):
|
||||||
|
return mx.random.categorical(logits / temp)
|
||||||
|
|
||||||
|
|
||||||
class GreedyDecoder(TokenDecoder):
|
class GreedyDecoder(TokenDecoder):
|
||||||
def __init__(self, temperature: float, eot: int):
|
def __init__(self, temperature: float, eot: int):
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
@ -262,10 +263,8 @@ class GreedyDecoder(TokenDecoder):
|
|||||||
if self.temperature == 0:
|
if self.temperature == 0:
|
||||||
next_tokens = logits.argmax(axis=-1)
|
next_tokens = logits.argmax(axis=-1)
|
||||||
else:
|
else:
|
||||||
next_tokens = mx.random.categorical(logits=logits / self.temperature)
|
next_tokens = categorical(logits, self.temperature)
|
||||||
|
|
||||||
next_tokens = mx.argmax(logits, axis=-1)
|
|
||||||
logits = logits.astype(mx.float32)
|
|
||||||
logprobs = logits - mx.logsumexp(logits, axis=-1)
|
logprobs = logits - mx.logsumexp(logits, axis=-1)
|
||||||
|
|
||||||
current_logprobs = logprobs[mx.arange(logprobs.shape[0]), next_tokens]
|
current_logprobs = logprobs[mx.arange(logprobs.shape[0]), next_tokens]
|
||||||
@ -281,7 +280,7 @@ class GreedyDecoder(TokenDecoder):
|
|||||||
def finalize(self, tokens: mx.array, sum_logprobs: mx.array):
|
def finalize(self, tokens: mx.array, sum_logprobs: mx.array):
|
||||||
# make sure each sequence has at least one EOT token at the end
|
# make sure each sequence has at least one EOT token at the end
|
||||||
tokens = mx.pad(tokens, [(0, 0), (0, 0), (0, 1)], constant_values=self.eot)
|
tokens = mx.pad(tokens, [(0, 0), (0, 0), (0, 1)], constant_values=self.eot)
|
||||||
return tokens, sum_logprobs.tolist()
|
return tokens, sum_logprobs
|
||||||
|
|
||||||
|
|
||||||
class LogitFilter:
|
class LogitFilter:
|
||||||
@ -340,10 +339,10 @@ class ApplyTimestampRules(LogitFilter):
|
|||||||
if self.tokenizer.no_timestamps is not None:
|
if self.tokenizer.no_timestamps is not None:
|
||||||
mask[:, self.tokenizer.no_timestamps] = -np.inf
|
mask[:, self.tokenizer.no_timestamps] = -np.inf
|
||||||
|
|
||||||
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
## timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
||||||
for k in range(tokens.shape[0]):
|
tokens = tokens.tolist()
|
||||||
sampled_tokens = tokens[k, self.sample_begin :]
|
for k in range(len(tokens)):
|
||||||
seq = sampled_tokens.tolist()
|
seq = tokens[k][self.sample_begin :]
|
||||||
last_was_timestamp = (
|
last_was_timestamp = (
|
||||||
len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
||||||
)
|
)
|
||||||
@ -368,7 +367,7 @@ class ApplyTimestampRules(LogitFilter):
|
|||||||
last_timestamp += 1
|
last_timestamp += 1
|
||||||
mask[k, self.tokenizer.timestamp_begin : last_timestamp] = -np.inf
|
mask[k, self.tokenizer.timestamp_begin : last_timestamp] = -np.inf
|
||||||
|
|
||||||
if tokens.shape[1] == self.sample_begin:
|
if len(tokens[0]) == self.sample_begin:
|
||||||
# suppress generating non-timestamp tokens at the beginning
|
# suppress generating non-timestamp tokens at the beginning
|
||||||
mask[:, : self.tokenizer.timestamp_begin] = -np.inf
|
mask[:, : self.tokenizer.timestamp_begin] = -np.inf
|
||||||
|
|
||||||
@ -380,16 +379,20 @@ class ApplyTimestampRules(LogitFilter):
|
|||||||
mask[:, last_allowed + 1 :] = -np.inf
|
mask[:, last_allowed + 1 :] = -np.inf
|
||||||
|
|
||||||
# if sum of probability over timestamps is above any other token, sample timestamp
|
# if sum of probability over timestamps is above any other token, sample timestamp
|
||||||
|
mask = mx.array(mask)
|
||||||
logprobs = logits - mx.logsumexp(logits, axis=-1)
|
logprobs = logits - mx.logsumexp(logits, axis=-1)
|
||||||
for k in range(tokens.shape[0]):
|
timestamp_logprob = logprobs[:, self.tokenizer.timestamp_begin :].logsumexp(
|
||||||
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
|
axis=-1, keepdims=True
|
||||||
axis=-1
|
)
|
||||||
)
|
max_text_token_logprob = logprobs[:, : self.tokenizer.timestamp_begin].max(
|
||||||
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
|
axis=-1, keepdims=True
|
||||||
if timestamp_logprob > max_text_token_logprob:
|
)
|
||||||
mask[k, : self.tokenizer.timestamp_begin] = -np.inf
|
mask[:, : self.tokenizer.timestamp_begin] = mx.where(
|
||||||
|
timestamp_logprob > max_text_token_logprob,
|
||||||
return logits + mx.array(mask, logits.dtype)
|
-mx.inf,
|
||||||
|
mask[:, : self.tokenizer.timestamp_begin],
|
||||||
|
)
|
||||||
|
return logits + mask
|
||||||
|
|
||||||
|
|
||||||
class DecodingTask:
|
class DecodingTask:
|
||||||
@ -424,7 +427,7 @@ class DecodingTask:
|
|||||||
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
|
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
|
||||||
|
|
||||||
# inference: implements the forward pass through the decoder, including kv caching
|
# inference: implements the forward pass through the decoder, including kv caching
|
||||||
self.inference = Inference(model, len(self.initial_tokens))
|
self.inference = Inference(model)
|
||||||
|
|
||||||
# sequence ranker: implements how to rank a group of sampled sequences
|
# sequence ranker: implements how to rank a group of sampled sequences
|
||||||
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
|
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
|
||||||
@ -432,9 +435,6 @@ class DecodingTask:
|
|||||||
# decoder: implements how to select the next tokens, given the autoregressive distribution
|
# decoder: implements how to select the next tokens, given the autoregressive distribution
|
||||||
if options.beam_size is not None:
|
if options.beam_size is not None:
|
||||||
raise NotImplementedError("Beam search decoder is not yet implemented")
|
raise NotImplementedError("Beam search decoder is not yet implemented")
|
||||||
# self.decoder = BeamSearchDecoder(
|
|
||||||
# options.beam_size, tokenizer.eot, self.inference, options.patience
|
|
||||||
# )
|
|
||||||
else:
|
else:
|
||||||
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
|
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
|
||||||
|
|
||||||
@ -448,6 +448,7 @@ class DecodingTask:
|
|||||||
self.logit_filters.append(
|
self.logit_filters.append(
|
||||||
SuppressTokens(self._get_suppress_tokens(), model.dims.n_vocab)
|
SuppressTokens(self._get_suppress_tokens(), model.dims.n_vocab)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not options.without_timestamps:
|
if not options.without_timestamps:
|
||||||
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
|
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
|
||||||
max_initial_timestamp_index = None
|
max_initial_timestamp_index = None
|
||||||
@ -570,48 +571,59 @@ class DecodingTask:
|
|||||||
|
|
||||||
def _main_loop(self, audio_features: mx.array, tokens: mx.array):
|
def _main_loop(self, audio_features: mx.array, tokens: mx.array):
|
||||||
n_batch = tokens.shape[0]
|
n_batch = tokens.shape[0]
|
||||||
sum_logprobs: mx.array = mx.zeros(n_batch)
|
sum_logprobs = mx.zeros(n_batch)
|
||||||
no_speech_probs = [np.nan] * n_batch
|
|
||||||
|
|
||||||
try:
|
def _step(inputs, audio_features, tokens, sum_logprobs):
|
||||||
for i in range(self.sample_len):
|
pre_logits = self.inference.logits(inputs, audio_features)
|
||||||
logits = self.inference.logits(tokens, audio_features)
|
|
||||||
|
|
||||||
if (
|
# consider the logits at the last token only
|
||||||
i == 0 and self.tokenizer.no_speech is not None
|
logits = pre_logits[:, -1]
|
||||||
): # save no_speech_probs
|
|
||||||
probs_at_sot = mx.softmax(
|
|
||||||
logits[:, self.sot_index].astype(mx.float32), axis=-1
|
|
||||||
)
|
|
||||||
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
|
||||||
|
|
||||||
# now we need to consider the logits at the last token only
|
# apply the logit filters, e.g. for suppressing or applying penalty to
|
||||||
logits = logits[:, -1]
|
for logit_filter in self.logit_filters:
|
||||||
|
logits = logit_filter.apply(logits, tokens)
|
||||||
|
|
||||||
# apply the logit filters, e.g. for suppressing or applying penalty to
|
# expand the tokens tensor with the selected next tokens
|
||||||
for logit_filter in self.logit_filters:
|
tokens, completed, sum_logprobs = self.decoder.update(
|
||||||
logits = logit_filter.apply(logits, tokens)
|
tokens, logits, sum_logprobs
|
||||||
|
)
|
||||||
|
return tokens, completed, sum_logprobs, pre_logits
|
||||||
|
|
||||||
# expand the tokens tensor with the selected next tokens
|
tokens, completed, sum_logprobs, pre_logits = _step(
|
||||||
tokens, completed, sum_logprobs = self.decoder.update(
|
tokens, audio_features, tokens, sum_logprobs
|
||||||
tokens, logits, sum_logprobs
|
)
|
||||||
)
|
if self.tokenizer.no_speech is not None: # compute no_speech_probs
|
||||||
|
probs_at_sot = mx.softmax(pre_logits[:, self.sot_index], axis=-1)
|
||||||
|
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech]
|
||||||
|
else:
|
||||||
|
no_speech_probs = mx.full(n_batch, mx.nan)
|
||||||
|
mx.async_eval(completed, tokens, sum_logprobs, no_speech_probs)
|
||||||
|
|
||||||
if completed or tokens.shape[-1] > self.n_ctx:
|
for i in range(1, self.sample_len):
|
||||||
break
|
inputs = tokens[:, -1:]
|
||||||
finally:
|
if tokens.shape[-1] > self.n_ctx:
|
||||||
self.inference.reset()
|
break
|
||||||
|
next_tokens, next_completed, next_sum_logprobs, _ = _step(
|
||||||
|
inputs, audio_features, tokens, sum_logprobs
|
||||||
|
)
|
||||||
|
mx.async_eval(next_completed, next_tokens, next_sum_logprobs)
|
||||||
|
if completed:
|
||||||
|
break
|
||||||
|
tokens = next_tokens
|
||||||
|
completed = next_completed
|
||||||
|
sum_logprobs = next_sum_logprobs
|
||||||
|
|
||||||
return tokens, sum_logprobs, no_speech_probs
|
return tokens, sum_logprobs, no_speech_probs
|
||||||
|
|
||||||
def run(self, mel: mx.array) -> List[DecodingResult]:
|
def run(self, mel: mx.array) -> List[DecodingResult]:
|
||||||
|
self.inference.reset()
|
||||||
self.decoder.reset()
|
self.decoder.reset()
|
||||||
tokenizer: Tokenizer = self.tokenizer
|
tokenizer: Tokenizer = self.tokenizer
|
||||||
n_audio: int = mel.shape[0]
|
n_audio: int = mel.shape[0]
|
||||||
|
|
||||||
audio_features: mx.array = self._get_audio_features(mel) # encoder forward pass
|
audio_features: mx.array = self._get_audio_features(mel) # encoder forward pass
|
||||||
tokens: np.array = np.array(self.initial_tokens)
|
tokens: mx.array = mx.array(self.initial_tokens)
|
||||||
tokens = np.broadcast_to(tokens, (n_audio, len(self.initial_tokens))).copy()
|
tokens = mx.broadcast_to(tokens, (n_audio, len(self.initial_tokens)))
|
||||||
|
|
||||||
# detect language if requested, overwriting the language token
|
# detect language if requested, overwriting the language token
|
||||||
languages, language_probs = self._detect_language(audio_features, tokens)
|
languages, language_probs = self._detect_language(audio_features, tokens)
|
||||||
@ -626,7 +638,6 @@ class DecodingTask:
|
|||||||
]
|
]
|
||||||
|
|
||||||
# repeat tokens by the group size, for beam search or best-of-n sampling
|
# repeat tokens by the group size, for beam search or best-of-n sampling
|
||||||
tokens = mx.array(tokens)
|
|
||||||
if self.n_group > 1:
|
if self.n_group > 1:
|
||||||
tokens = tokens[:, None, :]
|
tokens = tokens[:, None, :]
|
||||||
tokens = mx.broadcast_to(
|
tokens = mx.broadcast_to(
|
||||||
@ -649,7 +660,13 @@ class DecodingTask:
|
|||||||
|
|
||||||
# get the final candidates for each group, and slice between the first sampled token and EOT
|
# get the final candidates for each group, and slice between the first sampled token and EOT
|
||||||
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
||||||
tokens = tokens[..., self.sample_begin :].tolist()
|
tokens = tokens[..., self.sample_begin :]
|
||||||
|
|
||||||
|
# eval and convert to list
|
||||||
|
mx.eval(tokens, sum_logprobs, no_speech_probs)
|
||||||
|
tokens = tokens.tolist()
|
||||||
|
sum_logprobs = sum_logprobs.tolist()
|
||||||
|
no_speech_probs = no_speech_probs.tolist()
|
||||||
tokens = [[t[: t.index(tokenizer.eot)] for t in s] for s in tokens]
|
tokens = [[t[: t.index(tokenizer.eot)] for t in s] for s in tokens]
|
||||||
|
|
||||||
# select the top-ranked sample in each group
|
# select the top-ranked sample in each group
|
||||||
|
@ -26,7 +26,10 @@ def load_model(
|
|||||||
|
|
||||||
model_args = whisper.ModelDimensions(**config)
|
model_args = whisper.ModelDimensions(**config)
|
||||||
|
|
||||||
weights = mx.load(str(model_path / "weights.npz"))
|
wf = model_path / "weights.safetensors"
|
||||||
|
if not wf.exists():
|
||||||
|
wf = model_path / "weights.npz"
|
||||||
|
weights = mx.load(str(wf))
|
||||||
|
|
||||||
model = whisper.Whisper(model_args, dtype)
|
model = whisper.Whisper(model_args, dtype)
|
||||||
|
|
||||||
|
@ -293,6 +293,7 @@ def transcribe(
|
|||||||
|
|
||||||
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
||||||
result: DecodingResult = decode_with_fallback(mel_segment)
|
result: DecodingResult = decode_with_fallback(mel_segment)
|
||||||
|
|
||||||
tokens = np.array(result.tokens)
|
tokens = np.array(result.tokens)
|
||||||
|
|
||||||
if no_speech_threshold is not None:
|
if no_speech_threshold is not None:
|
||||||
|
@ -80,12 +80,11 @@ class MultiHeadAttention(nn.Module):
|
|||||||
qk = q @ k
|
qk = q @ k
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
qk = qk + mask[:n_ctx, :n_ctx]
|
qk = qk + mask[:n_ctx, :n_ctx]
|
||||||
qk = qk.astype(mx.float32)
|
|
||||||
|
|
||||||
w = mx.softmax(qk, axis=-1).astype(q.dtype)
|
w = mx.softmax(qk, axis=-1, precise=True)
|
||||||
out = (w @ v).transpose(0, 2, 1, 3)
|
out = (w @ v).transpose(0, 2, 1, 3)
|
||||||
out = out.reshape(n_batch, n_ctx, n_state)
|
out = out.reshape(n_batch, n_ctx, n_state)
|
||||||
return out, qk
|
return out, qk.astype(mx.float32)
|
||||||
|
|
||||||
|
|
||||||
class ResidualAttentionBlock(nn.Module):
|
class ResidualAttentionBlock(nn.Module):
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import pathlib
|
||||||
import re
|
import re
|
||||||
import sys
|
|
||||||
import zlib
|
|
||||||
from typing import Callable, List, Optional, TextIO
|
from typing import Callable, List, Optional, TextIO
|
||||||
|
|
||||||
|
|
||||||
@ -43,15 +41,13 @@ class ResultWriter:
|
|||||||
self.output_dir = output_dir
|
self.output_dir = output_dir
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
|
self, result: dict, output_name: str, options: Optional[dict] = None, **kwargs
|
||||||
):
|
):
|
||||||
audio_basename = os.path.basename(audio_path)
|
output_path = (pathlib.Path(self.output_dir) / output_name).with_suffix(
|
||||||
audio_basename = os.path.splitext(audio_basename)[0]
|
f".{self.extension}"
|
||||||
output_path = os.path.join(
|
|
||||||
self.output_dir, audio_basename + "." + self.extension
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(output_path, "w", encoding="utf-8") as f:
|
with output_path.open("wt", encoding="utf-8") as f:
|
||||||
self.write_result(result, file=f, options=options, **kwargs)
|
self.write_result(result, file=f, options=options, **kwargs)
|
||||||
|
|
||||||
def write_result(
|
def write_result(
|
||||||
|
Loading…
Reference in New Issue
Block a user