Merge branch 'ml-explore:main' into adding-support-for-mamba2

This commit is contained in:
Gökdeniz Gülmez 2024-11-10 16:36:02 +01:00 committed by GitHub
commit 49d3f188f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
58 changed files with 1092 additions and 536 deletions

View File

@ -188,7 +188,7 @@ The adapters are saved in `mlx_output` and can be used directly by the
```shell ```shell
python txt2image.py --model dev --save-raw --image-size 512x512 --n-images 1 \ python txt2image.py --model dev --save-raw --image-size 512x512 --n-images 1 \
--adapter mlx_output/0001200_adapters.safetensors \ --adapter mlx_output/final_adapters.safetensors \
--fuse-adapter \ --fuse-adapter \
--no-t5-padding \ --no-t5-padding \
'A photo of an sks dog lying on the sand at a beach in Greece' 'A photo of an sks dog lying on the sand at a beach in Greece'

View File

@ -13,7 +13,7 @@ from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten, tree_map, tree_reduce from mlx.utils import tree_flatten, tree_map, tree_reduce
from PIL import Image from PIL import Image
from flux import FluxPipeline, Trainer, load_dataset from flux import FluxPipeline, Trainer, load_dataset, save_config
def generate_progress_images(iteration, flux, args): def generate_progress_images(iteration, flux, args):
@ -43,10 +43,10 @@ def generate_progress_images(iteration, flux, args):
im.save(out_file) im.save(out_file)
def save_adapters(iteration, flux, args): def save_adapters(adapter_name, flux, args):
out_dir = Path(args.output_dir) out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True) out_dir.mkdir(parents=True, exist_ok=True)
out_file = out_dir / f"{iteration:07d}_adapters.safetensors" out_file = out_dir / adapter_name
print(f"Saving {str(out_file)}") print(f"Saving {str(out_file)}")
mx.save_safetensors( mx.save_safetensors(
@ -157,6 +157,10 @@ if __name__ == "__main__":
parser = setup_arg_parser() parser = setup_arg_parser()
args = parser.parse_args() args = parser.parse_args()
output_path = Path(args.output_dir)
output_path.mkdir(parents=True, exist_ok=True)
save_config(vars(args), output_path / "adapter_config.json")
# Load the model and set it up for LoRA training. We use the same random # Load the model and set it up for LoRA training. We use the same random
# state when creating the LoRA layers so all workers will have the same # state when creating the LoRA layers so all workers will have the same
# initial weights. # initial weights.
@ -278,8 +282,11 @@ if __name__ == "__main__":
generate_progress_images(i + 1, flux, args) generate_progress_images(i + 1, flux, args)
if (i + 1) % args.checkpoint_every == 0: if (i + 1) % args.checkpoint_every == 0:
save_adapters(i + 1, flux, args) save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args)
if (i + 1) % 10 == 0: if (i + 1) % 10 == 0:
losses = [] losses = []
tic = time.time() tic = time.time()
save_adapters("final_adapters.safetensors", flux, args)
print(f"Training successful. Saved final weights to {args.adapter_file}.")

View File

@ -12,4 +12,5 @@ from .utils import (
load_flow_model, load_flow_model,
load_t5, load_t5,
load_t5_tokenizer, load_t5_tokenizer,
save_config,
) )

View File

@ -3,7 +3,8 @@
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from pathlib import Path
from typing import Optional, Union
import mlx.core as mx import mlx.core as mx
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
@ -207,3 +208,23 @@ def load_clip_tokenizer(name: str):
def load_t5_tokenizer(name: str, pad: bool = True): def load_t5_tokenizer(name: str, pad: bool = True):
model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model") model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model")
return T5Tokenizer(model_file, 256 if "schnell" in name else 512) return T5Tokenizer(model_file, 256 if "schnell" in name else 512)
def save_config(
config: dict,
config_path: Union[str, Path],
) -> None:
"""Save the model configuration to the ``config_path``.
The final configuration will be sorted before saving for better readability.
Args:
config (dict): The model configuration.
config_path (Union[str, Path]): Model configuration file path.
"""
# Sort the config for better readability
config = dict(sorted(config.items()))
# Write the config to the provided file
with open(config_path, "w") as fid:
json.dump(config, fid, indent=4)

View File

@ -101,7 +101,8 @@ To see a description of all the arguments you can do:
#### Streaming #### Streaming
For streaming generation, use the `stream_generate` function. This returns a For streaming generation, use the `stream_generate` function. This returns a
generator object which streams the output text. For example, generator object which streams the output text, token, and log probabilities.
For example,
```python ```python
from mlx_lm import load, stream_generate from mlx_lm import load, stream_generate
@ -116,7 +117,7 @@ prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
) )
for t in stream_generate(model, tokenizer, prompt, max_tokens=512): for text, *_ in stream_generate(model, tokenizer, prompt, max_tokens=512):
print(t, end="", flush=True) print(t, end="", flush=True)
print() print()
``` ```
@ -221,6 +222,7 @@ Here are a few examples of Hugging Face models that work with this example:
- [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct) - [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct)
- [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b) - [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b)
- [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b) - [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b)
- [tiiuae/falcon-mamba-7b-instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct)
Most Most
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending), [Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),
@ -248,3 +250,28 @@ model, tokenizer = load(
tokenizer_config={"eos_token": "<|endoftext|>", "trust_remote_code": True}, tokenizer_config={"eos_token": "<|endoftext|>", "trust_remote_code": True},
) )
``` ```
### Large Models
> [!NOTE]
This requires macOS 15.0 or higher to work.
Models which are large relative to the total RAM available on the machine can
be slow. `mlx-lm` will attempt to make them faster by wiring the memory
occupied by the model and cache. This requires macOS 15 or higher to
work.
If you see the following warning message:
> [WARNING] Generating with a model that requires ...
then the model will likely be slow on the given machine. If the model fits in
RAM then it can often be sped up by increasing the system wired memory limit.
To increase the limit, set the following `sysctl`:
```bash
sudo sysctl iogpu.wired_limit_mb=N
```
The value `N` should be larger than the size of the model in megabytes but
smaller than the memory size of the machine.

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
__version__ = "0.19.1" __version__ = "0.19.3"

View File

@ -8,7 +8,9 @@ import time
import mlx.core as mx import mlx.core as mx
from .models.cache import make_prompt_cache, save_prompt_cache from .models.cache import make_prompt_cache, save_prompt_cache
from .utils import load from .utils import load, maybe_quantize_kv_cache
DEFAULT_QUANTIZED_KV_START = 5000
def setup_arg_parser(): def setup_arg_parser():
@ -70,6 +72,26 @@ def setup_arg_parser():
required=True, required=True,
help="Message to be processed by the model ('-' reads from stdin)", help="Message to be processed by the model ('-' reads from stdin)",
) )
parser.add_argument(
"--kv-bits",
type=int,
help="Number of bits for KV cache quantization. "
"Defaults to no quantization.",
default=None,
)
parser.add_argument(
"--kv-group-size",
type=int,
help="Group size for KV cache quantization.",
default=64,
)
parser.add_argument(
"--quantized-kv-start",
help="When --kv-bits is set, start quantizing the KV cache "
"from this step onwards.",
type=int,
default=DEFAULT_QUANTIZED_KV_START,
)
return parser return parser
@ -127,8 +149,10 @@ def main():
start = time.time() start = time.time()
max_msg_len = 0 max_msg_len = 0
while y.size > 0: while y.size > 0:
model(y[:step_size][None], cache=cache) model(y[:step_size][None], cache=cache)
mx.eval([c.state for c in cache]) mx.eval([c.state for c in cache])
mx.metal.clear_cache()
processed += min(y.size, step_size) processed += min(y.size, step_size)
y = y[step_size:] y = y[step_size:]
current = time.time() current = time.time()
@ -136,15 +160,19 @@ def main():
msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)" msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)"
max_msg_len = max(max_msg_len, len(msg)) max_msg_len = max(max_msg_len, len(msg))
print(msg + " " * (max_msg_len - len(msg)), end="", flush=True) print(msg + " " * (max_msg_len - len(msg)), end="", flush=True)
maybe_quantize_kv_cache(
cache, args.quantized_kv_start, args.kv_group_size, args.kv_bits
)
print() print()
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB") print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB")
print("Saving...") print("Saving...")
metadata = {} metadata = {}
metadata["model"] = args.model metadata["model"] = args.model
metadata["chat_template"] = tokenizer.chat_template metadata["chat_template"] = tokenizer.chat_template
metadata["tokenizer_config"] = json.dumps(tokenizer_config) metadata["tokenizer_config"] = json.dumps(tokenizer_config)
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
save_prompt_cache(args.prompt_cache_file, cache, metadata) save_prompt_cache(args.prompt_cache_file, cache, metadata)

View File

@ -11,6 +11,7 @@ from .utils import load, stream_generate
DEFAULT_TEMP = 0.0 DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0 DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0 DEFAULT_SEED = 0
DEFAULT_MAX_TOKENS = 256
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
@ -41,6 +42,13 @@ def setup_arg_parser():
help="Set the maximum key-value cache size", help="Set the maximum key-value cache size",
default=None, default=None,
) )
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=DEFAULT_MAX_TOKENS,
help="Maximum number of tokens to generate",
)
return parser return parser
@ -56,7 +64,7 @@ def main():
tokenizer_config={"trust_remote_code": True}, tokenizer_config={"trust_remote_code": True},
) )
print(f"[INFO] Starting chat sessiong with {args.model}. To exit, enter 'q'.") print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.")
prompt_cache = make_prompt_cache(model, args.max_kv_size) prompt_cache = make_prompt_cache(model, args.max_kv_size)
while True: while True:
query = input(">> ") query = input(">> ")
@ -66,10 +74,11 @@ def main():
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
) )
for response in stream_generate( for response, *_ in stream_generate(
model, model,
tokenizer, tokenizer,
prompt, prompt,
args.max_tokens,
temp=args.temp, temp=args.temp,
top_p=args.top_p, top_p=args.top_p,
prompt_cache=prompt_cache, prompt_cache=prompt_cache,

View File

@ -6,15 +6,18 @@ import sys
import mlx.core as mx import mlx.core as mx
from .models.cache import load_prompt_cache from .models.cache import QuantizedKVCache, load_prompt_cache
from .utils import generate, load from .utils import generate, load
DEFAULT_PROMPT = "hello" DEFAULT_PROMPT = "hello"
DEFAULT_MAX_TOKENS = 100 DEFAULT_MAX_TOKENS = 100
DEFAULT_TEMP = 0.0 DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0 DEFAULT_TOP_P = 1.0
DEFAULT_MIN_P = 0.0
DEFAULT_MIN_TOKENS_TO_KEEP = 1
DEFAULT_SEED = 0 DEFAULT_SEED = 0
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
DEFAULT_QUANTIZED_KV_START = 5000
def str2bool(string): def str2bool(string):
@ -51,6 +54,7 @@ def setup_arg_parser():
) )
parser.add_argument( parser.add_argument(
"--prompt", "--prompt",
"-p",
default=DEFAULT_PROMPT, default=DEFAULT_PROMPT,
help="Message to be processed by the model ('-' reads from stdin)", help="Message to be processed by the model ('-' reads from stdin)",
) )
@ -67,6 +71,15 @@ def setup_arg_parser():
parser.add_argument( parser.add_argument(
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p" "--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
) )
parser.add_argument(
"--min-p", type=float, default=DEFAULT_MIN_P, help="Sampling min-p"
)
parser.add_argument(
"--min-tokens-to-keep",
type=float,
default=DEFAULT_MIN_TOKENS_TO_KEEP,
help="Minimum tokens to keep for min-p sampling.",
)
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed") parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
parser.add_argument( parser.add_argument(
"--ignore-chat-template", "--ignore-chat-template",
@ -89,12 +102,6 @@ def setup_arg_parser():
action="store_true", action="store_true",
help="Colorize output based on T[0] probability", help="Colorize output based on T[0] probability",
) )
parser.add_argument(
"--cache-limit-gb",
type=int,
default=None,
help="Set the MLX cache limit in GB",
)
parser.add_argument( parser.add_argument(
"--max-kv-size", "--max-kv-size",
type=int, type=int,
@ -107,6 +114,26 @@ def setup_arg_parser():
default=None, default=None,
help="A file containing saved KV caches to avoid recomputing them", help="A file containing saved KV caches to avoid recomputing them",
) )
parser.add_argument(
"--kv-bits",
type=int,
help="Number of bits for KV cache quantization. "
"Defaults to no quantization.",
default=None,
)
parser.add_argument(
"--kv-group-size",
type=int,
help="Group size for KV cache quantization.",
default=64,
)
parser.add_argument(
"--quantized-kv-start",
help="When --kv-bits is set, start quantizing the KV cache "
"from this step onwards.",
type=int,
default=DEFAULT_QUANTIZED_KV_START,
)
return parser return parser
@ -143,14 +170,21 @@ def main():
mx.random.seed(args.seed) mx.random.seed(args.seed)
if args.cache_limit_gb is not None:
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
# Load the prompt cache and metadata if a cache file is provided # Load the prompt cache and metadata if a cache file is provided
using_cache = args.prompt_cache_file is not None using_cache = args.prompt_cache_file is not None
if using_cache: if using_cache:
prompt_cache, metadata = load_prompt_cache( prompt_cache, metadata = load_prompt_cache(
args.prompt_cache_file, return_metadata=True args.prompt_cache_file,
return_metadata=True,
)
if isinstance(prompt_cache[0], QuantizedKVCache):
if args.kv_bits is not None and args.kv_bits != prompt_cache[0].bits:
raise ValueError(
"--kv-bits does not match the kv cache loaded from --prompt-cache-file."
)
if args.kv_group_size != prompt_cache[0].group_size:
raise ValueError(
"--kv-group-size does not match the kv cache loaded from --prompt-cache-file."
) )
# Building tokenizer_config # Building tokenizer_config
@ -225,8 +259,13 @@ def main():
formatter=formatter, formatter=formatter,
temp=args.temp, temp=args.temp,
top_p=args.top_p, top_p=args.top_p,
min_p=args.min_p,
min_tokens_to_keep=args.min_tokens_to_keep,
max_kv_size=args.max_kv_size, max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None, prompt_cache=prompt_cache if using_cache else None,
kv_bits=args.kv_bits,
kv_group_size=args.kv_group_size,
quantized_kv_start=args.quantized_kv_start,
) )
if not args.verbose: if not args.verbose:
print(response) print(response)

View File

@ -5,6 +5,9 @@ from dataclasses import dataclass
from typing import Any, Optional from typing import Any, Optional
import mlx.core as mx import mlx.core as mx
from mlx.utils import tree_map
from .cache import QuantizedKVCache
@dataclass @dataclass
@ -39,7 +42,7 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
if cache is not None and cache[0] is not None: if cache is not None and cache[0] is not None:
c = cache[0] c = cache[0]
if hasattr(c, "max_size"): if hasattr(c, "max_size"):
offset = min(c.max_size - 1, c.offset) offset = min(c.max_size, c.offset)
window_size = c.max_size window_size = c.max_size
else: else:
offset = c.offset offset = c.offset
@ -48,3 +51,63 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
else: else:
mask = None mask = None
return mask return mask
def quantized_scaled_dot_product_attention(
queries: mx.array,
q_keys: tuple[mx.array, mx.array, mx.array],
q_values: tuple[mx.array, mx.array, mx.array],
scale: float,
mask: Optional[mx.array],
group_size: int = 64,
bits: int = 8,
) -> mx.array:
B, n_q_heads, L, D = queries.shape
n_kv_heads = q_keys[0].shape[-3]
n_repeats = n_q_heads // n_kv_heads
queries *= scale
if n_repeats > 1:
queries = mx.reshape(queries, (B, n_kv_heads, n_repeats, L, D))
q_keys = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys)
q_values = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values)
scores = mx.quantized_matmul(
queries, *q_keys, transpose=True, group_size=group_size, bits=bits
)
if mask is not None:
scores += mask
scores = mx.softmax(scores, axis=-1, precise=True)
out = mx.quantized_matmul(
scores, *q_values, transpose=False, group_size=group_size, bits=bits
)
if n_repeats > 1:
out = mx.reshape(out, (B, n_q_heads, L, D))
return out
def scaled_dot_product_attention(
queries,
keys,
values,
cache,
scale: float,
mask: Optional[mx.array],
) -> mx.array:
if isinstance(cache, QuantizedKVCache):
return quantized_scaled_dot_product_attention(
queries,
keys,
values,
scale=scale,
mask=mask,
group_size=cache.group_size,
bits=cache.bits,
)
else:
return mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=scale, mask=mask
)

View File

@ -4,10 +4,13 @@ from typing import Any, Dict, List, Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx.utils import tree_flatten, tree_unflatten from mlx.utils import tree_flatten, tree_map, tree_unflatten
def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> List[Any]: def make_prompt_cache(
model: nn.Module,
max_kv_size: Optional[int] = None,
) -> List[Any]:
""" """
Construct the model's cache for use when cgeneration. Construct the model's cache for use when cgeneration.
@ -126,6 +129,88 @@ class _BaseCache:
return False return False
class QuantizedKVCache(_BaseCache):
def __init__(self, group_size: int = 64, bits: int = 8):
self.keys = None
self.values = None
self.offset = 0
self.step = 256
self.group_size = group_size
self.bits = bits
def update_and_fetch(self, keys, values):
B, n_kv_heads, num_steps, k_head_dim = keys.shape
v_head_dim = values.shape[-1]
prev = self.offset
if self.keys is None or (prev + num_steps) > self.keys[0].shape[-2]:
el_per_int = 8 * mx.uint32.size // self.bits
new_steps = (self.step + num_steps - 1) // self.step * self.step
shape = (B, n_kv_heads, new_steps)
def init_quant(dim):
return (
mx.zeros((*shape, dim // el_per_int), dtype=mx.uint32),
mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype),
mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype),
)
def expand_quant(x):
new_x = mx.zeros((*shape, x.shape[-1]), dtype=x.dtype)
return mx.concatenate([x, new_x], axis=-2)
if self.keys is not None:
if prev % self.step != 0:
self.keys, self.values = tree_map(
lambda x: x[..., :prev, :], (self.keys, self.values)
)
self.keys, self.values = tree_map(
expand_quant, (self.keys, self.values)
)
else:
self.keys, self.values = init_quant(k_head_dim), init_quant(v_head_dim)
self.offset += num_steps
keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits)
values = mx.quantize(values, group_size=self.group_size, bits=self.bits)
for i in range(len(self.keys)):
self.keys[i][..., prev : self.offset, :] = keys[i]
self.values[i][..., prev : self.offset, :] = values[i]
return tree_map(lambda x: x[..., : self.offset, :], (self.keys, self.values))
@property
def state(self):
if self.offset == self.keys[0].shape[2]:
return self.keys, self.values
else:
return tree_map(
lambda x: x[..., : self.offset, :], (self.keys, self.values)
)
@state.setter
def state(self, v):
self.keys, self.values = v
@property
def meta_state(self):
return tuple(map(str, (self.step, self.offset, self.group_size, self.bits)))
@meta_state.setter
def meta_state(self, v):
self.step, self.offset, self.group_size, self.bits = map(int, v)
def is_trimmable(self):
return True
def trim(self, n):
n = min(self.offset, n)
self.offset -= n
return n
class KVCache(_BaseCache): class KVCache(_BaseCache):
def __init__(self): def __init__(self):
self.keys = None self.keys = None
@ -180,6 +265,16 @@ class KVCache(_BaseCache):
self.offset -= n self.offset -= n
return n return n
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
quant_cache = QuantizedKVCache(group_size=group_size, bits=bits)
quant_cache.offset = self.offset
if self.keys is not None:
quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits)
quant_cache.values = mx.quantize(
self.values, group_size=group_size, bits=bits
)
return quant_cache
class RotatingKVCache(_BaseCache): class RotatingKVCache(_BaseCache):
@ -230,9 +325,9 @@ class RotatingKVCache(_BaseCache):
self.keys = self._temporal_order(self.keys) self.keys = self._temporal_order(self.keys)
self.values = self._temporal_order(self.values) self.values = self._temporal_order(self.values)
# The largest size is self.max_size + S - 1 to ensure # The largest size is self.max_size + S to ensure
# every token gets at least self.max_size context # every token gets at least self.max_size context
trim_size = self._idx - self.max_size + 1 trim_size = self._idx - self.max_size
self.keys = self._trim(trim_size, self.keys, keys) self.keys = self._trim(trim_size, self.keys, keys)
self.values = self._trim(trim_size, self.values, values) self.values = self._trim(trim_size, self.values, values)
self.offset += keys.shape[2] self.offset += keys.shape[2]
@ -320,6 +415,9 @@ class RotatingKVCache(_BaseCache):
self._idx -= n self._idx -= n
return n return n
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
raise NotImplementedError("RotatingKVCache Quantization NYI")
class MambaCache: class MambaCache:
def __init__(self): def __init__(self):

View File

@ -6,7 +6,7 @@ from typing import Any, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -93,8 +93,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)

View File

@ -7,7 +7,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -74,8 +74,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(output) return self.out_proj(output)

View File

@ -4,7 +4,7 @@ from typing import Any, Dict, Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU from .switch_layers import SwitchGLU
@ -97,8 +97,8 @@ class DeepseekAttention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU from .switch_layers import SwitchGLU
@ -235,8 +235,8 @@ class DeepseekV2Attention(nn.Module):
queries = mx.concatenate([q_nope, q_pe], axis=-1) queries = mx.concatenate([q_nope, q_pe], axis=-1)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -6,7 +6,7 @@ from typing import Any, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -79,8 +79,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)

View File

@ -7,7 +7,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -61,8 +61,8 @@ class Attention(nn.Module):
if cache is not None: if cache is not None:
keys, values = cache.update_and_fetch(keys, values) keys, values = cache.update_and_fetch(keys, values)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)

View File

@ -7,7 +7,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -74,8 +74,8 @@ class Attention(nn.Module):
if cache is not None: if cache is not None:
keys, values = cache.update_and_fetch(keys, values) keys, values = cache.update_and_fetch(keys, values)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.c_proj(output) return self.c_proj(output)

View File

@ -7,7 +7,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
# Based on the transformers implementation at: # Based on the transformers implementation at:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
@ -79,8 +79,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)

View File

@ -6,7 +6,7 @@ from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -141,8 +141,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.wo(output) return self.wo(output)

View File

@ -1,12 +1,12 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -190,9 +190,10 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -23,6 +23,8 @@ class ModelArgs(BaseModelArgs):
use_conv_bias: bool use_conv_bias: bool
time_step_rank: int time_step_rank: int
tie_word_embeddings: bool = True tie_word_embeddings: bool = True
use_bcdt_rms: bool = False
mixer_rms_eps: float = 1e-6
def __post_init__(self): def __post_init__(self):
if not hasattr(self, "hidden_size") and hasattr(self, "d_model"): if not hasattr(self, "hidden_size") and hasattr(self, "d_model"):
@ -44,6 +46,8 @@ class ModelArgs(BaseModelArgs):
if self.time_step_rank == "auto": if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16) self.time_step_rank = math.ceil(self.hidden_size / 16)
if self.model_type == "falcon_mamba":
self.use_bcdt_rms = True
class DepthWiseConv1d(nn.Module): class DepthWiseConv1d(nn.Module):
@ -83,6 +87,11 @@ class MambaBlock(nn.Module):
self.intermediate_size = args.intermediate_size self.intermediate_size = args.intermediate_size
self.time_step_rank = int(args.time_step_rank) self.time_step_rank = int(args.time_step_rank)
self.use_conv_bias = args.use_conv_bias self.use_conv_bias = args.use_conv_bias
self.use_bcdt_rms = args.use_bcdt_rms
if self.use_bcdt_rms:
self.mixer_norm = lambda x: mx.fast.rms_norm(
x, mx.ones(x.shape[-1], x.dtype), eps=args.mixer_rms_eps
)
self.in_proj = nn.Linear( self.in_proj = nn.Linear(
self.hidden_size, self.intermediate_size * 2, bias=args.use_bias self.hidden_size, self.intermediate_size * 2, bias=args.use_bias
@ -126,6 +135,8 @@ class MambaBlock(nn.Module):
], ],
axis=-1, axis=-1,
) )
if self.use_bcdt_rms:
delta, B, C = map(self.mixer_norm, (delta, B, C))
delta = nn.softplus(self.dt_proj(delta)) delta = nn.softplus(self.dt_proj(delta))
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1) new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
if state is not None: if state is not None:

View File

@ -7,7 +7,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -105,8 +105,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
attn_output = mx.fast.scaled_dot_product_attention( attn_output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1) attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1)

View File

@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU from .switch_layers import SwitchGLU
@ -87,8 +87,8 @@ class MixtralAttention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -113,8 +113,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -107,8 +107,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)

View File

@ -7,7 +7,7 @@ from typing import Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -93,8 +93,13 @@ class PhiAttention(nn.Module):
keys = self.rope(keys) keys = self.rope(keys)
scale = math.sqrt(1 / queries.shape[-1]) scale = math.sqrt(1 / queries.shape[-1])
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries.astype(mx.float32), keys, values, scale=scale, mask=mask queries.astype(mx.float32),
keys,
values,
cache=cache,
scale=scale,
mask=mask,
).astype(values.dtype) ).astype(values.dtype)
output = output.moveaxis(2, 1).reshape(B, L, -1) output = output.moveaxis(2, 1).reshape(B, L, -1)

View File

@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .su_rope import SuScaledRotaryEmbedding from .su_rope import SuScaledRotaryEmbedding
@ -107,8 +107,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -8,7 +8,7 @@ from typing import Any, Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -188,8 +188,8 @@ class Attention(nn.Module):
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, scale=self.scale, mask=mask
) )
else: else:
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.dense(output) return self.dense(output)

View File

@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .su_rope import SuScaledRotaryEmbedding from .su_rope import SuScaledRotaryEmbedding
from .switch_layers import SwitchGLU from .switch_layers import SwitchGLU
@ -79,8 +79,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -8,7 +8,7 @@ from typing import Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import create_attention_mask from .base import create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchMLP from .switch_layers import SwitchMLP
@ -71,8 +71,13 @@ class RoPEAttention(nn.Module):
# Finally perform the attention computation # Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1]) scale = math.sqrt(1 / queries.shape[-1])
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries.astype(mx.float32), keys, values, scale=scale, mask=mask queries.astype(mx.float32),
keys,
values,
cache=cache,
scale=scale,
mask=mask,
).astype(values.dtype) ).astype(values.dtype)
output = output.moveaxis(2, 1).reshape(B, L, -1) output = output.moveaxis(2, 1).reshape(B, L, -1)

View File

@ -7,7 +7,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -92,10 +92,11 @@ class Attention(nn.Module):
keys = mx.tile(keys, [1, self.config.n_shared_head, 1, 1]) keys = mx.tile(keys, [1, self.config.n_shared_head, 1, 1])
values = mx.tile(values, [1, self.config.n_shared_head, 1, 1]) values = mx.tile(values, [1, self.config.n_shared_head, 1, 1])
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, queries,
keys, keys,
values, values,
cache=cache,
scale=self.scale, scale=self.scale,
mask=attention_mask, mask=attention_mask,
) )

View File

@ -5,7 +5,7 @@ from dataclasses import dataclass
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -64,8 +64,8 @@ class Attention(nn.Module):
queries = self.rotary_emb(queries) queries = self.rotary_emb(queries)
keys = self.rotary_emb(keys) keys = self.rotary_emb(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)

View File

@ -6,7 +6,7 @@ from typing import Any, Dict, Optional, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -89,8 +89,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU from .switch_layers import SwitchGLU
@ -89,8 +89,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -7,7 +7,7 @@ from typing import List, Literal, Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .cache import MambaCache, RotatingKVCache from .cache import MambaCache, RotatingKVCache
@ -263,8 +263,8 @@ class LocalAttentionBlock(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -6,7 +6,7 @@ from dataclasses import dataclass
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -120,8 +120,8 @@ class Attention(nn.Module):
# Finally perform the attention computation # Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1]) scale = math.sqrt(1 / queries.shape[-1])
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=scale, mask=mask queries, keys, values, cache=cache, scale=scale, mask=mask
).astype(values.dtype) ).astype(values.dtype)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output) return self.o_proj(output)

View File

@ -6,7 +6,7 @@ from typing import Any, Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass @dataclass
@ -64,8 +64,8 @@ class Attention(nn.Module):
queries = self.rope(queries) queries = self.rope(queries)
keys = self.rope(keys) keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention( output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask queries, keys, values, cache=cache, scale=self.scale, mask=mask
) )
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)

View File

@ -1,4 +1,4 @@
mlx>=0.17.0 mlx>=0.19.2
numpy numpy
transformers[sentencepiece]>=4.39.3 transformers[sentencepiece]>=4.39.3
protobuf protobuf

View File

@ -1,10 +1,83 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
from functools import partial from functools import partial
from typing import Callable, Dict, Optional
import mlx.core as mx import mlx.core as mx
def make_sampler(
temp: float = 0.0,
top_p: float = 0.0,
min_p: float = 0.0,
min_tokens_to_keep: int = 1,
) -> Callable[mx.array, mx.array]:
"""
Make a sampler function for use with ``generate_step``.
Args:
temp (float): The temperature for sampling, if 0 the argmax is used.
Default: ``0``.
top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words.
min_p (float, optional): The minimum value (scaled by the top token's
probability) that a token probability must have to be considered.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered by min_p sampling.
Returns:
Callable[mx.array, mx.array]:
A sampler which takes log-probabilities and returns tokens.
"""
if temp == 0:
return lambda x: mx.argmax(x, axis=-1)
elif top_p > 0 and top_p < 1.0:
return lambda x: top_p_sampling(x, top_p, temp)
elif min_p != 0.0:
return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp)
else:
return lambda x: categorical_sampling(x, temp)
def make_logits_processors(
logit_bias: Optional[Dict[int, float]] = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = 20,
):
"""
Make logits processors for use with ``generate_step``.
Args:
repetition_penalty (float, optional): The penalty factor for repeating
tokens.
repetition_context_size (int, optional): The number of tokens to
consider for repetition penalty. Default: ``20``.
logit_bias (dictionary, optional): Additive logit bias.
Returns:
List[Callable[[mx.array, mx.array], mx.array]]:
A list of logits processors. Each processor in the list is a
callable which takes an array of tokens and an array of logits
and returns the updated logits.
"""
logits_processors = []
if logit_bias:
indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values()))
def logit_bias_processor(_, logits):
logits[:, indices] += values
return logits
logits_processors.append(logit_bias_processor)
if repetition_penalty and repetition_penalty != 0.0:
logits_processors.append(
make_repetition_penalty(repetition_penalty, repetition_context_size)
)
return logits_processors
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def min_p_sampling( def min_p_sampling(
logits: mx.array, logits: mx.array,
@ -100,3 +173,36 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def categorical_sampling(logits, temp): def categorical_sampling(logits, temp):
return mx.random.categorical(logits * (1 / temp)) return mx.random.categorical(logits * (1 / temp))
def make_repetition_penalty(penalty: float, context_size: int = 20):
"""
Make repetition penalty processor.
Paper: https://arxiv.org/abs/1909.05858
Args:
penalty (float): The repetition penalty factor to be applied.
context_size (int): The number of previous tokens to use.
Default: ``20``.
Returns:
Callable[[mx.array, List[int]], mx.array]:
The repetition penalty processor.
"""
if penalty < 0 or not isinstance(penalty, float):
raise ValueError(f"penalty must be a non-negative float, got {penalty}")
def repetition_penalty_processor(tokens, logits):
if len(tokens) > 0:
tokens = tokens[-context_size:]
selected_logits = logits[:, tokens]
selected_logits = mx.where(
selected_logits < 0,
selected_logits * penalty,
selected_logits / penalty,
)
logits[:, tokens] = selected_logits
return logits
return repetition_penalty_processor

View File

@ -27,7 +27,7 @@ from huggingface_hub import scan_cache_dir
from ._version import __version__ from ._version import __version__
from .models.cache import make_prompt_cache from .models.cache import make_prompt_cache
from .utils import generate_step, load from .utils import load, stream_generate
def get_system_fingerprint(): def get_system_fingerprint():
@ -64,7 +64,7 @@ def stopping_criteria(
end if it has (`trim_length`). end if it has (`trim_length`).
""" """
if tokens and tokens[-1] == eos_token_id: if tokens and tokens[-1] == eos_token_id:
return StopCondition(stop_met=True, trim_length=1) return StopCondition(stop_met=True, trim_length=0)
for stop_ids in stop_id_sequences: for stop_ids in stop_id_sequences:
if len(tokens) >= len(stop_ids): if len(tokens) >= len(stop_ids):
@ -253,7 +253,7 @@ class APIHandler(BaseHTTPRequestHandler):
self.max_tokens = self.body.get("max_completion_tokens", None) self.max_tokens = self.body.get("max_completion_tokens", None)
if self.max_tokens is None: if self.max_tokens is None:
self.max_tokens = self.body.get("max_tokens", 512) self.max_tokens = self.body.get("max_tokens", 512)
self.temperature = self.body.get("temperature", 1.0) self.temperature = self.body.get("temperature", 0.0)
self.top_p = self.body.get("top_p", 1.0) self.top_p = self.body.get("top_p", 1.0)
self.repetition_penalty = self.body.get("repetition_penalty", 1.0) self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
self.repetition_context_size = self.body.get("repetition_context_size", 20) self.repetition_context_size = self.body.get("repetition_context_size", 20)
@ -290,10 +290,7 @@ class APIHandler(BaseHTTPRequestHandler):
# Call endpoint specific method # Call endpoint specific method
prompt = endpoints[self.path]() prompt = endpoints[self.path]()
self.handle_completion(prompt, stop_id_sequences)
# Call method based on response type
method = self.handle_stream if self.stream else self.handle_completion
method(prompt, stop_id_sequences)
def validate_model_parameters(self): def validate_model_parameters(self):
""" """
@ -452,32 +449,40 @@ class APIHandler(BaseHTTPRequestHandler):
stop_id_sequences (List[List[int]]): A list of stop words passed stop_id_sequences (List[List[int]]): A list of stop words passed
to the stopping_criteria function to the stopping_criteria function
""" """
detokenizer = self.tokenizer.detokenizer
detokenizer.reset()
tokens = [] tokens = []
finish_reason = "length" finish_reason = "length"
stop_sequence_suffix = None stop_sequence_suffix = None
if self.stream:
self.end_headers()
logging.debug(f"Starting stream:")
else:
logging.debug(f"Starting completion:") logging.debug(f"Starting completion:")
token_logprobs = [] token_logprobs = []
top_tokens = [] top_tokens = []
prompt = self.get_prompt_cache(prompt) prompt = self.get_prompt_cache(prompt)
for _, (token, logprobs) in zip( text = ""
range(self.max_tokens), tic = time.perf_counter()
generate_step( for n, (segment, token, logprobs) in enumerate(
prompt=mx.array(prompt), stream_generate(
model=self.model, model=self.model,
tokenizer=self.tokenizer,
prompt=prompt,
max_tokens=self.max_tokens,
temp=self.temperature, temp=self.temperature,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty, repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size, repetition_context_size=self.repetition_context_size,
logit_bias=self.logit_bias, logit_bias=self.logit_bias,
prompt_cache=self.prompt_cache.cache, prompt_cache=self.prompt_cache.cache,
), ),
): ):
detokenizer.add_token(token) if n == 0:
logging.debug(detokenizer.text) prompt_time = time.perf_counter() - tic
tic = time.perf_counter()
text += segment
logging.debug(text)
tokens.append(token) tokens.append(token)
if self.logprobs > 0: if self.logprobs > 0:
@ -498,15 +503,45 @@ class APIHandler(BaseHTTPRequestHandler):
stop_sequence_suffix = self.tokenizer.decode( stop_sequence_suffix = self.tokenizer.decode(
tokens[-stop_condition.trim_length :] tokens[-stop_condition.trim_length :]
) )
text = text[: -len(stop_sequence_suffix)]
break break
self.prompt_cache.tokens.extend(tokens) if self.stream:
detokenizer.finalize() # If the end of tokens overlaps with a stop sequence, generate new
text = ( # tokens until we know if the stop sequence is hit or not
detokenizer.text if any(
if stop_sequence_suffix is None (
else detokenizer.text[: -len(stop_sequence_suffix)] sequence_overlap(tokens, sequence)
for sequence in stop_id_sequences
) )
):
continue
elif segment:
response = self.generate_response(segment, None)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
self.prompt_cache.tokens.extend(tokens)
gen_time = time.perf_counter() - tic
prompt_tps = len(prompt) / prompt_time
gen_tps = len(tokens) / gen_time
peak_mem = mx.metal.get_peak_memory() / 1e9
logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec")
logging.debug(f"Peak memory: {peak_mem:.3f} GB")
if self.stream:
response = self.generate_response(segment, finish_reason)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
if self.stream_options is not None and self.stream_options["include_usage"]:
response = self.completion_usage_response(len(prompt), len(tokens))
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
self.wfile.write("data: [DONE]\n\n".encode())
self.wfile.flush()
else:
response = self.generate_response( response = self.generate_response(
text, text,
finish_reason, finish_reason,
@ -516,7 +551,6 @@ class APIHandler(BaseHTTPRequestHandler):
top_tokens=top_tokens, top_tokens=top_tokens,
tokens=tokens, tokens=tokens,
) )
response_json = json.dumps(response).encode() response_json = json.dumps(response).encode()
indent = "\t" # Backslashes can't be inside of f-strings indent = "\t" # Backslashes can't be inside of f-strings
logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}") logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}")
@ -524,96 +558,9 @@ class APIHandler(BaseHTTPRequestHandler):
# Send an additional Content-Length header when it is known # Send an additional Content-Length header when it is known
self.send_header("Content-Length", str(len(response_json))) self.send_header("Content-Length", str(len(response_json)))
self.end_headers() self.end_headers()
self.wfile.write(response_json) self.wfile.write(response_json)
self.wfile.flush() self.wfile.flush()
def handle_stream(
self,
prompt: List[int],
stop_id_sequences: List[List[int]],
):
"""
Generate response to prompt and foward it to the client using a Server
Sent Events (SSE) stream.
Args:
prompt (mx.array): The tokenized prompt
stop_id_sequences (List[List[int]]): A list of stop words passed to
the stopping_criteria function
"""
# No additional headers are needed, call end_headers
self.end_headers()
detokenizer = self.tokenizer.detokenizer
detokenizer.reset()
tokens = []
stop_sequence_suffix = None
logging.debug(f"Starting stream:")
prompt = self.get_prompt_cache(prompt)
for _, (token, _) in zip(
range(self.max_tokens),
generate_step(
prompt=mx.array(prompt),
model=self.model,
temp=self.temperature,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size,
prompt_cache=self.prompt_cache.cache,
),
):
detokenizer.add_token(token)
logging.debug(detokenizer.text)
tokens.append(token)
stop_condition = stopping_criteria(
tokens,
stop_id_sequences,
self.tokenizer.eos_token_id,
)
if stop_condition.stop_met:
if stop_condition.trim_length:
stop_sequence_suffix = self.tokenizer.decode(
tokens[-stop_condition.trim_length :]
)
break
# If the end of tokens overlaps with a stop sequence, generate new
# tokens until we know if the stop sequence is hit or not
if any(
(sequence_overlap(tokens, sequence) for sequence in stop_id_sequences)
):
continue
new_text = detokenizer.last_segment
if new_text:
response = self.generate_response(new_text, None)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
self.prompt_cache.tokens.extend(tokens)
# check is there any remaining text to send
detokenizer.finalize()
last_segment = detokenizer.last_segment
if last_segment:
if stop_sequence_suffix is not None:
last_segment = last_segment[: -len(stop_sequence_suffix)]
response = self.generate_response(last_segment, "length")
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
if self.stream_options is not None and self.stream_options["include_usage"]:
response = self.completion_usage_response(len(prompt), len(tokens))
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.write("data: [DONE]\n\n".encode())
self.wfile.flush()
def completion_usage_response( def completion_usage_response(
self, self,
prompt_token_count: Optional[int] = None, prompt_token_count: Optional[int] = None,

View File

@ -6,12 +6,6 @@ from transformers import AutoTokenizer
REPLACEMENT_CHAR = "\ufffd" REPLACEMENT_CHAR = "\ufffd"
def _remove_space(x):
if x and x[0] == " ":
return x[1:]
return x
class StreamingDetokenizer: class StreamingDetokenizer:
"""The streaming detokenizer interface so that we can detokenize one token at a time. """The streaming detokenizer interface so that we can detokenize one token at a time.
@ -123,42 +117,42 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
def __init__(self, tokenizer, trim_space=True): def __init__(self, tokenizer, trim_space=True):
self.trim_space = trim_space self.trim_space = trim_space
self._sep = "\u2581".encode()
# Extract the tokens in a list from id to text # Extract the tokens in a list from id to text
self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1) self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1)
for value, tokenid in tokenizer.vocab.items(): for value, tokenid in tokenizer.vocab.items():
self.tokenmap[tokenid] = value if value.startswith("<0x"):
# Replace bytes with their value # Replace bytes with their value
for i in range(len(self.tokenmap)): self.tokenmap[tokenid] = bytes([int(value[3:5], 16)])
if self.tokenmap[i].startswith("<0x"): else:
self.tokenmap[i] = chr(int(self.tokenmap[i][3:5], 16)) self.tokenmap[tokenid] = value.encode()
self.reset() self.reset()
def reset(self): def reset(self):
self.offset = 0 self.offset = 0
self._unflushed = "" self._unflushed = b""
self.text = "" self.text = ""
self.tokens = [] self.tokens = []
def _flush(self):
text = self._unflushed.replace(self._sep, b" ").decode("utf-8")
if not self.text and self.trim_space and text and text[0] == " ":
text = text[1:]
self.text += text
def add_token(self, token): def add_token(self, token):
v = self.tokenmap[token] v = self.tokenmap[token]
if v[0] == "\u2581": if v.startswith(self._sep):
if self.text or not self.trim_space: self._flush()
self.text += self._unflushed.replace("\u2581", " ")
else:
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
self._unflushed = v self._unflushed = v
else: else:
self._unflushed += v self._unflushed += v
def finalize(self): def finalize(self):
if self.text or not self.trim_space: self._flush()
self.text += self._unflushed.replace("\u2581", " ") self._unflushed = b""
else:
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
self._unflushed = ""
class BPEStreamingDetokenizer(StreamingDetokenizer): class BPEStreamingDetokenizer(StreamingDetokenizer):
@ -186,6 +180,8 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
# https://github.com/openai/gpt-2/blob/master/src/encoder.py # https://github.com/openai/gpt-2/blob/master/src/encoder.py
self.make_byte_decoder() self.make_byte_decoder()
self._added_ids = set(tokenizer.added_tokens_decoder.keys())
def reset(self): def reset(self):
self.offset = 0 self.offset = 0
self._unflushed = "" self._unflushed = ""
@ -205,11 +201,16 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
def add_token(self, token): def add_token(self, token):
v = self.tokenmap[token] v = self.tokenmap[token]
if self._byte_decoder[v[0]] == 32: is_added = token in self._added_ids
if is_added or self._byte_decoder[v[0]] == 32:
current_text = bytearray( current_text = bytearray(
self._byte_decoder[c] for c in self._unflushed self._byte_decoder[c] for c in self._unflushed
).decode("utf-8") ).decode("utf-8")
self.text += self._maybe_trim_space(current_text) self.text += self._maybe_trim_space(current_text)
if is_added:
self.text += v
self._unflushed = ""
else:
self._unflushed = v self._unflushed = v
else: else:
self._unflushed += v self._unflushed += v

View File

@ -10,6 +10,7 @@ from typing import Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
@ -84,9 +85,16 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
f" examples but only has {len(dataset)}." f" examples but only has {len(dataset)}."
) )
# If running in distributed mode (N machines) then each one should skip N-1
# samples
step = mx.distributed.init().size()
if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers")
# Make the batches: # Make the batches:
batch_idx = [ batch_idx = [
idx[i : i + batch_size] for i in range(0, len(idx) - batch_size + 1, batch_size) idx[i : i + batch_size : step]
for i in range(0, len(idx) - batch_size + 1, batch_size)
] ]
while True: while True:
@ -112,9 +120,9 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to) max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
max_length_in_batch = min(max_length_in_batch, max_seq_length) max_length_in_batch = min(max_length_in_batch, max_seq_length)
batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32) batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
for j in range(batch_size): for j in range(batch_size // step):
truncated_length = min(lengths[j], max_seq_length) truncated_length = min(lengths[j], max_seq_length)
batch_arr[j, :truncated_length] = batch[j][:truncated_length] batch_arr[j, :truncated_length] = batch[j][:truncated_length]
lengths[j] = ( lengths[j] = (
@ -138,7 +146,7 @@ def evaluate(
loss: callable = default_loss, loss: callable = default_loss,
iterate_batches: callable = iterate_batches, iterate_batches: callable = iterate_batches,
): ):
all_losses = [] all_losses = 0
ntokens = 0 ntokens = 0
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
@ -153,10 +161,14 @@ def evaluate(
), ),
): ):
losses, toks = loss(model, *batch) losses, toks = loss(model, *batch)
all_losses.append((losses * toks).item()) all_losses += losses * toks
ntokens += toks.item() ntokens += toks
mx.eval(all_losses, ntokens)
return np.sum(all_losses) / ntokens all_losses = mx.distributed.all_sum(all_losses)
ntokens = mx.distributed.all_sum(ntokens)
return (all_losses / ntokens).item()
class TrainingCallback: class TrainingCallback:
@ -182,6 +194,11 @@ def train(
training_callback: TrainingCallback = None, training_callback: TrainingCallback = None,
): ):
print(f"Starting training..., iters: {args.iters}") print(f"Starting training..., iters: {args.iters}")
world = mx.distributed.init()
world_size = world.size()
rank = world.rank()
if world_size > 1:
print(f"Node {rank} of {world_size}")
if args.grad_checkpoint: if args.grad_checkpoint:
grad_checkpoint(model.layers[0]) grad_checkpoint(model.layers[0])
@ -192,6 +209,9 @@ def train(
# Forward and backward pass # Forward and backward pass
(lvalue, toks), grad = loss_value_and_grad(model, *batch) (lvalue, toks), grad = loss_value_and_grad(model, *batch)
# All reduce the gradients if running in distributed mode
grad = average_gradients(grad)
# Model update # Model update
optimizer.update(model, grad) optimizer.update(model, grad)
@ -199,8 +219,9 @@ def train(
loss_value_and_grad = nn.value_and_grad(model, loss) loss_value_and_grad = nn.value_and_grad(model, loss)
losses = [] losses = 0
n_tokens = 0 n_tokens = 0
steps = 0
trained_tokens = 0 trained_tokens = 0
# Main training loop # Main training loop
start = time.perf_counter() start = time.perf_counter()
@ -229,8 +250,12 @@ def train(
iterate_batches=iterate_batches, iterate_batches=iterate_batches,
) )
val_time = time.perf_counter() - stop val_time = time.perf_counter() - stop
if rank == 0:
print( print(
f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s" f"Iter {it}: "
f"Val loss {val_loss:.3f}, "
f"Val took {val_time:.3f}s",
flush=True,
) )
if training_callback is not None: if training_callback is not None:
@ -244,29 +269,32 @@ def train(
start = time.perf_counter() start = time.perf_counter()
lvalue, toks = step(batch) lvalue, toks = step(batch)
mx.eval(state, lvalue, toks) losses += lvalue
n_tokens += toks
# Record loss steps += 1
losses.append(lvalue.item()) mx.eval(state, losses, n_tokens)
n_tokens += toks.item()
# Report training loss if needed # Report training loss if needed
if it % args.steps_per_report == 0 or it == args.iters: if it % args.steps_per_report == 0 or it == args.iters:
stop = time.perf_counter() stop = time.perf_counter()
train_loss = np.mean(losses) train_loss = mx.distributed.all_sum(losses).item()
train_loss /= steps * mx.distributed.init().size()
n_tokens = mx.distributed.all_sum(n_tokens).item()
learning_rate = optimizer.learning_rate.item() learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start) it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start) tokens_sec = float(n_tokens) / (stop - start)
trained_tokens += n_tokens trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 2**30 peak_mem = mx.metal.get_peak_memory() / 1e9
if rank == 0:
print( print(
f"Iter {it}: Train loss {train_loss:.3f}, " f"Iter {it}: Train loss {train_loss:.3f}, "
f"Learning Rate {learning_rate:.3e}, " f"Learning Rate {learning_rate:.3e}, "
f"It/sec {it_sec:.3f}, " f"It/sec {it_sec:.3f}, "
f"Tokens/sec {tokens_sec:.3f}, " f"Tokens/sec {tokens_sec:.3f}, "
f"Trained Tokens {trained_tokens}, " f"Trained Tokens {trained_tokens}, "
f"Peak mem {peak_mem:.3f} GB" f"Peak mem {peak_mem:.3f} GB",
flush=True,
) )
if training_callback is not None: if training_callback is not None:
@ -281,8 +309,9 @@ def train(
} }
training_callback.on_train_loss_report(train_info) training_callback.on_train_loss_report(train_info)
losses = [] losses = 0
n_tokens = 0 n_tokens = 0
steps = 0
start = time.perf_counter() start = time.perf_counter()
# Save adapter weights # Save adapter weights

View File

@ -1,5 +1,6 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import contextlib
import copy import copy
import glob import glob
import importlib import importlib
@ -14,12 +15,12 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten from mlx.utils import tree_flatten, tree_reduce
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
# Local imports # Local imports
from .models import base, cache from .models import cache
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling from .sample_utils import make_logits_processors, make_sampler
from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import dequantize as dequantize_model from .tuner.utils import dequantize as dequantize_model
from .tuner.utils import load_adapters from .tuner.utils import load_adapters
@ -28,10 +29,14 @@ from .tuner.utils import load_adapters
MODEL_REMAPPING = { MODEL_REMAPPING = {
"mistral": "llama", # mistral is compatible with llama "mistral": "llama", # mistral is compatible with llama
"phi-msft": "phixtral", "phi-msft": "phixtral",
"falcon_mamba": "mamba",
} }
MAX_FILE_SIZE_GB = 5 MAX_FILE_SIZE_GB = 5
# A stream on the default device just for generation
generation_stream = mx.new_stream(mx.default_device())
class ModelNotFoundError(Exception): class ModelNotFoundError(Exception):
def __init__(self, message): def __init__(self, message):
@ -39,6 +44,40 @@ class ModelNotFoundError(Exception):
super().__init__(self.message) super().__init__(self.message)
@contextlib.contextmanager
def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None):
"""
A context manager to temporarily change the wired limit.
Note, the wired limit should not be changed during an async eval. If an
async eval could be running pass in the streams to synchronize with prior
to exiting the context manager.
"""
model_bytes = tree_reduce(
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
)
max_rec_size = mx.metal.device_info()["max_recommended_working_set_size"]
if model_bytes > 0.9 * max_rec_size:
model_mb = model_bytes // 2**20
max_rec_mb = max_rec_size // 2**20
print(
"[WARNING] Generating with a model that requires {model_mb} MB "
"which is close to the maximum recommended size of {max_rec_mb} "
"MB. This can be slow. See the documentation for possible work-arounds: "
"https://github.com/ml-explore/mlx-examples/tree/main/llms#large-models"
)
old_limit = mx.metal.set_wired_limit(max_rec_size)
try:
yield None
finally:
if streams is not None:
for s in streams:
mx.synchronize(s)
else:
mx.synchronize()
mx.metal.set_wired_limit(old_limit)
def _get_classes(config: dict): def _get_classes(config: dict):
""" """
Retrieve the model and model args classes based on the configuration. Retrieve the model and model args classes based on the configuration.
@ -101,27 +140,16 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
return model_path return model_path
def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float): def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
""" if (
Apply repetition penalty to specific logits based on the given context. kv_bits is not None
and not isinstance(prompt_cache[0], cache.QuantizedKVCache)
Paper: https://arxiv.org/abs/1909.05858 and prompt_cache[0].offset > quantized_kv_start
):
Args: for i in range(len(prompt_cache)):
logits (mx.array): The logits produced by the language model. prompt_cache[i] = prompt_cache[i].to_quantized(
tokens (mx.array): A list of N previous tokens. group_size=kv_group_size, bits=kv_bits
penalty (float): The repetition penalty factor to be applied.
Returns:
logits (mx.array): Logits with repetition penalty applied to generated tokens.
"""
if len(tokens) > 0:
selected_logits = logits[:, tokens]
selected_logits = mx.where(
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
) )
logits[:, tokens] = selected_logits
return logits
def generate_step( def generate_step(
@ -137,7 +165,10 @@ def generate_step(
max_kv_size: Optional[int] = None, max_kv_size: Optional[int] = None,
prompt_cache: Optional[Any] = None, prompt_cache: Optional[Any] = None,
logit_bias: Optional[Dict[int, float]] = None, logit_bias: Optional[Dict[int, float]] = None,
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
) -> Generator[Tuple[mx.array, mx.array], None, None]: ) -> Generator[Tuple[mx.array, mx.array], None, None]:
""" """
A generator producing token ids based on the given prompt from the model. A generator producing token ids based on the given prompt from the model.
@ -163,79 +194,55 @@ def generate_step(
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place. provided, the cache will be updated in place.
logit_bias (dictionary, optional): Additive logit bias. logit_bias (dictionary, optional): Additive logit bias.
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional): logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed A list of functions that take tokens and logits and return the processed
logits. Default: ``None``. logits. Default: ``None``.
kv_bits (int, optional): Number of bits to use for KV cache quantization.
None implies no cache quantization. Default: ``None``.
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
quantized_kv_start (int): Step to begin using a quantized KV cache.
when ``kv_bits`` is non-None. Default: ``0``.
Yields: Yields:
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
one token and a vector of log probabilities.
""" """
def sample(logits: mx.array) -> Tuple[mx.array, float]:
logprobs = logits - mx.logsumexp(logits)
if temp == 0:
token = mx.argmax(logits, axis=-1)
else:
if top_p > 0 and top_p < 1.0:
token = top_p_sampling(logits, top_p, temp)
elif min_p != 0.0:
token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp)
else:
token = categorical_sampling(logits, temp)
return token, logprobs
if repetition_penalty and (
repetition_penalty < 0 or not isinstance(repetition_penalty, float)
):
raise ValueError(
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
)
logits_processor = logits_processor or []
if repetition_penalty:
def repetition_penalty_processor(tokens, logits):
return apply_repetition_penalty(
logits, tokens[-repetition_context_size:], repetition_penalty
)
logits_processor.append(repetition_penalty_processor)
if logit_bias:
indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values()))
def logit_bias_processor(_, logits):
logits[:, indices] += values
return logits
logits_processor.append(logit_bias_processor)
y = prompt y = prompt
tokens = None tokens = None
# Create the KV cache for generation # Create the KV cache for generation
if prompt_cache is None: if prompt_cache is None:
prompt_cache = cache.make_prompt_cache(model, max_kv_size) prompt_cache = cache.make_prompt_cache(
model,
max_kv_size=max_kv_size,
)
elif len(prompt_cache) != len(model.layers): elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.") raise ValueError("Wrong number of layers in the prompt cache.")
sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep)
logits_processors = logits_processors or []
logits_processors.extend(
make_logits_processors(logit_bias, repetition_penalty, repetition_context_size)
)
def _step(y): def _step(y):
with mx.stream(generation_stream):
logits = model(y[None], cache=prompt_cache) logits = model(y[None], cache=prompt_cache)
logits = logits[:, -1, :] logits = logits[:, -1, :]
if logits_processor: if logits_processors:
nonlocal tokens nonlocal tokens
tokens = mx.concat([tokens, y]) if tokens is not None else y tokens = mx.concat([tokens, y]) if tokens is not None else y
for processor in logits_processor: for processor in logits_processors:
logits = processor(tokens, logits) logits = processor(tokens, logits)
y, logprobs = sample(logits) maybe_quantize_kv_cache(
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
)
logprobs = logits - mx.logsumexp(logits, keepdims=True)
y = sampler(logprobs)
return y, logprobs.squeeze(0) return y, logprobs.squeeze(0)
while y.size > prefill_step_size: while y.size > prefill_step_size:
@ -247,53 +254,65 @@ def generate_step(
y, logprobs = _step(y) y, logprobs = _step(y)
mx.async_eval(y, logprobs) mx.async_eval(y, logprobs)
n = 0
while True: while True:
next_y, next_logprobs = _step(y) next_y, next_logprobs = _step(y)
mx.async_eval(next_y, next_logprobs) mx.async_eval(next_y, next_logprobs)
yield y.item(), logprobs yield y.item(), logprobs
if n % 256 == 0:
mx.metal.clear_cache()
n += 1
y, logprobs = next_y, next_logprobs y, logprobs = next_y, next_logprobs
def stream_generate( def stream_generate(
model: nn.Module, model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str, prompt: Union[str, List[int]],
max_tokens: int = 100, max_tokens: int = 100,
**kwargs, **kwargs,
) -> Union[str, Generator[str, None, None]]: ) -> Generator[Tuple[str, int, mx.array], None, None]:
""" """
A generator producing text based on the given prompt from the model. A generator producing text based on the given prompt from the model.
Args: Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation. model (nn.Module): The model to use for generation.
max_tokens (int): The ma tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (Union[str, List[int]]): The input prompt string or integer tokens.
max_tokens (int): The maximum number of tokens. Default: ``100``.
kwargs: The remaining options get passed to :func:`generate_step`. kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details. See :func:`generate_step` for more details.
Yields: Yields:
Generator[Tuple[mx.array, mx.array]]: A generator producing text. Tuple[str, int, mx.array]:
The next text segment, token, and vector of log probabilities.
""" """
if not isinstance(tokenizer, TokenizerWrapper): if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer) tokenizer = TokenizerWrapper(tokenizer)
prompt_tokens = mx.array(tokenizer.encode(prompt)) prompt_tokens = mx.array(
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
)
detokenizer = tokenizer.detokenizer detokenizer = tokenizer.detokenizer
with wired_limit(model, [generation_stream]):
detokenizer.reset() detokenizer.reset()
for n, (token, _) in zip( for n, (token, logits) in zip(
range(max_tokens), range(max_tokens),
generate_step(prompt_tokens, model, **kwargs), generate_step(prompt_tokens, model, **kwargs),
): ):
if token == tokenizer.eos_token_id: if token == tokenizer.eos_token_id:
break break
detokenizer.add_token(token) detokenizer.add_token(token)
# Yield the last segment if streaming if n == (max_tokens - 1):
yield detokenizer.last_segment break
yield detokenizer.last_segment, token, logits
detokenizer.finalize() detokenizer.finalize()
yield detokenizer.last_segment yield detokenizer.last_segment, token, logits
def generate( def generate(
@ -304,7 +323,7 @@ def generate(
verbose: bool = False, verbose: bool = False,
formatter: Optional[Callable] = None, formatter: Optional[Callable] = None,
**kwargs, **kwargs,
) -> Union[str, Generator[str, None, None]]: ) -> str:
""" """
Generate a complete response from the model. Generate a complete response from the model.
@ -330,9 +349,9 @@ def generate(
prompt_tokens = mx.array(tokenizer.encode(prompt)) prompt_tokens = mx.array(tokenizer.encode(prompt))
detokenizer = tokenizer.detokenizer detokenizer = tokenizer.detokenizer
with wired_limit(model, [generation_stream]):
tic = time.perf_counter() tic = time.perf_counter()
detokenizer.reset() detokenizer.reset()
for n, (token, logprobs) in zip( for n, (token, logprobs) in zip(
range(max_tokens), range(max_tokens),
generate_step(prompt_tokens, model, **kwargs), generate_step(prompt_tokens, model, **kwargs),
@ -348,7 +367,6 @@ def generate(
if formatter: if formatter:
# We have to finalize so that the prob corresponds to the last segment # We have to finalize so that the prob corresponds to the last segment
detokenizer.finalize() detokenizer.finalize()
with mx.stream(mx.cpu):
prob = mx.exp(logprobs[token]).item() prob = mx.exp(logprobs[token]).item()
formatter(detokenizer.last_segment, prob) formatter(detokenizer.last_segment, prob)
else: else:
@ -366,9 +384,11 @@ def generate(
return return
prompt_tps = prompt_tokens.size / prompt_time prompt_tps = prompt_tokens.size / prompt_time
gen_tps = (token_count - 1) / gen_time gen_tps = (token_count - 1) / gen_time
print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec") print(
f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec"
)
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec") print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
peak_mem = mx.metal.get_peak_memory() / 2**30 peak_mem = mx.metal.get_peak_memory() / 1e9
print(f"Peak memory: {peak_mem:.3f} GB") print(f"Peak memory: {peak_mem:.3f} GB")
return detokenizer.text return detokenizer.text
@ -553,7 +573,9 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
f""" f"""
# {upload_repo} # {upload_repo}
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) using mlx-lm version **{__version__}**. The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was
converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path})
using mlx-lm version **{__version__}**.
## Use with mlx ## Use with mlx

View File

@ -3,6 +3,7 @@
import math import math
import sys import sys
import unittest import unittest
from contextlib import contextmanager
from io import StringIO from io import StringIO
from unittest.mock import MagicMock from unittest.mock import MagicMock
@ -17,6 +18,14 @@ from mlx_lm.tuner.trainer import evaluate
from mlx_lm.tuner.utils import build_schedule from mlx_lm.tuner.utils import build_schedule
@contextmanager
def swapped_with_identity(obj, func):
old_func = getattr(obj, func)
setattr(obj, func, lambda x: x)
yield
setattr(obj, func, old_func)
class TestLora(unittest.TestCase): class TestLora(unittest.TestCase):
def setUp(self): def setUp(self):
self.capturedOutput = StringIO() self.capturedOutput = StringIO()
@ -374,6 +383,7 @@ class TestScheduleConfig(unittest.TestCase):
(MagicMock(return_value=0.4), MagicMock(return_value=180)), (MagicMock(return_value=0.4), MagicMock(return_value=180)),
(MagicMock(return_value=0.6), MagicMock(return_value=120)), (MagicMock(return_value=0.6), MagicMock(return_value=120)),
] ]
with swapped_with_identity(mx.distributed, "all_sum"):
evaluate( evaluate(
model=mock_model, model=mock_model,
dataset=mock_dataset, dataset=mock_dataset,
@ -412,6 +422,7 @@ class TestScheduleConfig(unittest.TestCase):
(MagicMock(return_value=0.2), MagicMock(return_value=150)), (MagicMock(return_value=0.2), MagicMock(return_value=150)),
] ]
with swapped_with_identity(mx.distributed, "all_sum"):
evaluate( evaluate(
model=mock_model, model=mock_model,
dataset=mock_dataset, dataset=mock_dataset,

View File

@ -46,7 +46,7 @@ class TestGenerate(unittest.TestCase):
"hello", "hello",
max_tokens=5, max_tokens=5,
verbose=False, verbose=False,
logits_processor=[logits_processor], logits_processors=[logits_processor],
) )
self.assertEqual(len(all_toks), len(init_toks) + 5) self.assertEqual(len(all_toks), len(init_toks) + 5)

View File

@ -9,6 +9,7 @@ import mlx.core as mx
from mlx_lm.models.cache import ( from mlx_lm.models.cache import (
KVCache, KVCache,
MambaCache, MambaCache,
QuantizedKVCache,
RotatingKVCache, RotatingKVCache,
load_prompt_cache, load_prompt_cache,
make_prompt_cache, make_prompt_cache,
@ -186,6 +187,18 @@ class TestPromptCache(unittest.TestCase):
num_trimmed = trim_prompt_cache(cache, 4) num_trimmed = trim_prompt_cache(cache, 4)
self.assertEqual(num_trimmed, 0) self.assertEqual(num_trimmed, 0)
cache = [QuantizedKVCache() for _ in range(2)]
for c in cache:
x = mx.random.uniform(shape=(1, 8, 10, 64))
c.update_and_fetch(x, x)
num_trimmed = trim_prompt_cache(cache, 7)
self.assertEqual(num_trimmed, 7)
# Trim more tokens than remain
num_trimmed = trim_prompt_cache(cache, 4)
self.assertEqual(num_trimmed, 3)
def test_trim_cache_with_generate(self): def test_trim_cache_with_generate(self):
model, tokenizer = load(HF_MODEL_PATH) model, tokenizer = load(HF_MODEL_PATH)
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
@ -238,6 +251,56 @@ class TestPromptCache(unittest.TestCase):
self.assertTrue(mx.allclose(old_cache[0].keys[..., 10:11, :], y)) self.assertTrue(mx.allclose(old_cache[0].keys[..., 10:11, :], y))
self.assertTrue(mx.allclose(cache[0].keys[..., 10:11, :], z)) self.assertTrue(mx.allclose(cache[0].keys[..., 10:11, :], z))
def test_save_load_quantized_cache(self):
cache = [QuantizedKVCache(bits=4, group_size=32) for _ in range(4)]
for c in cache:
x = mx.random.uniform(shape=(1, 8, 10, 32))
c.update_and_fetch(x, x)
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
save_prompt_cache(cache_file, cache)
loaded_cache = load_prompt_cache(cache_file)
self.assertTrue(loaded_cache[0].bits == cache[0].bits)
self.assertTrue(loaded_cache[0].group_size == cache[0].group_size)
self.assertTrue(len(cache), len(loaded_cache))
for c, lc in zip(cache, loaded_cache):
self.assertEqual(c.offset, lc.offset)
# Loop over quantized tuple
for i in range(3):
self.assertTrue(mx.array_equal(c.state[0][i], lc.state[0][i]))
self.assertTrue(mx.array_equal(c.state[1][i], lc.state[1][i]))
# Test with metadata
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
metadata = {"a": "b", "c": "d"}
save_prompt_cache(cache_file, cache, metadata)
_, loaded_metadata = load_prompt_cache(cache_file, return_metadata=True)
self.assertEqual(metadata, loaded_metadata)
def test_cache_to_quantized(self):
model, tokenizer = load(HF_MODEL_PATH)
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
results = zip(range(4), generate_step(prompt, model))
toks, all_logits = zip(*(r[1] for r in results))
prompt_cache = make_prompt_cache(model)
i = 0
for _, (tok, logits) in zip(
range(2), generate_step(prompt, model, prompt_cache=prompt_cache)
):
self.assertEqual(tok, toks[i])
self.assertTrue(mx.allclose(logits, all_logits[i]))
i += 1
prompt_cache = [c.to_quantized(bits=8, group_size=32) for c in prompt_cache]
for _, (tok, logits) in zip(
range(1),
generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache),
):
i += 1
self.assertEqual(tok, toks[i])
self.assertTrue(mx.allclose(logits, all_logits[i], rtol=2e-2))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -42,6 +42,9 @@ class TestTokenizers(unittest.TestCase):
text += detokenizer.last_segment text += detokenizer.last_segment
self.assertEqual(text, expected_text) self.assertEqual(text, expected_text)
tokens = tokenizer.encode("こんにちは私の名前はAI")
check(tokens)
tokens = tokenizer.encode("a ,b") tokens = tokenizer.encode("a ,b")
check(tokens) check(tokens)
@ -74,6 +77,17 @@ class TestTokenizers(unittest.TestCase):
tokenizer._detokenizer = NaiveStreamingDetokenizer(tokenizer) tokenizer._detokenizer = NaiveStreamingDetokenizer(tokenizer)
self.check_tokenizer(tokenizer) self.check_tokenizer(tokenizer)
def test_special_tokens(self):
tokenizer_repo = "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx"
tokenizer = self.download_tokenizer(tokenizer_repo)
detokenizer = tokenizer.detokenizer
detokenizer.reset()
detokenizer.add_token(tokenizer.eos_token_id)
detokenizer.finalize()
self.assertEqual(detokenizer.last_segment, tokenizer.eos_token)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -25,7 +25,7 @@ pip install mlx-whisper
At its simplest: At its simplest:
``` ```sh
mlx_whisper audio_file.mp3 mlx_whisper audio_file.mp3
``` ```
@ -35,6 +35,15 @@ Use `-f` to specify the output format and `--model` to specify the model. There
are many other supported command line options. To see them all, run are many other supported command line options. To see them all, run
`mlx_whisper -h`. `mlx_whisper -h`.
You can also pipe the audio content of other programs via stdin:
```sh
some-process | mlx_whisper -
```
The default output file name will be `content.*`. You can specify the name with
the `--output-name` flag.
#### API #### API
Transcribe audio with: Transcribe audio with:

View File

@ -181,7 +181,7 @@ def load_torch_weights_and_config(
) )
if name_or_path.endswith(".pt"): if name_or_path.endswith(".pt"):
checkpoint = torch.load(name_or_path, map_location="cpu") checkpoint = torch.load(name_or_path, map_location="cpu", weights_only=False)
weights, config = checkpoint["model_state_dict"], checkpoint["dims"] weights, config = checkpoint["model_state_dict"], checkpoint["dims"]
else: else:
name_or_path = Path(name_or_path) name_or_path = Path(name_or_path)
@ -387,7 +387,7 @@ if __name__ == "__main__":
# Save weights # Save weights
print("[INFO] Saving") print("[INFO] Saving")
np.savez(str(mlx_path / "weights.npz"), **weights) mx.save_safetensors(str(mlx_path / "weights.safetensors"), weights)
# Save config.json with model_type # Save config.json with model_type
with open(str(mlx_path / "config.json"), "w") as f: with open(str(mlx_path / "config.json"), "w") as f:

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
__version__ = "0.3.0" __version__ = "0.4.1"

View File

@ -3,7 +3,7 @@
import os import os
from functools import lru_cache from functools import lru_cache
from subprocess import CalledProcessError, run from subprocess import CalledProcessError, run
from typing import Union from typing import Optional, Union
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
@ -21,7 +21,7 @@ FRAMES_PER_SECOND = SAMPLE_RATE // HOP_LENGTH # 10ms per audio frame
TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token
def load_audio(file: str, sr: int = SAMPLE_RATE): def load_audio(file: str = Optional[str], sr: int = SAMPLE_RATE, from_stdin=False):
""" """
Open an audio file and read as mono waveform, resampling as necessary Open an audio file and read as mono waveform, resampling as necessary
@ -40,18 +40,20 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
# This launches a subprocess to decode audio while down-mixing # This launches a subprocess to decode audio while down-mixing
# and resampling as necessary. Requires the ffmpeg CLI in PATH. # and resampling as necessary. Requires the ffmpeg CLI in PATH.
if from_stdin:
cmd = ["ffmpeg", "-i", "pipe:0"]
else:
cmd = ["ffmpeg", "-nostdin", "-i", file]
# fmt: off # fmt: off
cmd = [ cmd.extend([
"ffmpeg",
"-nostdin",
"-threads", "0", "-threads", "0",
"-i", file,
"-f", "s16le", "-f", "s16le",
"-ac", "1", "-ac", "1",
"-acodec", "pcm_s16le", "-acodec", "pcm_s16le",
"-ar", str(sr), "-ar", str(sr),
"-" "-"
] ])
# fmt: on # fmt: on
try: try:
out = run(cmd, capture_output=True, check=True).stdout out = run(cmd, capture_output=True, check=True).stdout

View File

@ -2,9 +2,11 @@
import argparse import argparse
import os import os
import pathlib
import traceback import traceback
import warnings import warnings
from . import audio
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from .transcribe import transcribe from .transcribe import transcribe
from .writers import get_writer from .writers import get_writer
@ -27,15 +29,24 @@ def build_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
) )
parser.add_argument(
"audio", nargs="+", type=str, help="Audio file(s) to transcribe" parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe")
)
parser.add_argument( parser.add_argument(
"--model", "--model",
default="mlx-community/whisper-tiny", default="mlx-community/whisper-tiny",
type=str, type=str,
help="The model directory or hugging face repo", help="The model directory or hugging face repo",
) )
parser.add_argument(
"--output-name",
type=str,
default=None,
help=(
"The name of transcription/translation output files before "
"--output-format extensions"
),
)
parser.add_argument( parser.add_argument(
"--output-dir", "--output-dir",
"-o", "-o",
@ -200,6 +211,7 @@ def main():
path_or_hf_repo: str = args.pop("model") path_or_hf_repo: str = args.pop("model")
output_dir: str = args.pop("output_dir") output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format") output_format: str = args.pop("output_format")
output_name: str = args.pop("output_name")
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
writer = get_writer(output_format, output_dir) writer = get_writer(output_format, output_dir)
@ -219,17 +231,25 @@ def main():
warnings.warn("--max-line-count has no effect without --max-line-width") warnings.warn("--max-line-count has no effect without --max-line-width")
if writer_args["max_words_per_line"] and writer_args["max_line_width"]: if writer_args["max_words_per_line"] and writer_args["max_line_width"]:
warnings.warn("--max-words-per-line has no effect with --max-line-width") warnings.warn("--max-words-per-line has no effect with --max-line-width")
for audio_path in args.pop("audio"):
for audio_obj in args.pop("audio"):
if audio_obj == "-":
# receive the contents from stdin rather than read a file
audio_obj = audio.load_audio(from_stdin=True)
output_name = output_name or "content"
else:
output_name = output_name or pathlib.Path(audio_obj).stem
try: try:
result = transcribe( result = transcribe(
audio_path, audio_obj,
path_or_hf_repo=path_or_hf_repo, path_or_hf_repo=path_or_hf_repo,
**args, **args,
) )
writer(result, audio_path, **writer_args) writer(result, output_name, **writer_args)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}") print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -58,11 +58,12 @@ def detect_language(
logits = model.logits(x, mel)[:, 0] logits = model.logits(x, mel)[:, 0]
# collect detected languages; suppress all non-language tokens # collect detected languages; suppress all non-language tokens
mask = np.full(logits.shape[-1], -np.inf, dtype=np.float32) mask = mx.full(logits.shape[-1], -mx.inf, dtype=mx.float32)
mask[list(tokenizer.all_language_tokens)] = 0.0 mask[list(tokenizer.all_language_tokens)] = 0.0
logits += mx.array(mask) logits += mask
language_tokens = mx.argmax(logits, axis=-1) language_tokens = mx.argmax(logits, axis=-1)
language_token_probs = mx.softmax(logits, axis=-1) language_token_probs = mx.softmax(logits, axis=-1)
language_token_probs = np.array(language_token_probs)
language_probs = [ language_probs = [
{ {
c: language_token_probs[i, j].item() c: language_token_probs[i, j].item()
@ -129,17 +130,12 @@ class DecodingResult:
class Inference: class Inference:
def __init__(self, model: "Whisper", initial_token_length: int): def __init__(self, model: "Whisper"):
self.model: "Whisper" = model self.model: "Whisper" = model
self.initial_token_length = initial_token_length
self.kv_cache = None self.kv_cache = None
def logits(self, tokens: mx.array, audio_features: mx.array) -> mx.array: def logits(self, tokens: mx.array, audio_features: mx.array) -> mx.array:
"""Perform a forward pass on the decoder and return per-token logits""" """Perform a forward pass on the decoder and return per-token logits"""
if tokens.shape[-1] > self.initial_token_length:
# only need to use the last token except in the first forward pass
tokens = tokens[:, -1:]
logits, self.kv_cache, _ = self.model.decoder( logits, self.kv_cache, _ = self.model.decoder(
tokens, audio_features, kv_cache=self.kv_cache tokens, audio_features, kv_cache=self.kv_cache
) )
@ -251,6 +247,11 @@ class TokenDecoder:
raise NotImplementedError raise NotImplementedError
@mx.compile
def categorical(logits, temp):
return mx.random.categorical(logits / temp)
class GreedyDecoder(TokenDecoder): class GreedyDecoder(TokenDecoder):
def __init__(self, temperature: float, eot: int): def __init__(self, temperature: float, eot: int):
self.temperature = temperature self.temperature = temperature
@ -262,10 +263,8 @@ class GreedyDecoder(TokenDecoder):
if self.temperature == 0: if self.temperature == 0:
next_tokens = logits.argmax(axis=-1) next_tokens = logits.argmax(axis=-1)
else: else:
next_tokens = mx.random.categorical(logits=logits / self.temperature) next_tokens = categorical(logits, self.temperature)
next_tokens = mx.argmax(logits, axis=-1)
logits = logits.astype(mx.float32)
logprobs = logits - mx.logsumexp(logits, axis=-1) logprobs = logits - mx.logsumexp(logits, axis=-1)
current_logprobs = logprobs[mx.arange(logprobs.shape[0]), next_tokens] current_logprobs = logprobs[mx.arange(logprobs.shape[0]), next_tokens]
@ -281,7 +280,7 @@ class GreedyDecoder(TokenDecoder):
def finalize(self, tokens: mx.array, sum_logprobs: mx.array): def finalize(self, tokens: mx.array, sum_logprobs: mx.array):
# make sure each sequence has at least one EOT token at the end # make sure each sequence has at least one EOT token at the end
tokens = mx.pad(tokens, [(0, 0), (0, 0), (0, 1)], constant_values=self.eot) tokens = mx.pad(tokens, [(0, 0), (0, 0), (0, 1)], constant_values=self.eot)
return tokens, sum_logprobs.tolist() return tokens, sum_logprobs
class LogitFilter: class LogitFilter:
@ -340,10 +339,10 @@ class ApplyTimestampRules(LogitFilter):
if self.tokenizer.no_timestamps is not None: if self.tokenizer.no_timestamps is not None:
mask[:, self.tokenizer.no_timestamps] = -np.inf mask[:, self.tokenizer.no_timestamps] = -np.inf
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly ## timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
for k in range(tokens.shape[0]): tokens = tokens.tolist()
sampled_tokens = tokens[k, self.sample_begin :] for k in range(len(tokens)):
seq = sampled_tokens.tolist() seq = tokens[k][self.sample_begin :]
last_was_timestamp = ( last_was_timestamp = (
len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
) )
@ -368,7 +367,7 @@ class ApplyTimestampRules(LogitFilter):
last_timestamp += 1 last_timestamp += 1
mask[k, self.tokenizer.timestamp_begin : last_timestamp] = -np.inf mask[k, self.tokenizer.timestamp_begin : last_timestamp] = -np.inf
if tokens.shape[1] == self.sample_begin: if len(tokens[0]) == self.sample_begin:
# suppress generating non-timestamp tokens at the beginning # suppress generating non-timestamp tokens at the beginning
mask[:, : self.tokenizer.timestamp_begin] = -np.inf mask[:, : self.tokenizer.timestamp_begin] = -np.inf
@ -380,16 +379,20 @@ class ApplyTimestampRules(LogitFilter):
mask[:, last_allowed + 1 :] = -np.inf mask[:, last_allowed + 1 :] = -np.inf
# if sum of probability over timestamps is above any other token, sample timestamp # if sum of probability over timestamps is above any other token, sample timestamp
mask = mx.array(mask)
logprobs = logits - mx.logsumexp(logits, axis=-1) logprobs = logits - mx.logsumexp(logits, axis=-1)
for k in range(tokens.shape[0]): timestamp_logprob = logprobs[:, self.tokenizer.timestamp_begin :].logsumexp(
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp( axis=-1, keepdims=True
axis=-1
) )
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max() max_text_token_logprob = logprobs[:, : self.tokenizer.timestamp_begin].max(
if timestamp_logprob > max_text_token_logprob: axis=-1, keepdims=True
mask[k, : self.tokenizer.timestamp_begin] = -np.inf )
mask[:, : self.tokenizer.timestamp_begin] = mx.where(
return logits + mx.array(mask, logits.dtype) timestamp_logprob > max_text_token_logprob,
-mx.inf,
mask[:, : self.tokenizer.timestamp_begin],
)
return logits + mask
class DecodingTask: class DecodingTask:
@ -424,7 +427,7 @@ class DecodingTask:
self.sot_index: int = self.initial_tokens.index(tokenizer.sot) self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
# inference: implements the forward pass through the decoder, including kv caching # inference: implements the forward pass through the decoder, including kv caching
self.inference = Inference(model, len(self.initial_tokens)) self.inference = Inference(model)
# sequence ranker: implements how to rank a group of sampled sequences # sequence ranker: implements how to rank a group of sampled sequences
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty) self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
@ -432,9 +435,6 @@ class DecodingTask:
# decoder: implements how to select the next tokens, given the autoregressive distribution # decoder: implements how to select the next tokens, given the autoregressive distribution
if options.beam_size is not None: if options.beam_size is not None:
raise NotImplementedError("Beam search decoder is not yet implemented") raise NotImplementedError("Beam search decoder is not yet implemented")
# self.decoder = BeamSearchDecoder(
# options.beam_size, tokenizer.eot, self.inference, options.patience
# )
else: else:
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot) self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
@ -448,6 +448,7 @@ class DecodingTask:
self.logit_filters.append( self.logit_filters.append(
SuppressTokens(self._get_suppress_tokens(), model.dims.n_vocab) SuppressTokens(self._get_suppress_tokens(), model.dims.n_vocab)
) )
if not options.without_timestamps: if not options.without_timestamps:
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
max_initial_timestamp_index = None max_initial_timestamp_index = None
@ -570,23 +571,13 @@ class DecodingTask:
def _main_loop(self, audio_features: mx.array, tokens: mx.array): def _main_loop(self, audio_features: mx.array, tokens: mx.array):
n_batch = tokens.shape[0] n_batch = tokens.shape[0]
sum_logprobs: mx.array = mx.zeros(n_batch) sum_logprobs = mx.zeros(n_batch)
no_speech_probs = [np.nan] * n_batch
try: def _step(inputs, audio_features, tokens, sum_logprobs):
for i in range(self.sample_len): pre_logits = self.inference.logits(inputs, audio_features)
logits = self.inference.logits(tokens, audio_features)
if ( # consider the logits at the last token only
i == 0 and self.tokenizer.no_speech is not None logits = pre_logits[:, -1]
): # save no_speech_probs
probs_at_sot = mx.softmax(
logits[:, self.sot_index].astype(mx.float32), axis=-1
)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
# now we need to consider the logits at the last token only
logits = logits[:, -1]
# apply the logit filters, e.g. for suppressing or applying penalty to # apply the logit filters, e.g. for suppressing or applying penalty to
for logit_filter in self.logit_filters: for logit_filter in self.logit_filters:
@ -596,22 +587,43 @@ class DecodingTask:
tokens, completed, sum_logprobs = self.decoder.update( tokens, completed, sum_logprobs = self.decoder.update(
tokens, logits, sum_logprobs tokens, logits, sum_logprobs
) )
return tokens, completed, sum_logprobs, pre_logits
if completed or tokens.shape[-1] > self.n_ctx: tokens, completed, sum_logprobs, pre_logits = _step(
tokens, audio_features, tokens, sum_logprobs
)
if self.tokenizer.no_speech is not None: # compute no_speech_probs
probs_at_sot = mx.softmax(pre_logits[:, self.sot_index], axis=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech]
else:
no_speech_probs = mx.full(n_batch, mx.nan)
mx.async_eval(completed, tokens, sum_logprobs, no_speech_probs)
for i in range(1, self.sample_len):
inputs = tokens[:, -1:]
if tokens.shape[-1] > self.n_ctx:
break break
finally: next_tokens, next_completed, next_sum_logprobs, _ = _step(
self.inference.reset() inputs, audio_features, tokens, sum_logprobs
)
mx.async_eval(next_completed, next_tokens, next_sum_logprobs)
if completed:
break
tokens = next_tokens
completed = next_completed
sum_logprobs = next_sum_logprobs
return tokens, sum_logprobs, no_speech_probs return tokens, sum_logprobs, no_speech_probs
def run(self, mel: mx.array) -> List[DecodingResult]: def run(self, mel: mx.array) -> List[DecodingResult]:
self.inference.reset()
self.decoder.reset() self.decoder.reset()
tokenizer: Tokenizer = self.tokenizer tokenizer: Tokenizer = self.tokenizer
n_audio: int = mel.shape[0] n_audio: int = mel.shape[0]
audio_features: mx.array = self._get_audio_features(mel) # encoder forward pass audio_features: mx.array = self._get_audio_features(mel) # encoder forward pass
tokens: np.array = np.array(self.initial_tokens) tokens: mx.array = mx.array(self.initial_tokens)
tokens = np.broadcast_to(tokens, (n_audio, len(self.initial_tokens))).copy() tokens = mx.broadcast_to(tokens, (n_audio, len(self.initial_tokens)))
# detect language if requested, overwriting the language token # detect language if requested, overwriting the language token
languages, language_probs = self._detect_language(audio_features, tokens) languages, language_probs = self._detect_language(audio_features, tokens)
@ -626,7 +638,6 @@ class DecodingTask:
] ]
# repeat tokens by the group size, for beam search or best-of-n sampling # repeat tokens by the group size, for beam search or best-of-n sampling
tokens = mx.array(tokens)
if self.n_group > 1: if self.n_group > 1:
tokens = tokens[:, None, :] tokens = tokens[:, None, :]
tokens = mx.broadcast_to( tokens = mx.broadcast_to(
@ -649,7 +660,13 @@ class DecodingTask:
# get the final candidates for each group, and slice between the first sampled token and EOT # get the final candidates for each group, and slice between the first sampled token and EOT
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs) tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
tokens = tokens[..., self.sample_begin :].tolist() tokens = tokens[..., self.sample_begin :]
# eval and convert to list
mx.eval(tokens, sum_logprobs, no_speech_probs)
tokens = tokens.tolist()
sum_logprobs = sum_logprobs.tolist()
no_speech_probs = no_speech_probs.tolist()
tokens = [[t[: t.index(tokenizer.eot)] for t in s] for s in tokens] tokens = [[t[: t.index(tokenizer.eot)] for t in s] for s in tokens]
# select the top-ranked sample in each group # select the top-ranked sample in each group

View File

@ -26,7 +26,10 @@ def load_model(
model_args = whisper.ModelDimensions(**config) model_args = whisper.ModelDimensions(**config)
weights = mx.load(str(model_path / "weights.npz")) wf = model_path / "weights.safetensors"
if not wf.exists():
wf = model_path / "weights.npz"
weights = mx.load(str(wf))
model = whisper.Whisper(model_args, dtype) model = whisper.Whisper(model_args, dtype)

View File

@ -293,6 +293,7 @@ def transcribe(
decode_options["prompt"] = all_tokens[prompt_reset_since:] decode_options["prompt"] = all_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(mel_segment) result: DecodingResult = decode_with_fallback(mel_segment)
tokens = np.array(result.tokens) tokens = np.array(result.tokens)
if no_speech_threshold is not None: if no_speech_threshold is not None:

View File

@ -80,12 +80,11 @@ class MultiHeadAttention(nn.Module):
qk = q @ k qk = q @ k
if mask is not None: if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx] qk = qk + mask[:n_ctx, :n_ctx]
qk = qk.astype(mx.float32)
w = mx.softmax(qk, axis=-1).astype(q.dtype) w = mx.softmax(qk, axis=-1, precise=True)
out = (w @ v).transpose(0, 2, 1, 3) out = (w @ v).transpose(0, 2, 1, 3)
out = out.reshape(n_batch, n_ctx, n_state) out = out.reshape(n_batch, n_ctx, n_state)
return out, qk return out, qk.astype(mx.float32)
class ResidualAttentionBlock(nn.Module): class ResidualAttentionBlock(nn.Module):

View File

@ -1,10 +1,8 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024 Apple Inc.
import json import json
import os import pathlib
import re import re
import sys
import zlib
from typing import Callable, List, Optional, TextIO from typing import Callable, List, Optional, TextIO
@ -43,15 +41,13 @@ class ResultWriter:
self.output_dir = output_dir self.output_dir = output_dir
def __call__( def __call__(
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs self, result: dict, output_name: str, options: Optional[dict] = None, **kwargs
): ):
audio_basename = os.path.basename(audio_path) output_path = (pathlib.Path(self.output_dir) / output_name).with_suffix(
audio_basename = os.path.splitext(audio_basename)[0] f".{self.extension}"
output_path = os.path.join(
self.output_dir, audio_basename + "." + self.extension
) )
with open(output_path, "w", encoding="utf-8") as f: with output_path.open("wt", encoding="utf-8") as f:
self.write_result(result, file=f, options=options, **kwargs) self.write_result(result, file=f, options=options, **kwargs)
def write_result( def write_result(