Merge branch 'ml-explore:main' into completion_only

This commit is contained in:
Chime Ogbuji 2024-11-04 22:00:55 -05:00 committed by GitHub
commit a1fbc52cf2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 75 additions and 26 deletions

View File

@ -221,6 +221,7 @@ Here are a few examples of Hugging Face models that work with this example:
- [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct)
- [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b)
- [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b)
- [tiiuae/falcon-mamba-7b-instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct)
Most
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),

View File

@ -11,6 +11,7 @@ from .utils import load, stream_generate
DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
DEFAULT_MAX_TOKENS = 256
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
@ -41,6 +42,13 @@ def setup_arg_parser():
help="Set the maximum key-value cache size",
default=None,
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=DEFAULT_MAX_TOKENS,
help="Maximum number of tokens to generate",
)
return parser
@ -70,6 +78,7 @@ def main():
model,
tokenizer,
prompt,
args.max_tokens,
temp=args.temp,
top_p=args.top_p,
prompt_cache=prompt_cache,

View File

@ -23,6 +23,8 @@ class ModelArgs(BaseModelArgs):
use_conv_bias: bool
time_step_rank: int
tie_word_embeddings: bool = True
use_bcdt_rms: bool = False
mixer_rms_eps: float = 1e-6
def __post_init__(self):
if not hasattr(self, "hidden_size") and hasattr(self, "d_model"):
@ -44,6 +46,8 @@ class ModelArgs(BaseModelArgs):
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
if self.model_type == "falcon_mamba":
self.use_bcdt_rms = True
class DepthWiseConv1d(nn.Module):
@ -83,6 +87,11 @@ class MambaBlock(nn.Module):
self.intermediate_size = args.intermediate_size
self.time_step_rank = int(args.time_step_rank)
self.use_conv_bias = args.use_conv_bias
self.use_bcdt_rms = args.use_bcdt_rms
if self.use_bcdt_rms:
self.mixer_norm = lambda x: mx.fast.rms_norm(
x, mx.ones(x.shape[-1], x.dtype), eps=args.mixer_rms_eps
)
self.in_proj = nn.Linear(
self.hidden_size, self.intermediate_size * 2, bias=args.use_bias
@ -126,6 +135,8 @@ class MambaBlock(nn.Module):
],
axis=-1,
)
if self.use_bcdt_rms:
delta, B, C = map(self.mixer_norm, (delta, B, C))
delta = nn.softplus(self.dt_proj(delta))
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
if state is not None:

View File

@ -29,6 +29,7 @@ from .tuner.utils import load_adapters
MODEL_REMAPPING = {
"mistral": "llama", # mistral is compatible with llama
"phi-msft": "phixtral",
"falcon_mamba": "mamba",
}
MAX_FILE_SIZE_GB = 5

View File

@ -25,7 +25,7 @@ pip install mlx-whisper
At its simplest:
```
```sh
mlx_whisper audio_file.mp3
```
@ -35,6 +35,15 @@ Use `-f` to specify the output format and `--model` to specify the model. There
are many other supported command line options. To see them all, run
`mlx_whisper -h`.
You can also pipe the audio content of other programs via stdin:
```sh
some-process | mlx_whisper -
```
The default output file name will be `content.*`. You can specify the name with
the `--output-name` flag.
#### API
Transcribe audio with:
@ -103,7 +112,7 @@ python convert.py --help
```
By default, the conversion script will make the directory `mlx_models`
and save the converted `weights.npz` and `config.json` there.
and save the converted `weights.npz` and `config.json` there.
Each time it is run, `convert.py` will overwrite any model in the provided
path. To save different models, make sure to set `--mlx-path` to a unique

View File

@ -3,7 +3,7 @@
import os
from functools import lru_cache
from subprocess import CalledProcessError, run
from typing import Union
from typing import Optional, Union
import mlx.core as mx
import numpy as np
@ -21,7 +21,7 @@ FRAMES_PER_SECOND = SAMPLE_RATE // HOP_LENGTH # 10ms per audio frame
TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token
def load_audio(file: str, sr: int = SAMPLE_RATE):
def load_audio(file: str = Optional[str], sr: int = SAMPLE_RATE, from_stdin=False):
"""
Open an audio file and read as mono waveform, resampling as necessary
@ -39,19 +39,21 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
# This launches a subprocess to decode audio while down-mixing
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
if from_stdin:
cmd = ["ffmpeg", "-i", "pipe:0"]
else:
cmd = ["ffmpeg", "-nostdin", "-i", file]
# fmt: off
cmd = [
"ffmpeg",
"-nostdin",
cmd.extend([
"-threads", "0",
"-i", file,
"-f", "s16le",
"-ac", "1",
"-acodec", "pcm_s16le",
"-ar", str(sr),
"-"
]
])
# fmt: on
try:
out = run(cmd, capture_output=True, check=True).stdout

View File

@ -2,9 +2,11 @@
import argparse
import os
import pathlib
import traceback
import warnings
from . import audio
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from .transcribe import transcribe
from .writers import get_writer
@ -27,15 +29,24 @@ def build_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"audio", nargs="+", type=str, help="Audio file(s) to transcribe"
)
parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe")
parser.add_argument(
"--model",
default="mlx-community/whisper-tiny",
type=str,
help="The model directory or hugging face repo",
)
parser.add_argument(
"--output-name",
type=str,
default=None,
help=(
"The name of transcription/translation output files before "
"--output-format extensions"
),
)
parser.add_argument(
"--output-dir",
"-o",
@ -200,6 +211,7 @@ def main():
path_or_hf_repo: str = args.pop("model")
output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format")
output_name: str = args.pop("output_name")
os.makedirs(output_dir, exist_ok=True)
writer = get_writer(output_format, output_dir)
@ -219,17 +231,25 @@ def main():
warnings.warn("--max-line-count has no effect without --max-line-width")
if writer_args["max_words_per_line"] and writer_args["max_line_width"]:
warnings.warn("--max-words-per-line has no effect with --max-line-width")
for audio_path in args.pop("audio"):
for audio_obj in args.pop("audio"):
if audio_obj == "-":
# receive the contents from stdin rather than read a file
audio_obj = audio.load_audio(from_stdin=True)
output_name = output_name or "content"
else:
output_name = output_name or pathlib.Path(audio_obj).stem
try:
result = transcribe(
audio_path,
audio_obj,
path_or_hf_repo=path_or_hf_repo,
**args,
)
writer(result, audio_path, **writer_args)
writer(result, output_name, **writer_args)
except Exception as e:
traceback.print_exc()
print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}")
if __name__ == "__main__":

View File

@ -1,10 +1,8 @@
# Copyright © 2024 Apple Inc.
import json
import os
import pathlib
import re
import sys
import zlib
from typing import Callable, List, Optional, TextIO
@ -43,15 +41,13 @@ class ResultWriter:
self.output_dir = output_dir
def __call__(
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
self, result: dict, output_name: str, options: Optional[dict] = None, **kwargs
):
audio_basename = os.path.basename(audio_path)
audio_basename = os.path.splitext(audio_basename)[0]
output_path = os.path.join(
self.output_dir, audio_basename + "." + self.extension
output_path = (pathlib.Path(self.output_dir) / output_name).with_suffix(
f".{self.extension}"
)
with open(output_path, "w", encoding="utf-8") as f:
with output_path.open("wt", encoding="utf-8") as f:
self.write_result(result, file=f, options=options, **kwargs)
def write_result(