Merge branch 'ml-explore:main' into completion_only

This commit is contained in:
Chime Ogbuji 2024-11-04 22:00:55 -05:00 committed by GitHub
commit a1fbc52cf2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 75 additions and 26 deletions

View File

@ -221,6 +221,7 @@ Here are a few examples of Hugging Face models that work with this example:
- [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct) - [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct)
- [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b) - [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b)
- [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b) - [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b)
- [tiiuae/falcon-mamba-7b-instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct)
Most Most
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending), [Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),

View File

@ -11,6 +11,7 @@ from .utils import load, stream_generate
DEFAULT_TEMP = 0.0 DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0 DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0 DEFAULT_SEED = 0
DEFAULT_MAX_TOKENS = 256
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
@ -41,6 +42,13 @@ def setup_arg_parser():
help="Set the maximum key-value cache size", help="Set the maximum key-value cache size",
default=None, default=None,
) )
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=DEFAULT_MAX_TOKENS,
help="Maximum number of tokens to generate",
)
return parser return parser
@ -70,6 +78,7 @@ def main():
model, model,
tokenizer, tokenizer,
prompt, prompt,
args.max_tokens,
temp=args.temp, temp=args.temp,
top_p=args.top_p, top_p=args.top_p,
prompt_cache=prompt_cache, prompt_cache=prompt_cache,

View File

@ -23,6 +23,8 @@ class ModelArgs(BaseModelArgs):
use_conv_bias: bool use_conv_bias: bool
time_step_rank: int time_step_rank: int
tie_word_embeddings: bool = True tie_word_embeddings: bool = True
use_bcdt_rms: bool = False
mixer_rms_eps: float = 1e-6
def __post_init__(self): def __post_init__(self):
if not hasattr(self, "hidden_size") and hasattr(self, "d_model"): if not hasattr(self, "hidden_size") and hasattr(self, "d_model"):
@ -44,6 +46,8 @@ class ModelArgs(BaseModelArgs):
if self.time_step_rank == "auto": if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16) self.time_step_rank = math.ceil(self.hidden_size / 16)
if self.model_type == "falcon_mamba":
self.use_bcdt_rms = True
class DepthWiseConv1d(nn.Module): class DepthWiseConv1d(nn.Module):
@ -83,6 +87,11 @@ class MambaBlock(nn.Module):
self.intermediate_size = args.intermediate_size self.intermediate_size = args.intermediate_size
self.time_step_rank = int(args.time_step_rank) self.time_step_rank = int(args.time_step_rank)
self.use_conv_bias = args.use_conv_bias self.use_conv_bias = args.use_conv_bias
self.use_bcdt_rms = args.use_bcdt_rms
if self.use_bcdt_rms:
self.mixer_norm = lambda x: mx.fast.rms_norm(
x, mx.ones(x.shape[-1], x.dtype), eps=args.mixer_rms_eps
)
self.in_proj = nn.Linear( self.in_proj = nn.Linear(
self.hidden_size, self.intermediate_size * 2, bias=args.use_bias self.hidden_size, self.intermediate_size * 2, bias=args.use_bias
@ -126,6 +135,8 @@ class MambaBlock(nn.Module):
], ],
axis=-1, axis=-1,
) )
if self.use_bcdt_rms:
delta, B, C = map(self.mixer_norm, (delta, B, C))
delta = nn.softplus(self.dt_proj(delta)) delta = nn.softplus(self.dt_proj(delta))
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1) new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
if state is not None: if state is not None:

View File

@ -29,6 +29,7 @@ from .tuner.utils import load_adapters
MODEL_REMAPPING = { MODEL_REMAPPING = {
"mistral": "llama", # mistral is compatible with llama "mistral": "llama", # mistral is compatible with llama
"phi-msft": "phixtral", "phi-msft": "phixtral",
"falcon_mamba": "mamba",
} }
MAX_FILE_SIZE_GB = 5 MAX_FILE_SIZE_GB = 5

View File

@ -25,7 +25,7 @@ pip install mlx-whisper
At its simplest: At its simplest:
``` ```sh
mlx_whisper audio_file.mp3 mlx_whisper audio_file.mp3
``` ```
@ -35,6 +35,15 @@ Use `-f` to specify the output format and `--model` to specify the model. There
are many other supported command line options. To see them all, run are many other supported command line options. To see them all, run
`mlx_whisper -h`. `mlx_whisper -h`.
You can also pipe the audio content of other programs via stdin:
```sh
some-process | mlx_whisper -
```
The default output file name will be `content.*`. You can specify the name with
the `--output-name` flag.
#### API #### API
Transcribe audio with: Transcribe audio with:
@ -103,7 +112,7 @@ python convert.py --help
``` ```
By default, the conversion script will make the directory `mlx_models` By default, the conversion script will make the directory `mlx_models`
and save the converted `weights.npz` and `config.json` there. and save the converted `weights.npz` and `config.json` there.
Each time it is run, `convert.py` will overwrite any model in the provided Each time it is run, `convert.py` will overwrite any model in the provided
path. To save different models, make sure to set `--mlx-path` to a unique path. To save different models, make sure to set `--mlx-path` to a unique

View File

@ -3,7 +3,7 @@
import os import os
from functools import lru_cache from functools import lru_cache
from subprocess import CalledProcessError, run from subprocess import CalledProcessError, run
from typing import Union from typing import Optional, Union
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
@ -21,7 +21,7 @@ FRAMES_PER_SECOND = SAMPLE_RATE // HOP_LENGTH # 10ms per audio frame
TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token
def load_audio(file: str, sr: int = SAMPLE_RATE): def load_audio(file: str = Optional[str], sr: int = SAMPLE_RATE, from_stdin=False):
""" """
Open an audio file and read as mono waveform, resampling as necessary Open an audio file and read as mono waveform, resampling as necessary
@ -39,19 +39,21 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
""" """
# This launches a subprocess to decode audio while down-mixing # This launches a subprocess to decode audio while down-mixing
# and resampling as necessary. Requires the ffmpeg CLI in PATH. # and resampling as necessary. Requires the ffmpeg CLI in PATH.
if from_stdin:
cmd = ["ffmpeg", "-i", "pipe:0"]
else:
cmd = ["ffmpeg", "-nostdin", "-i", file]
# fmt: off # fmt: off
cmd = [ cmd.extend([
"ffmpeg",
"-nostdin",
"-threads", "0", "-threads", "0",
"-i", file,
"-f", "s16le", "-f", "s16le",
"-ac", "1", "-ac", "1",
"-acodec", "pcm_s16le", "-acodec", "pcm_s16le",
"-ar", str(sr), "-ar", str(sr),
"-" "-"
] ])
# fmt: on # fmt: on
try: try:
out = run(cmd, capture_output=True, check=True).stdout out = run(cmd, capture_output=True, check=True).stdout

View File

@ -2,9 +2,11 @@
import argparse import argparse
import os import os
import pathlib
import traceback import traceback
import warnings import warnings
from . import audio
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from .transcribe import transcribe from .transcribe import transcribe
from .writers import get_writer from .writers import get_writer
@ -27,15 +29,24 @@ def build_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
) )
parser.add_argument(
"audio", nargs="+", type=str, help="Audio file(s) to transcribe" parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe")
)
parser.add_argument( parser.add_argument(
"--model", "--model",
default="mlx-community/whisper-tiny", default="mlx-community/whisper-tiny",
type=str, type=str,
help="The model directory or hugging face repo", help="The model directory or hugging face repo",
) )
parser.add_argument(
"--output-name",
type=str,
default=None,
help=(
"The name of transcription/translation output files before "
"--output-format extensions"
),
)
parser.add_argument( parser.add_argument(
"--output-dir", "--output-dir",
"-o", "-o",
@ -200,6 +211,7 @@ def main():
path_or_hf_repo: str = args.pop("model") path_or_hf_repo: str = args.pop("model")
output_dir: str = args.pop("output_dir") output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format") output_format: str = args.pop("output_format")
output_name: str = args.pop("output_name")
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
writer = get_writer(output_format, output_dir) writer = get_writer(output_format, output_dir)
@ -219,17 +231,25 @@ def main():
warnings.warn("--max-line-count has no effect without --max-line-width") warnings.warn("--max-line-count has no effect without --max-line-width")
if writer_args["max_words_per_line"] and writer_args["max_line_width"]: if writer_args["max_words_per_line"] and writer_args["max_line_width"]:
warnings.warn("--max-words-per-line has no effect with --max-line-width") warnings.warn("--max-words-per-line has no effect with --max-line-width")
for audio_path in args.pop("audio"):
for audio_obj in args.pop("audio"):
if audio_obj == "-":
# receive the contents from stdin rather than read a file
audio_obj = audio.load_audio(from_stdin=True)
output_name = output_name or "content"
else:
output_name = output_name or pathlib.Path(audio_obj).stem
try: try:
result = transcribe( result = transcribe(
audio_path, audio_obj,
path_or_hf_repo=path_or_hf_repo, path_or_hf_repo=path_or_hf_repo,
**args, **args,
) )
writer(result, audio_path, **writer_args) writer(result, output_name, **writer_args)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}") print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,10 +1,8 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024 Apple Inc.
import json import json
import os import pathlib
import re import re
import sys
import zlib
from typing import Callable, List, Optional, TextIO from typing import Callable, List, Optional, TextIO
@ -43,15 +41,13 @@ class ResultWriter:
self.output_dir = output_dir self.output_dir = output_dir
def __call__( def __call__(
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs self, result: dict, output_name: str, options: Optional[dict] = None, **kwargs
): ):
audio_basename = os.path.basename(audio_path) output_path = (pathlib.Path(self.output_dir) / output_name).with_suffix(
audio_basename = os.path.splitext(audio_basename)[0] f".{self.extension}"
output_path = os.path.join(
self.output_dir, audio_basename + "." + self.extension
) )
with open(output_path, "w", encoding="utf-8") as f: with output_path.open("wt", encoding="utf-8") as f:
self.write_result(result, file=f, options=options, **kwargs) self.write_result(result, file=f, options=options, **kwargs)
def write_result( def write_result(