From 82e333898707eb57235f408aa6907beca095f759 Mon Sep 17 00:00:00 2001
From: Anchen
Date: Mon, 4 Nov 2024 22:06:34 +0800
Subject: [PATCH 1/6] chore(mlx-lm): add max token arg for mlx_lm.chat (#1089)
* chore(mlx-lm): add max token arg for mlx_lm.chat
* chore: update the default max token value
---
llms/mlx_lm/chat.py | 9 +++++++++
1 file changed, 9 insertions(+)
diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py
index ea1a99c7..85d32d5f 100644
--- a/llms/mlx_lm/chat.py
+++ b/llms/mlx_lm/chat.py
@@ -11,6 +11,7 @@ from .utils import load, stream_generate
DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
+DEFAULT_MAX_TOKENS = 256
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
@@ -41,6 +42,13 @@ def setup_arg_parser():
help="Set the maximum key-value cache size",
default=None,
)
+ parser.add_argument(
+ "--max-tokens",
+ "-m",
+ type=int,
+ default=DEFAULT_MAX_TOKENS,
+ help="Maximum number of tokens to generate",
+ )
return parser
@@ -70,6 +78,7 @@ def main():
model,
tokenizer,
prompt,
+ args.max_tokens,
temp=args.temp,
top_p=args.top_p,
prompt_cache=prompt_cache,
From 3b526f0aa1219fae662a86f012dbda82045f4fb0 Mon Sep 17 00:00:00 2001
From: ilyasch2 <104485953+ilyasch2@users.noreply.github.com>
Date: Tue, 5 Nov 2024 00:23:30 +0400
Subject: [PATCH 2/6] Add support for falcon-mamba (#1074)
* Add support for falcon-mamba
* nits
* nit
---------
Co-authored-by: Awni Hannun
---
llms/README.md | 1 +
llms/mlx_lm/models/mamba.py | 11 +++++++++++
llms/mlx_lm/utils.py | 1 +
3 files changed, 13 insertions(+)
diff --git a/llms/README.md b/llms/README.md
index f539988a..0e7dc7fb 100644
--- a/llms/README.md
+++ b/llms/README.md
@@ -221,6 +221,7 @@ Here are a few examples of Hugging Face models that work with this example:
- [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct)
- [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b)
- [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b)
+- [tiiuae/falcon-mamba-7b-instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct)
Most
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),
diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py
index 84f498e9..f2414660 100644
--- a/llms/mlx_lm/models/mamba.py
+++ b/llms/mlx_lm/models/mamba.py
@@ -23,6 +23,8 @@ class ModelArgs(BaseModelArgs):
use_conv_bias: bool
time_step_rank: int
tie_word_embeddings: bool = True
+ use_bcdt_rms: bool = False
+ mixer_rms_eps: float = 1e-6
def __post_init__(self):
if not hasattr(self, "hidden_size") and hasattr(self, "d_model"):
@@ -44,6 +46,8 @@ class ModelArgs(BaseModelArgs):
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
+ if self.model_type == "falcon_mamba":
+ self.use_bcdt_rms = True
class DepthWiseConv1d(nn.Module):
@@ -83,6 +87,11 @@ class MambaBlock(nn.Module):
self.intermediate_size = args.intermediate_size
self.time_step_rank = int(args.time_step_rank)
self.use_conv_bias = args.use_conv_bias
+ self.use_bcdt_rms = args.use_bcdt_rms
+ if self.use_bcdt_rms:
+ self.mixer_norm = lambda x: mx.fast.rms_norm(
+ x, mx.ones(x.shape[-1], x.dtype), eps=args.mixer_rms_eps
+ )
self.in_proj = nn.Linear(
self.hidden_size, self.intermediate_size * 2, bias=args.use_bias
@@ -126,6 +135,8 @@ class MambaBlock(nn.Module):
],
axis=-1,
)
+ if self.use_bcdt_rms:
+ delta, B, C = map(self.mixer_norm, (delta, B, C))
delta = nn.softplus(self.dt_proj(delta))
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
if state is not None:
diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py
index b9fc202d..7b440db6 100644
--- a/llms/mlx_lm/utils.py
+++ b/llms/mlx_lm/utils.py
@@ -29,6 +29,7 @@ from .tuner.utils import load_adapters
MODEL_REMAPPING = {
"mistral": "llama", # mistral is compatible with llama
"phi-msft": "phixtral",
+ "falcon_mamba": "mamba",
}
MAX_FILE_SIZE_GB = 5
From 4394633ce0f9d96cbbdf571e077fa4fd78479b9f Mon Sep 17 00:00:00 2001
From: Anthony Wu <462072+anthonywu@users.noreply.github.com>
Date: Mon, 4 Nov 2024 14:02:13 -0800
Subject: [PATCH 3/6] mlx_whisper: add support for audio input from stdin
(#1012)
* add support for audio and input name from stdin
* refactored to stdin - arg, and output-name template
* fix bugs, add test coverage
* fix doc to match arg rename
* some nits
---------
Co-authored-by: Awni Hannun
---
whisper/README.md | 13 +++++++++++--
whisper/mlx_whisper/audio.py | 18 ++++++++++--------
whisper/mlx_whisper/cli.py | 34 +++++++++++++++++++++++++++-------
whisper/mlx_whisper/writers.py | 14 +++++---------
4 files changed, 53 insertions(+), 26 deletions(-)
diff --git a/whisper/README.md b/whisper/README.md
index ac6e95f6..cd3bc684 100644
--- a/whisper/README.md
+++ b/whisper/README.md
@@ -25,7 +25,7 @@ pip install mlx-whisper
At its simplest:
-```
+```sh
mlx_whisper audio_file.mp3
```
@@ -35,6 +35,15 @@ Use `-f` to specify the output format and `--model` to specify the model. There
are many other supported command line options. To see them all, run
`mlx_whisper -h`.
+You can also pipe the audio content of other programs via stdin:
+
+```sh
+some-process | mlx_whisper -
+```
+
+The default output file name will be `content.*`. You can specify the name with
+the `--output-name` flag.
+
#### API
Transcribe audio with:
@@ -103,7 +112,7 @@ python convert.py --help
```
By default, the conversion script will make the directory `mlx_models`
-and save the converted `weights.npz` and `config.json` there.
+and save the converted `weights.npz` and `config.json` there.
Each time it is run, `convert.py` will overwrite any model in the provided
path. To save different models, make sure to set `--mlx-path` to a unique
diff --git a/whisper/mlx_whisper/audio.py b/whisper/mlx_whisper/audio.py
index e04309c1..c8cca07c 100644
--- a/whisper/mlx_whisper/audio.py
+++ b/whisper/mlx_whisper/audio.py
@@ -3,7 +3,7 @@
import os
from functools import lru_cache
from subprocess import CalledProcessError, run
-from typing import Union
+from typing import Optional, Union
import mlx.core as mx
import numpy as np
@@ -21,7 +21,7 @@ FRAMES_PER_SECOND = SAMPLE_RATE // HOP_LENGTH # 10ms per audio frame
TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token
-def load_audio(file: str, sr: int = SAMPLE_RATE):
+def load_audio(file: str = Optional[str], sr: int = SAMPLE_RATE, from_stdin=False):
"""
Open an audio file and read as mono waveform, resampling as necessary
@@ -39,19 +39,21 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
# This launches a subprocess to decode audio while down-mixing
- # and resampling as necessary. Requires the ffmpeg CLI in PATH.
+ # and resampling as necessary. Requires the ffmpeg CLI in PATH.
+ if from_stdin:
+ cmd = ["ffmpeg", "-i", "pipe:0"]
+ else:
+ cmd = ["ffmpeg", "-nostdin", "-i", file]
+
# fmt: off
- cmd = [
- "ffmpeg",
- "-nostdin",
+ cmd.extend([
"-threads", "0",
- "-i", file,
"-f", "s16le",
"-ac", "1",
"-acodec", "pcm_s16le",
"-ar", str(sr),
"-"
- ]
+ ])
# fmt: on
try:
out = run(cmd, capture_output=True, check=True).stdout
diff --git a/whisper/mlx_whisper/cli.py b/whisper/mlx_whisper/cli.py
index c2813338..7d08a043 100644
--- a/whisper/mlx_whisper/cli.py
+++ b/whisper/mlx_whisper/cli.py
@@ -2,9 +2,11 @@
import argparse
import os
+import pathlib
import traceback
import warnings
+from . import audio
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from .transcribe import transcribe
from .writers import get_writer
@@ -27,15 +29,24 @@ def build_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
- parser.add_argument(
- "audio", nargs="+", type=str, help="Audio file(s) to transcribe"
- )
+
+ parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe")
+
parser.add_argument(
"--model",
default="mlx-community/whisper-tiny",
type=str,
help="The model directory or hugging face repo",
)
+ parser.add_argument(
+ "--output-name",
+ type=str,
+ default=None,
+ help=(
+ "The name of transcription/translation output files before "
+ "--output-format extensions"
+ ),
+ )
parser.add_argument(
"--output-dir",
"-o",
@@ -200,6 +211,7 @@ def main():
path_or_hf_repo: str = args.pop("model")
output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format")
+ output_name: str = args.pop("output_name")
os.makedirs(output_dir, exist_ok=True)
writer = get_writer(output_format, output_dir)
@@ -219,17 +231,25 @@ def main():
warnings.warn("--max-line-count has no effect without --max-line-width")
if writer_args["max_words_per_line"] and writer_args["max_line_width"]:
warnings.warn("--max-words-per-line has no effect with --max-line-width")
- for audio_path in args.pop("audio"):
+
+ for audio_obj in args.pop("audio"):
+ if audio_obj == "-":
+ # receive the contents from stdin rather than read a file
+ audio_obj = audio.load_audio(from_stdin=True)
+
+ output_name = output_name or "content"
+ else:
+ output_name = output_name or pathlib.Path(audio_obj).stem
try:
result = transcribe(
- audio_path,
+ audio_obj,
path_or_hf_repo=path_or_hf_repo,
**args,
)
- writer(result, audio_path, **writer_args)
+ writer(result, output_name, **writer_args)
except Exception as e:
traceback.print_exc()
- print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
+ print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}")
if __name__ == "__main__":
diff --git a/whisper/mlx_whisper/writers.py b/whisper/mlx_whisper/writers.py
index 464ead18..cdb35063 100644
--- a/whisper/mlx_whisper/writers.py
+++ b/whisper/mlx_whisper/writers.py
@@ -1,10 +1,8 @@
# Copyright © 2024 Apple Inc.
import json
-import os
+import pathlib
import re
-import sys
-import zlib
from typing import Callable, List, Optional, TextIO
@@ -43,15 +41,13 @@ class ResultWriter:
self.output_dir = output_dir
def __call__(
- self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
+ self, result: dict, output_name: str, options: Optional[dict] = None, **kwargs
):
- audio_basename = os.path.basename(audio_path)
- audio_basename = os.path.splitext(audio_basename)[0]
- output_path = os.path.join(
- self.output_dir, audio_basename + "." + self.extension
+ output_path = (pathlib.Path(self.output_dir) / output_name).with_suffix(
+ f".{self.extension}"
)
- with open(output_path, "w", encoding="utf-8") as f:
+ with output_path.open("wt", encoding="utf-8") as f:
self.write_result(result, file=f, options=options, **kwargs)
def write_result(
From 6fd1f70f7366a1e55f14e2b4cd885b86875ab56c Mon Sep 17 00:00:00 2001
From: Awni Hannun
Date: Tue, 5 Nov 2024 06:06:26 -0800
Subject: [PATCH 4/6] fix spm decoder multi-byte (#1092)
---
llms/mlx_lm/tokenizer_utils.py | 40 +++++++++++++++-------------------
llms/tests/test_tokenizers.py | 3 +++
2 files changed, 20 insertions(+), 23 deletions(-)
diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py
index 568a672d..9d390733 100644
--- a/llms/mlx_lm/tokenizer_utils.py
+++ b/llms/mlx_lm/tokenizer_utils.py
@@ -6,12 +6,6 @@ from transformers import AutoTokenizer
REPLACEMENT_CHAR = "\ufffd"
-def _remove_space(x):
- if x and x[0] == " ":
- return x[1:]
- return x
-
-
class StreamingDetokenizer:
"""The streaming detokenizer interface so that we can detokenize one token at a time.
@@ -123,42 +117,42 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):
def __init__(self, tokenizer, trim_space=True):
self.trim_space = trim_space
+ self._sep = "\u2581".encode()
# Extract the tokens in a list from id to text
self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1)
for value, tokenid in tokenizer.vocab.items():
- self.tokenmap[tokenid] = value
-
- # Replace bytes with their value
- for i in range(len(self.tokenmap)):
- if self.tokenmap[i].startswith("<0x"):
- self.tokenmap[i] = chr(int(self.tokenmap[i][3:5], 16))
+ if value.startswith("<0x"):
+ # Replace bytes with their value
+ self.tokenmap[tokenid] = bytes([int(value[3:5], 16)])
+ else:
+ self.tokenmap[tokenid] = value.encode()
self.reset()
def reset(self):
self.offset = 0
- self._unflushed = ""
+ self._unflushed = b""
self.text = ""
self.tokens = []
+ def _flush(self):
+ text = self._unflushed.replace(self._sep, b" ").decode("utf-8")
+ if not self.text and self.trim_space and text and text[0] == " ":
+ text = text[1:]
+ self.text += text
+
def add_token(self, token):
v = self.tokenmap[token]
- if v[0] == "\u2581":
- if self.text or not self.trim_space:
- self.text += self._unflushed.replace("\u2581", " ")
- else:
- self.text = _remove_space(self._unflushed.replace("\u2581", " "))
+ if v.startswith(self._sep):
+ self._flush()
self._unflushed = v
else:
self._unflushed += v
def finalize(self):
- if self.text or not self.trim_space:
- self.text += self._unflushed.replace("\u2581", " ")
- else:
- self.text = _remove_space(self._unflushed.replace("\u2581", " "))
- self._unflushed = ""
+ self._flush()
+ self._unflushed = b""
class BPEStreamingDetokenizer(StreamingDetokenizer):
diff --git a/llms/tests/test_tokenizers.py b/llms/tests/test_tokenizers.py
index 3c93fbe2..9c30d51e 100644
--- a/llms/tests/test_tokenizers.py
+++ b/llms/tests/test_tokenizers.py
@@ -42,6 +42,9 @@ class TestTokenizers(unittest.TestCase):
text += detokenizer.last_segment
self.assertEqual(text, expected_text)
+ tokens = tokenizer.encode("こんにちは!私の名前はAI")
+ check(tokens)
+
tokens = tokenizer.encode("a ,b")
check(tokens)
From ed9e81dd581a9505e677e12c025137d5326fe6df Mon Sep 17 00:00:00 2001
From: Angelos Katharopoulos
Date: Tue, 5 Nov 2024 10:24:24 -0800
Subject: [PATCH 5/6] Fix rotating kv cache size (#1093)
---
llms/mlx_lm/models/base.py | 2 +-
llms/mlx_lm/models/cache.py | 4 ++--
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py
index cda41c79..f02f49b1 100644
--- a/llms/mlx_lm/models/base.py
+++ b/llms/mlx_lm/models/base.py
@@ -42,7 +42,7 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
if cache is not None and cache[0] is not None:
c = cache[0]
if hasattr(c, "max_size"):
- offset = min(c.max_size - 1, c.offset)
+ offset = min(c.max_size, c.offset)
window_size = c.max_size
else:
offset = c.offset
diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py
index 1cd5289d..14026f0c 100644
--- a/llms/mlx_lm/models/cache.py
+++ b/llms/mlx_lm/models/cache.py
@@ -325,9 +325,9 @@ class RotatingKVCache(_BaseCache):
self.keys = self._temporal_order(self.keys)
self.values = self._temporal_order(self.values)
- # The largest size is self.max_size + S - 1 to ensure
+ # The largest size is self.max_size + S to ensure
# every token gets at least self.max_size context
- trim_size = self._idx - self.max_size + 1
+ trim_size = self._idx - self.max_size
self.keys = self._trim(trim_size, self.keys, keys)
self.values = self._trim(trim_size, self.values, values)
self.offset += keys.shape[2]
From 657b4cc0aa90af09ac9793168cb81d406db882c6 Mon Sep 17 00:00:00 2001
From: Awni Hannun
Date: Thu, 7 Nov 2024 16:15:24 -0800
Subject: [PATCH 6/6] [MLX LM] Sampler refactor + a few improvements (#1094)
* starting
* refactor sampler/processor and a few improvements
* fix stream
* fix stream generate
* fix eos handling in stream generate
---
llms/README.md | 5 +-
llms/mlx_lm/cache_prompt.py | 4 +-
llms/mlx_lm/chat.py | 2 +-
llms/mlx_lm/generate.py | 14 +++
llms/mlx_lm/sample_utils.py | 106 ++++++++++++++++++
llms/mlx_lm/server.py | 193 ++++++++++++--------------------
llms/mlx_lm/tuner/trainer.py | 2 +-
llms/mlx_lm/utils.py | 168 ++++++++++-----------------
llms/tests/test_generate.py | 2 +-
llms/tests/test_prompt_cache.py | 2 +-
10 files changed, 259 insertions(+), 239 deletions(-)
diff --git a/llms/README.md b/llms/README.md
index 0e7dc7fb..eeb3ed6a 100644
--- a/llms/README.md
+++ b/llms/README.md
@@ -101,7 +101,8 @@ To see a description of all the arguments you can do:
#### Streaming
For streaming generation, use the `stream_generate` function. This returns a
-generator object which streams the output text. For example,
+generator object which streams the output text, token, and log probabilities.
+For example,
```python
from mlx_lm import load, stream_generate
@@ -116,7 +117,7 @@ prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
-for t in stream_generate(model, tokenizer, prompt, max_tokens=512):
+for text, *_ in stream_generate(model, tokenizer, prompt, max_tokens=512):
print(t, end="", flush=True)
print()
```
diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py
index 7bb06411..987b640d 100644
--- a/llms/mlx_lm/cache_prompt.py
+++ b/llms/mlx_lm/cache_prompt.py
@@ -152,6 +152,7 @@ def main():
model(y[:step_size][None], cache=cache)
mx.eval([c.state for c in cache])
+ mx.metal.clear_cache()
processed += min(y.size, step_size)
y = y[step_size:]
current = time.time()
@@ -165,14 +166,13 @@ def main():
)
print()
- print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
+ print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB")
print("Saving...")
metadata = {}
metadata["model"] = args.model
metadata["chat_template"] = tokenizer.chat_template
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
- print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
save_prompt_cache(args.prompt_cache_file, cache, metadata)
diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py
index 85d32d5f..c03056a6 100644
--- a/llms/mlx_lm/chat.py
+++ b/llms/mlx_lm/chat.py
@@ -74,7 +74,7 @@ def main():
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
- for response in stream_generate(
+ for response, *_ in stream_generate(
model,
tokenizer,
prompt,
diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py
index 29976da2..51169def 100644
--- a/llms/mlx_lm/generate.py
+++ b/llms/mlx_lm/generate.py
@@ -13,6 +13,8 @@ DEFAULT_PROMPT = "hello"
DEFAULT_MAX_TOKENS = 100
DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0
+DEFAULT_MIN_P = 0.0
+DEFAULT_MIN_TOKENS_TO_KEEP = 1
DEFAULT_SEED = 0
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
DEFAULT_QUANTIZED_KV_START = 5000
@@ -52,6 +54,7 @@ def setup_arg_parser():
)
parser.add_argument(
"--prompt",
+ "-p",
default=DEFAULT_PROMPT,
help="Message to be processed by the model ('-' reads from stdin)",
)
@@ -68,6 +71,15 @@ def setup_arg_parser():
parser.add_argument(
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
)
+ parser.add_argument(
+ "--min-p", type=float, default=DEFAULT_MIN_P, help="Sampling min-p"
+ )
+ parser.add_argument(
+ "--min-tokens-to-keep",
+ type=float,
+ default=DEFAULT_MIN_TOKENS_TO_KEEP,
+ help="Minimum tokens to keep for min-p sampling.",
+ )
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
parser.add_argument(
"--ignore-chat-template",
@@ -247,6 +259,8 @@ def main():
formatter=formatter,
temp=args.temp,
top_p=args.top_p,
+ min_p=args.min_p,
+ min_tokens_to_keep=args.min_tokens_to_keep,
max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None,
kv_bits=args.kv_bits,
diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py
index 20b008fa..c27b52d8 100644
--- a/llms/mlx_lm/sample_utils.py
+++ b/llms/mlx_lm/sample_utils.py
@@ -1,10 +1,83 @@
# Copyright © 2023-2024 Apple Inc.
from functools import partial
+from typing import Callable, Dict, Optional
import mlx.core as mx
+def make_sampler(
+ temp: float = 0.0,
+ top_p: float = 0.0,
+ min_p: float = 0.0,
+ min_tokens_to_keep: int = 1,
+) -> Callable[mx.array, mx.array]:
+ """
+ Make a sampler function for use with ``generate_step``.
+
+ Args:
+ temp (float): The temperature for sampling, if 0 the argmax is used.
+ Default: ``0``.
+ top_p (float, optional): Nulceus sampling, higher means model considers
+ more less likely words.
+ min_p (float, optional): The minimum value (scaled by the top token's
+ probability) that a token probability must have to be considered.
+ min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
+ be filtered by min_p sampling.
+
+ Returns:
+ Callable[mx.array, mx.array]:
+ A sampler which takes log-probabilities and returns tokens.
+ """
+ if temp == 0:
+ return lambda x: mx.argmax(x, axis=-1)
+ elif top_p > 0 and top_p < 1.0:
+ return lambda x: top_p_sampling(x, top_p, temp)
+ elif min_p != 0.0:
+ return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp)
+ else:
+ return lambda x: categorical_sampling(x, temp)
+
+
+def make_logits_processors(
+ logit_bias: Optional[Dict[int, float]] = None,
+ repetition_penalty: Optional[float] = None,
+ repetition_context_size: Optional[int] = 20,
+):
+ """
+ Make logits processors for use with ``generate_step``.
+
+ Args:
+ repetition_penalty (float, optional): The penalty factor for repeating
+ tokens.
+ repetition_context_size (int, optional): The number of tokens to
+ consider for repetition penalty. Default: ``20``.
+ logit_bias (dictionary, optional): Additive logit bias.
+
+ Returns:
+ List[Callable[[mx.array, mx.array], mx.array]]:
+ A list of logits processors. Each processor in the list is a
+ callable which takes an array of tokens and an array of logits
+ and returns the updated logits.
+ """
+ logits_processors = []
+ if logit_bias:
+ indices = mx.array(list(logit_bias.keys()))
+ values = mx.array(list(logit_bias.values()))
+
+ def logit_bias_processor(_, logits):
+ logits[:, indices] += values
+ return logits
+
+ logits_processors.append(logit_bias_processor)
+
+ if repetition_penalty and repetition_penalty != 0.0:
+ logits_processors.append(
+ make_repetition_penalty(repetition_penalty, repetition_context_size)
+ )
+ return logits_processors
+
+
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def min_p_sampling(
logits: mx.array,
@@ -100,3 +173,36 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def categorical_sampling(logits, temp):
return mx.random.categorical(logits * (1 / temp))
+
+
+def make_repetition_penalty(penalty: float, context_size: int = 20):
+ """
+ Make repetition penalty processor.
+
+ Paper: https://arxiv.org/abs/1909.05858
+
+ Args:
+ penalty (float): The repetition penalty factor to be applied.
+ context_size (int): The number of previous tokens to use.
+ Default: ``20``.
+
+ Returns:
+ Callable[[mx.array, List[int]], mx.array]:
+ The repetition penalty processor.
+ """
+ if penalty < 0 or not isinstance(penalty, float):
+ raise ValueError(f"penalty must be a non-negative float, got {penalty}")
+
+ def repetition_penalty_processor(tokens, logits):
+ if len(tokens) > 0:
+ tokens = tokens[-context_size:]
+ selected_logits = logits[:, tokens]
+ selected_logits = mx.where(
+ selected_logits < 0,
+ selected_logits * penalty,
+ selected_logits / penalty,
+ )
+ logits[:, tokens] = selected_logits
+ return logits
+
+ return repetition_penalty_processor
diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py
index ec659969..c1365b36 100644
--- a/llms/mlx_lm/server.py
+++ b/llms/mlx_lm/server.py
@@ -27,7 +27,7 @@ from huggingface_hub import scan_cache_dir
from ._version import __version__
from .models.cache import make_prompt_cache
-from .utils import generate_step, load
+from .utils import load, stream_generate
def get_system_fingerprint():
@@ -64,7 +64,7 @@ def stopping_criteria(
end if it has (`trim_length`).
"""
if tokens and tokens[-1] == eos_token_id:
- return StopCondition(stop_met=True, trim_length=1)
+ return StopCondition(stop_met=True, trim_length=0)
for stop_ids in stop_id_sequences:
if len(tokens) >= len(stop_ids):
@@ -253,7 +253,7 @@ class APIHandler(BaseHTTPRequestHandler):
self.max_tokens = self.body.get("max_completion_tokens", None)
if self.max_tokens is None:
self.max_tokens = self.body.get("max_tokens", 512)
- self.temperature = self.body.get("temperature", 1.0)
+ self.temperature = self.body.get("temperature", 0.0)
self.top_p = self.body.get("top_p", 1.0)
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
self.repetition_context_size = self.body.get("repetition_context_size", 20)
@@ -290,10 +290,7 @@ class APIHandler(BaseHTTPRequestHandler):
# Call endpoint specific method
prompt = endpoints[self.path]()
-
- # Call method based on response type
- method = self.handle_stream if self.stream else self.handle_completion
- method(prompt, stop_id_sequences)
+ self.handle_completion(prompt, stop_id_sequences)
def validate_model_parameters(self):
"""
@@ -452,32 +449,40 @@ class APIHandler(BaseHTTPRequestHandler):
stop_id_sequences (List[List[int]]): A list of stop words passed
to the stopping_criteria function
"""
- detokenizer = self.tokenizer.detokenizer
- detokenizer.reset()
tokens = []
finish_reason = "length"
stop_sequence_suffix = None
- logging.debug(f"Starting completion:")
+ if self.stream:
+ self.end_headers()
+ logging.debug(f"Starting stream:")
+ else:
+ logging.debug(f"Starting completion:")
token_logprobs = []
top_tokens = []
prompt = self.get_prompt_cache(prompt)
- for _, (token, logprobs) in zip(
- range(self.max_tokens),
- generate_step(
- prompt=mx.array(prompt),
+ text = ""
+ tic = time.perf_counter()
+ for n, (segment, token, logprobs) in enumerate(
+ stream_generate(
model=self.model,
+ tokenizer=self.tokenizer,
+ prompt=prompt,
+ max_tokens=self.max_tokens,
temp=self.temperature,
- top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size,
logit_bias=self.logit_bias,
prompt_cache=self.prompt_cache.cache,
),
):
- detokenizer.add_token(token)
- logging.debug(detokenizer.text)
+ if n == 0:
+ prompt_time = time.perf_counter() - tic
+ tic = time.perf_counter()
+
+ text += segment
+ logging.debug(text)
tokens.append(token)
if self.logprobs > 0:
@@ -498,121 +503,63 @@ class APIHandler(BaseHTTPRequestHandler):
stop_sequence_suffix = self.tokenizer.decode(
tokens[-stop_condition.trim_length :]
)
+ text = text[: -len(stop_sequence_suffix)]
break
- self.prompt_cache.tokens.extend(tokens)
- detokenizer.finalize()
- text = (
- detokenizer.text
- if stop_sequence_suffix is None
- else detokenizer.text[: -len(stop_sequence_suffix)]
- )
- response = self.generate_response(
- text,
- finish_reason,
- len(prompt),
- len(tokens),
- token_logprobs=token_logprobs,
- top_tokens=top_tokens,
- tokens=tokens,
- )
-
- response_json = json.dumps(response).encode()
- indent = "\t" # Backslashes can't be inside of f-strings
- logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}")
-
- # Send an additional Content-Length header when it is known
- self.send_header("Content-Length", str(len(response_json)))
- self.end_headers()
-
- self.wfile.write(response_json)
- self.wfile.flush()
-
- def handle_stream(
- self,
- prompt: List[int],
- stop_id_sequences: List[List[int]],
- ):
- """
- Generate response to prompt and foward it to the client using a Server
- Sent Events (SSE) stream.
-
- Args:
- prompt (mx.array): The tokenized prompt
- stop_id_sequences (List[List[int]]): A list of stop words passed to
- the stopping_criteria function
- """
- # No additional headers are needed, call end_headers
- self.end_headers()
-
- detokenizer = self.tokenizer.detokenizer
- detokenizer.reset()
- tokens = []
-
- stop_sequence_suffix = None
- logging.debug(f"Starting stream:")
-
- prompt = self.get_prompt_cache(prompt)
-
- for _, (token, _) in zip(
- range(self.max_tokens),
- generate_step(
- prompt=mx.array(prompt),
- model=self.model,
- temp=self.temperature,
- top_p=self.top_p,
- repetition_penalty=self.repetition_penalty,
- repetition_context_size=self.repetition_context_size,
- prompt_cache=self.prompt_cache.cache,
- ),
- ):
- detokenizer.add_token(token)
- logging.debug(detokenizer.text)
- tokens.append(token)
-
- stop_condition = stopping_criteria(
- tokens,
- stop_id_sequences,
- self.tokenizer.eos_token_id,
- )
- if stop_condition.stop_met:
- if stop_condition.trim_length:
- stop_sequence_suffix = self.tokenizer.decode(
- tokens[-stop_condition.trim_length :]
+ if self.stream:
+ # If the end of tokens overlaps with a stop sequence, generate new
+ # tokens until we know if the stop sequence is hit or not
+ if any(
+ (
+ sequence_overlap(tokens, sequence)
+ for sequence in stop_id_sequences
)
- break
-
- # If the end of tokens overlaps with a stop sequence, generate new
- # tokens until we know if the stop sequence is hit or not
- if any(
- (sequence_overlap(tokens, sequence) for sequence in stop_id_sequences)
- ):
- continue
-
- new_text = detokenizer.last_segment
- if new_text:
- response = self.generate_response(new_text, None)
- self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
- self.wfile.flush()
+ ):
+ continue
+ elif segment:
+ response = self.generate_response(segment, None)
+ self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
+ self.wfile.flush()
self.prompt_cache.tokens.extend(tokens)
- # check is there any remaining text to send
- detokenizer.finalize()
- last_segment = detokenizer.last_segment
- if last_segment:
- if stop_sequence_suffix is not None:
- last_segment = last_segment[: -len(stop_sequence_suffix)]
- response = self.generate_response(last_segment, "length")
+ gen_time = time.perf_counter() - tic
+ prompt_tps = len(prompt) / prompt_time
+ gen_tps = len(tokens) / gen_time
+ peak_mem = mx.metal.get_peak_memory() / 1e9
+ logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
+ logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec")
+ logging.debug(f"Peak memory: {peak_mem:.3f} GB")
+
+ if self.stream:
+ response = self.generate_response(segment, finish_reason)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
+ if self.stream_options is not None and self.stream_options["include_usage"]:
+ response = self.completion_usage_response(len(prompt), len(tokens))
+ self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
+ self.wfile.flush()
+ self.wfile.write("data: [DONE]\n\n".encode())
+ self.wfile.flush()
+ else:
+ response = self.generate_response(
+ text,
+ finish_reason,
+ len(prompt),
+ len(tokens),
+ token_logprobs=token_logprobs,
+ top_tokens=top_tokens,
+ tokens=tokens,
+ )
+ response_json = json.dumps(response).encode()
+ indent = "\t" # Backslashes can't be inside of f-strings
+ logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}")
- if self.stream_options is not None and self.stream_options["include_usage"]:
- response = self.completion_usage_response(len(prompt), len(tokens))
- self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
-
- self.wfile.write("data: [DONE]\n\n".encode())
- self.wfile.flush()
+ # Send an additional Content-Length header when it is known
+ self.send_header("Content-Length", str(len(response_json)))
+ self.end_headers()
+ self.wfile.write(response_json)
+ self.wfile.flush()
def completion_usage_response(
self,
diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py
index 38619d95..21b1af18 100644
--- a/llms/mlx_lm/tuner/trainer.py
+++ b/llms/mlx_lm/tuner/trainer.py
@@ -285,7 +285,7 @@ def train(
it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start)
trained_tokens += n_tokens
- peak_mem = mx.metal.get_peak_memory() / 2**30
+ peak_mem = mx.metal.get_peak_memory() / 1e9
if rank == 0:
print(
f"Iter {it}: Train loss {train_loss:.3f}, "
diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py
index 7b440db6..8893b570 100644
--- a/llms/mlx_lm/utils.py
+++ b/llms/mlx_lm/utils.py
@@ -20,7 +20,7 @@ from transformers import PreTrainedTokenizer
# Local imports
from .models import cache
-from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
+from .sample_utils import make_logits_processors, make_sampler
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import dequantize as dequantize_model
from .tuner.utils import load_adapters
@@ -34,6 +34,9 @@ MODEL_REMAPPING = {
MAX_FILE_SIZE_GB = 5
+# A stream on the default device just for generation
+generation_stream = mx.new_stream(mx.default_device())
+
class ModelNotFoundError(Exception):
def __init__(self, message):
@@ -137,29 +140,6 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
return model_path
-def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float):
- """
- Apply repetition penalty to specific logits based on the given context.
-
- Paper: https://arxiv.org/abs/1909.05858
-
- Args:
- logits (mx.array): The logits produced by the language model.
- tokens (mx.array): A list of N previous tokens.
- penalty (float): The repetition penalty factor to be applied.
-
- Returns:
- logits (mx.array): Logits with repetition penalty applied to generated tokens.
- """
- if len(tokens) > 0:
- selected_logits = logits[:, tokens]
- selected_logits = mx.where(
- selected_logits < 0, selected_logits * penalty, selected_logits / penalty
- )
- logits[:, tokens] = selected_logits
- return logits
-
-
def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
if (
kv_bits is not None
@@ -185,7 +165,7 @@ def generate_step(
max_kv_size: Optional[int] = None,
prompt_cache: Optional[Any] = None,
logit_bias: Optional[Dict[int, float]] = None,
- logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
+ logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
@@ -214,7 +194,7 @@ def generate_step(
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place.
logit_bias (dictionary, optional): Additive logit bias.
- logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
+ logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.
kv_bits (int, optional): Number of bits to use for KV cache quantization.
@@ -224,53 +204,9 @@ def generate_step(
when ``kv_bits`` is non-None. Default: ``0``.
Yields:
- Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
- one token and a vector of log probabilities.
+ Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
"""
- def sample(logits: mx.array) -> Tuple[mx.array, float]:
- logprobs = logits - mx.logsumexp(logits)
-
- if temp == 0:
- token = mx.argmax(logits, axis=-1)
- else:
- if top_p > 0 and top_p < 1.0:
- token = top_p_sampling(logits, top_p, temp)
- elif min_p != 0.0:
- token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp)
- else:
- token = categorical_sampling(logits, temp)
-
- return token, logprobs
-
- if repetition_penalty and (
- repetition_penalty < 0 or not isinstance(repetition_penalty, float)
- ):
- raise ValueError(
- f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
- )
-
- logits_processor = logits_processor or []
-
- if repetition_penalty:
-
- def repetition_penalty_processor(tokens, logits):
- return apply_repetition_penalty(
- logits, tokens[-repetition_context_size:], repetition_penalty
- )
-
- logits_processor.append(repetition_penalty_processor)
-
- if logit_bias:
- indices = mx.array(list(logit_bias.keys()))
- values = mx.array(list(logit_bias.values()))
-
- def logit_bias_processor(_, logits):
- logits[:, indices] += values
- return logits
-
- logits_processor.append(logit_bias_processor)
-
y = prompt
tokens = None
@@ -283,24 +219,31 @@ def generate_step(
elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.")
+ sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep)
+ logits_processors = logits_processors or []
+ logits_processors.extend(
+ make_logits_processors(logit_bias, repetition_penalty, repetition_context_size)
+ )
+
def _step(y):
+ with mx.stream(generation_stream):
+ logits = model(y[None], cache=prompt_cache)
+ logits = logits[:, -1, :]
- logits = model(y[None], cache=prompt_cache)
- logits = logits[:, -1, :]
+ if logits_processors:
+ nonlocal tokens
+ tokens = mx.concat([tokens, y]) if tokens is not None else y
- if logits_processor:
- nonlocal tokens
- tokens = mx.concat([tokens, y]) if tokens is not None else y
+ for processor in logits_processors:
+ logits = processor(tokens, logits)
- for processor in logits_processor:
- logits = processor(tokens, logits)
+ maybe_quantize_kv_cache(
+ prompt_cache, quantized_kv_start, kv_group_size, kv_bits
+ )
- maybe_quantize_kv_cache(
- prompt_cache, quantized_kv_start, kv_group_size, kv_bits
- )
-
- y, logprobs = sample(logits)
- return y, logprobs.squeeze(0)
+ logprobs = logits - mx.logsumexp(logits, keepdims=True)
+ y = sampler(logprobs)
+ return y, logprobs.squeeze(0)
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=prompt_cache)
@@ -325,43 +268,51 @@ def generate_step(
def stream_generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
- prompt: str,
+ prompt: Union[str, List[int]],
max_tokens: int = 100,
**kwargs,
-) -> Union[str, Generator[str, None, None]]:
+) -> Generator[Tuple[str, int, mx.array], None, None]:
"""
A generator producing text based on the given prompt from the model.
Args:
- prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
- max_tokens (int): The ma
+ tokenizer (PreTrainedTokenizer): The tokenizer.
+ prompt (Union[str, List[int]]): The input prompt string or integer tokens.
+ max_tokens (int): The maximum number of tokens. Default: ``100``.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
Yields:
- Generator[Tuple[mx.array, mx.array]]: A generator producing text.
+ Tuple[str, int, mx.array]:
+ The next text segment, token, and vector of log probabilities.
"""
if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
- prompt_tokens = mx.array(tokenizer.encode(prompt))
+ prompt_tokens = mx.array(
+ prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
+ )
detokenizer = tokenizer.detokenizer
- detokenizer.reset()
- for n, (token, _) in zip(
- range(max_tokens),
- generate_step(prompt_tokens, model, **kwargs),
- ):
- if token == tokenizer.eos_token_id:
- break
- detokenizer.add_token(token)
+ with wired_limit(model, [generation_stream]):
+ detokenizer.reset()
+ for n, (token, logits) in zip(
+ range(max_tokens),
+ generate_step(prompt_tokens, model, **kwargs),
+ ):
+ if token == tokenizer.eos_token_id:
+ break
- # Yield the last segment if streaming
- yield detokenizer.last_segment
+ detokenizer.add_token(token)
- detokenizer.finalize()
- yield detokenizer.last_segment
+ if n == (max_tokens - 1):
+ break
+
+ yield detokenizer.last_segment, token, logits
+
+ detokenizer.finalize()
+ yield detokenizer.last_segment, token, logits
def generate(
@@ -372,7 +323,7 @@ def generate(
verbose: bool = False,
formatter: Optional[Callable] = None,
**kwargs,
-) -> Union[str, Generator[str, None, None]]:
+) -> str:
"""
Generate a complete response from the model.
@@ -398,7 +349,7 @@ def generate(
prompt_tokens = mx.array(tokenizer.encode(prompt))
detokenizer = tokenizer.detokenizer
- with wired_limit(model):
+ with wired_limit(model, [generation_stream]):
tic = time.perf_counter()
detokenizer.reset()
for n, (token, logprobs) in zip(
@@ -416,8 +367,7 @@ def generate(
if formatter:
# We have to finalize so that the prob corresponds to the last segment
detokenizer.finalize()
- with mx.stream(mx.cpu):
- prob = mx.exp(logprobs[token]).item()
+ prob = mx.exp(logprobs[token]).item()
formatter(detokenizer.last_segment, prob)
else:
print(detokenizer.last_segment, end="", flush=True)
@@ -438,7 +388,7 @@ def generate(
f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec"
)
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
- peak_mem = mx.metal.get_peak_memory() / 2**30
+ peak_mem = mx.metal.get_peak_memory() / 1e9
print(f"Peak memory: {peak_mem:.3f} GB")
return detokenizer.text
@@ -623,7 +573,9 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
f"""
# {upload_repo}
- The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) using mlx-lm version **{__version__}**.
+ The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was
+ converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path})
+ using mlx-lm version **{__version__}**.
## Use with mlx
diff --git a/llms/tests/test_generate.py b/llms/tests/test_generate.py
index 68f1670b..e0a372a9 100644
--- a/llms/tests/test_generate.py
+++ b/llms/tests/test_generate.py
@@ -46,7 +46,7 @@ class TestGenerate(unittest.TestCase):
"hello",
max_tokens=5,
verbose=False,
- logits_processor=[logits_processor],
+ logits_processors=[logits_processor],
)
self.assertEqual(len(all_toks), len(init_toks) + 5)
diff --git a/llms/tests/test_prompt_cache.py b/llms/tests/test_prompt_cache.py
index 1e57bd86..0867ab56 100644
--- a/llms/tests/test_prompt_cache.py
+++ b/llms/tests/test_prompt_cache.py
@@ -299,7 +299,7 @@ class TestPromptCache(unittest.TestCase):
):
i += 1
self.assertEqual(tok, toks[i])
- self.assertTrue(mx.allclose(logits, all_logits[i], rtol=1e-2))
+ self.assertTrue(mx.allclose(logits, all_logits[i], rtol=2e-2))
if __name__ == "__main__":