From 82e333898707eb57235f408aa6907beca095f759 Mon Sep 17 00:00:00 2001
From: Anchen
Date: Mon, 4 Nov 2024 22:06:34 +0800
Subject: [PATCH 1/3] chore(mlx-lm): add max token arg for mlx_lm.chat (#1089)
* chore(mlx-lm): add max token arg for mlx_lm.chat
* chore: update the default max token value
---
llms/mlx_lm/chat.py | 9 +++++++++
1 file changed, 9 insertions(+)
diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py
index ea1a99c7..85d32d5f 100644
--- a/llms/mlx_lm/chat.py
+++ b/llms/mlx_lm/chat.py
@@ -11,6 +11,7 @@ from .utils import load, stream_generate
DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
+DEFAULT_MAX_TOKENS = 256
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
@@ -41,6 +42,13 @@ def setup_arg_parser():
help="Set the maximum key-value cache size",
default=None,
)
+ parser.add_argument(
+ "--max-tokens",
+ "-m",
+ type=int,
+ default=DEFAULT_MAX_TOKENS,
+ help="Maximum number of tokens to generate",
+ )
return parser
@@ -70,6 +78,7 @@ def main():
model,
tokenizer,
prompt,
+ args.max_tokens,
temp=args.temp,
top_p=args.top_p,
prompt_cache=prompt_cache,
From 3b526f0aa1219fae662a86f012dbda82045f4fb0 Mon Sep 17 00:00:00 2001
From: ilyasch2 <104485953+ilyasch2@users.noreply.github.com>
Date: Tue, 5 Nov 2024 00:23:30 +0400
Subject: [PATCH 2/3] Add support for falcon-mamba (#1074)
* Add support for falcon-mamba
* nits
* nit
---------
Co-authored-by: Awni Hannun
---
llms/README.md | 1 +
llms/mlx_lm/models/mamba.py | 11 +++++++++++
llms/mlx_lm/utils.py | 1 +
3 files changed, 13 insertions(+)
diff --git a/llms/README.md b/llms/README.md
index f539988a..0e7dc7fb 100644
--- a/llms/README.md
+++ b/llms/README.md
@@ -221,6 +221,7 @@ Here are a few examples of Hugging Face models that work with this example:
- [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct)
- [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b)
- [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b)
+- [tiiuae/falcon-mamba-7b-instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct)
Most
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),
diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py
index 84f498e9..f2414660 100644
--- a/llms/mlx_lm/models/mamba.py
+++ b/llms/mlx_lm/models/mamba.py
@@ -23,6 +23,8 @@ class ModelArgs(BaseModelArgs):
use_conv_bias: bool
time_step_rank: int
tie_word_embeddings: bool = True
+ use_bcdt_rms: bool = False
+ mixer_rms_eps: float = 1e-6
def __post_init__(self):
if not hasattr(self, "hidden_size") and hasattr(self, "d_model"):
@@ -44,6 +46,8 @@ class ModelArgs(BaseModelArgs):
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
+ if self.model_type == "falcon_mamba":
+ self.use_bcdt_rms = True
class DepthWiseConv1d(nn.Module):
@@ -83,6 +87,11 @@ class MambaBlock(nn.Module):
self.intermediate_size = args.intermediate_size
self.time_step_rank = int(args.time_step_rank)
self.use_conv_bias = args.use_conv_bias
+ self.use_bcdt_rms = args.use_bcdt_rms
+ if self.use_bcdt_rms:
+ self.mixer_norm = lambda x: mx.fast.rms_norm(
+ x, mx.ones(x.shape[-1], x.dtype), eps=args.mixer_rms_eps
+ )
self.in_proj = nn.Linear(
self.hidden_size, self.intermediate_size * 2, bias=args.use_bias
@@ -126,6 +135,8 @@ class MambaBlock(nn.Module):
],
axis=-1,
)
+ if self.use_bcdt_rms:
+ delta, B, C = map(self.mixer_norm, (delta, B, C))
delta = nn.softplus(self.dt_proj(delta))
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
if state is not None:
diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py
index b9fc202d..7b440db6 100644
--- a/llms/mlx_lm/utils.py
+++ b/llms/mlx_lm/utils.py
@@ -29,6 +29,7 @@ from .tuner.utils import load_adapters
MODEL_REMAPPING = {
"mistral": "llama", # mistral is compatible with llama
"phi-msft": "phixtral",
+ "falcon_mamba": "mamba",
}
MAX_FILE_SIZE_GB = 5
From 4394633ce0f9d96cbbdf571e077fa4fd78479b9f Mon Sep 17 00:00:00 2001
From: Anthony Wu <462072+anthonywu@users.noreply.github.com>
Date: Mon, 4 Nov 2024 14:02:13 -0800
Subject: [PATCH 3/3] mlx_whisper: add support for audio input from stdin
(#1012)
* add support for audio and input name from stdin
* refactored to stdin - arg, and output-name template
* fix bugs, add test coverage
* fix doc to match arg rename
* some nits
---------
Co-authored-by: Awni Hannun
---
whisper/README.md | 13 +++++++++++--
whisper/mlx_whisper/audio.py | 18 ++++++++++--------
whisper/mlx_whisper/cli.py | 34 +++++++++++++++++++++++++++-------
whisper/mlx_whisper/writers.py | 14 +++++---------
4 files changed, 53 insertions(+), 26 deletions(-)
diff --git a/whisper/README.md b/whisper/README.md
index ac6e95f6..cd3bc684 100644
--- a/whisper/README.md
+++ b/whisper/README.md
@@ -25,7 +25,7 @@ pip install mlx-whisper
At its simplest:
-```
+```sh
mlx_whisper audio_file.mp3
```
@@ -35,6 +35,15 @@ Use `-f` to specify the output format and `--model` to specify the model. There
are many other supported command line options. To see them all, run
`mlx_whisper -h`.
+You can also pipe the audio content of other programs via stdin:
+
+```sh
+some-process | mlx_whisper -
+```
+
+The default output file name will be `content.*`. You can specify the name with
+the `--output-name` flag.
+
#### API
Transcribe audio with:
@@ -103,7 +112,7 @@ python convert.py --help
```
By default, the conversion script will make the directory `mlx_models`
-and save the converted `weights.npz` and `config.json` there.
+and save the converted `weights.npz` and `config.json` there.
Each time it is run, `convert.py` will overwrite any model in the provided
path. To save different models, make sure to set `--mlx-path` to a unique
diff --git a/whisper/mlx_whisper/audio.py b/whisper/mlx_whisper/audio.py
index e04309c1..c8cca07c 100644
--- a/whisper/mlx_whisper/audio.py
+++ b/whisper/mlx_whisper/audio.py
@@ -3,7 +3,7 @@
import os
from functools import lru_cache
from subprocess import CalledProcessError, run
-from typing import Union
+from typing import Optional, Union
import mlx.core as mx
import numpy as np
@@ -21,7 +21,7 @@ FRAMES_PER_SECOND = SAMPLE_RATE // HOP_LENGTH # 10ms per audio frame
TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token
-def load_audio(file: str, sr: int = SAMPLE_RATE):
+def load_audio(file: str = Optional[str], sr: int = SAMPLE_RATE, from_stdin=False):
"""
Open an audio file and read as mono waveform, resampling as necessary
@@ -39,19 +39,21 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
# This launches a subprocess to decode audio while down-mixing
- # and resampling as necessary. Requires the ffmpeg CLI in PATH.
+ # and resampling as necessary. Requires the ffmpeg CLI in PATH.
+ if from_stdin:
+ cmd = ["ffmpeg", "-i", "pipe:0"]
+ else:
+ cmd = ["ffmpeg", "-nostdin", "-i", file]
+
# fmt: off
- cmd = [
- "ffmpeg",
- "-nostdin",
+ cmd.extend([
"-threads", "0",
- "-i", file,
"-f", "s16le",
"-ac", "1",
"-acodec", "pcm_s16le",
"-ar", str(sr),
"-"
- ]
+ ])
# fmt: on
try:
out = run(cmd, capture_output=True, check=True).stdout
diff --git a/whisper/mlx_whisper/cli.py b/whisper/mlx_whisper/cli.py
index c2813338..7d08a043 100644
--- a/whisper/mlx_whisper/cli.py
+++ b/whisper/mlx_whisper/cli.py
@@ -2,9 +2,11 @@
import argparse
import os
+import pathlib
import traceback
import warnings
+from . import audio
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from .transcribe import transcribe
from .writers import get_writer
@@ -27,15 +29,24 @@ def build_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
- parser.add_argument(
- "audio", nargs="+", type=str, help="Audio file(s) to transcribe"
- )
+
+ parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe")
+
parser.add_argument(
"--model",
default="mlx-community/whisper-tiny",
type=str,
help="The model directory or hugging face repo",
)
+ parser.add_argument(
+ "--output-name",
+ type=str,
+ default=None,
+ help=(
+ "The name of transcription/translation output files before "
+ "--output-format extensions"
+ ),
+ )
parser.add_argument(
"--output-dir",
"-o",
@@ -200,6 +211,7 @@ def main():
path_or_hf_repo: str = args.pop("model")
output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format")
+ output_name: str = args.pop("output_name")
os.makedirs(output_dir, exist_ok=True)
writer = get_writer(output_format, output_dir)
@@ -219,17 +231,25 @@ def main():
warnings.warn("--max-line-count has no effect without --max-line-width")
if writer_args["max_words_per_line"] and writer_args["max_line_width"]:
warnings.warn("--max-words-per-line has no effect with --max-line-width")
- for audio_path in args.pop("audio"):
+
+ for audio_obj in args.pop("audio"):
+ if audio_obj == "-":
+ # receive the contents from stdin rather than read a file
+ audio_obj = audio.load_audio(from_stdin=True)
+
+ output_name = output_name or "content"
+ else:
+ output_name = output_name or pathlib.Path(audio_obj).stem
try:
result = transcribe(
- audio_path,
+ audio_obj,
path_or_hf_repo=path_or_hf_repo,
**args,
)
- writer(result, audio_path, **writer_args)
+ writer(result, output_name, **writer_args)
except Exception as e:
traceback.print_exc()
- print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
+ print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}")
if __name__ == "__main__":
diff --git a/whisper/mlx_whisper/writers.py b/whisper/mlx_whisper/writers.py
index 464ead18..cdb35063 100644
--- a/whisper/mlx_whisper/writers.py
+++ b/whisper/mlx_whisper/writers.py
@@ -1,10 +1,8 @@
# Copyright © 2024 Apple Inc.
import json
-import os
+import pathlib
import re
-import sys
-import zlib
from typing import Callable, List, Optional, TextIO
@@ -43,15 +41,13 @@ class ResultWriter:
self.output_dir = output_dir
def __call__(
- self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
+ self, result: dict, output_name: str, options: Optional[dict] = None, **kwargs
):
- audio_basename = os.path.basename(audio_path)
- audio_basename = os.path.splitext(audio_basename)[0]
- output_path = os.path.join(
- self.output_dir, audio_basename + "." + self.extension
+ output_path = (pathlib.Path(self.output_dir) / output_name).with_suffix(
+ f".{self.extension}"
)
- with open(output_path, "w", encoding="utf-8") as f:
+ with output_path.open("wt", encoding="utf-8") as f:
self.write_result(result, file=f, options=options, **kwargs)
def write_result(