mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Merge branch 'ml-explore:main' into completion_only
This commit is contained in:
commit
a1fbc52cf2
@ -221,6 +221,7 @@ Here are a few examples of Hugging Face models that work with this example:
|
|||||||
- [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct)
|
- [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct)
|
||||||
- [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b)
|
- [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b)
|
||||||
- [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b)
|
- [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b)
|
||||||
|
- [tiiuae/falcon-mamba-7b-instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct)
|
||||||
|
|
||||||
Most
|
Most
|
||||||
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),
|
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),
|
||||||
|
@ -11,6 +11,7 @@ from .utils import load, stream_generate
|
|||||||
DEFAULT_TEMP = 0.0
|
DEFAULT_TEMP = 0.0
|
||||||
DEFAULT_TOP_P = 1.0
|
DEFAULT_TOP_P = 1.0
|
||||||
DEFAULT_SEED = 0
|
DEFAULT_SEED = 0
|
||||||
|
DEFAULT_MAX_TOKENS = 256
|
||||||
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||||
|
|
||||||
|
|
||||||
@ -41,6 +42,13 @@ def setup_arg_parser():
|
|||||||
help="Set the maximum key-value cache size",
|
help="Set the maximum key-value cache size",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-tokens",
|
||||||
|
"-m",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_MAX_TOKENS,
|
||||||
|
help="Maximum number of tokens to generate",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -70,6 +78,7 @@ def main():
|
|||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
prompt,
|
prompt,
|
||||||
|
args.max_tokens,
|
||||||
temp=args.temp,
|
temp=args.temp,
|
||||||
top_p=args.top_p,
|
top_p=args.top_p,
|
||||||
prompt_cache=prompt_cache,
|
prompt_cache=prompt_cache,
|
||||||
|
@ -23,6 +23,8 @@ class ModelArgs(BaseModelArgs):
|
|||||||
use_conv_bias: bool
|
use_conv_bias: bool
|
||||||
time_step_rank: int
|
time_step_rank: int
|
||||||
tie_word_embeddings: bool = True
|
tie_word_embeddings: bool = True
|
||||||
|
use_bcdt_rms: bool = False
|
||||||
|
mixer_rms_eps: float = 1e-6
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if not hasattr(self, "hidden_size") and hasattr(self, "d_model"):
|
if not hasattr(self, "hidden_size") and hasattr(self, "d_model"):
|
||||||
@ -44,6 +46,8 @@ class ModelArgs(BaseModelArgs):
|
|||||||
|
|
||||||
if self.time_step_rank == "auto":
|
if self.time_step_rank == "auto":
|
||||||
self.time_step_rank = math.ceil(self.hidden_size / 16)
|
self.time_step_rank = math.ceil(self.hidden_size / 16)
|
||||||
|
if self.model_type == "falcon_mamba":
|
||||||
|
self.use_bcdt_rms = True
|
||||||
|
|
||||||
|
|
||||||
class DepthWiseConv1d(nn.Module):
|
class DepthWiseConv1d(nn.Module):
|
||||||
@ -83,6 +87,11 @@ class MambaBlock(nn.Module):
|
|||||||
self.intermediate_size = args.intermediate_size
|
self.intermediate_size = args.intermediate_size
|
||||||
self.time_step_rank = int(args.time_step_rank)
|
self.time_step_rank = int(args.time_step_rank)
|
||||||
self.use_conv_bias = args.use_conv_bias
|
self.use_conv_bias = args.use_conv_bias
|
||||||
|
self.use_bcdt_rms = args.use_bcdt_rms
|
||||||
|
if self.use_bcdt_rms:
|
||||||
|
self.mixer_norm = lambda x: mx.fast.rms_norm(
|
||||||
|
x, mx.ones(x.shape[-1], x.dtype), eps=args.mixer_rms_eps
|
||||||
|
)
|
||||||
|
|
||||||
self.in_proj = nn.Linear(
|
self.in_proj = nn.Linear(
|
||||||
self.hidden_size, self.intermediate_size * 2, bias=args.use_bias
|
self.hidden_size, self.intermediate_size * 2, bias=args.use_bias
|
||||||
@ -126,6 +135,8 @@ class MambaBlock(nn.Module):
|
|||||||
],
|
],
|
||||||
axis=-1,
|
axis=-1,
|
||||||
)
|
)
|
||||||
|
if self.use_bcdt_rms:
|
||||||
|
delta, B, C = map(self.mixer_norm, (delta, B, C))
|
||||||
delta = nn.softplus(self.dt_proj(delta))
|
delta = nn.softplus(self.dt_proj(delta))
|
||||||
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
|
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
|
||||||
if state is not None:
|
if state is not None:
|
||||||
|
@ -29,6 +29,7 @@ from .tuner.utils import load_adapters
|
|||||||
MODEL_REMAPPING = {
|
MODEL_REMAPPING = {
|
||||||
"mistral": "llama", # mistral is compatible with llama
|
"mistral": "llama", # mistral is compatible with llama
|
||||||
"phi-msft": "phixtral",
|
"phi-msft": "phixtral",
|
||||||
|
"falcon_mamba": "mamba",
|
||||||
}
|
}
|
||||||
|
|
||||||
MAX_FILE_SIZE_GB = 5
|
MAX_FILE_SIZE_GB = 5
|
||||||
|
@ -25,7 +25,7 @@ pip install mlx-whisper
|
|||||||
|
|
||||||
At its simplest:
|
At its simplest:
|
||||||
|
|
||||||
```
|
```sh
|
||||||
mlx_whisper audio_file.mp3
|
mlx_whisper audio_file.mp3
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -35,6 +35,15 @@ Use `-f` to specify the output format and `--model` to specify the model. There
|
|||||||
are many other supported command line options. To see them all, run
|
are many other supported command line options. To see them all, run
|
||||||
`mlx_whisper -h`.
|
`mlx_whisper -h`.
|
||||||
|
|
||||||
|
You can also pipe the audio content of other programs via stdin:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
some-process | mlx_whisper -
|
||||||
|
```
|
||||||
|
|
||||||
|
The default output file name will be `content.*`. You can specify the name with
|
||||||
|
the `--output-name` flag.
|
||||||
|
|
||||||
#### API
|
#### API
|
||||||
|
|
||||||
Transcribe audio with:
|
Transcribe audio with:
|
||||||
@ -103,7 +112,7 @@ python convert.py --help
|
|||||||
```
|
```
|
||||||
|
|
||||||
By default, the conversion script will make the directory `mlx_models`
|
By default, the conversion script will make the directory `mlx_models`
|
||||||
and save the converted `weights.npz` and `config.json` there.
|
and save the converted `weights.npz` and `config.json` there.
|
||||||
|
|
||||||
Each time it is run, `convert.py` will overwrite any model in the provided
|
Each time it is run, `convert.py` will overwrite any model in the provided
|
||||||
path. To save different models, make sure to set `--mlx-path` to a unique
|
path. To save different models, make sure to set `--mlx-path` to a unique
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
import os
|
import os
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from subprocess import CalledProcessError, run
|
from subprocess import CalledProcessError, run
|
||||||
from typing import Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -21,7 +21,7 @@ FRAMES_PER_SECOND = SAMPLE_RATE // HOP_LENGTH # 10ms per audio frame
|
|||||||
TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token
|
TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token
|
||||||
|
|
||||||
|
|
||||||
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
def load_audio(file: str = Optional[str], sr: int = SAMPLE_RATE, from_stdin=False):
|
||||||
"""
|
"""
|
||||||
Open an audio file and read as mono waveform, resampling as necessary
|
Open an audio file and read as mono waveform, resampling as necessary
|
||||||
|
|
||||||
@ -39,19 +39,21 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# This launches a subprocess to decode audio while down-mixing
|
# This launches a subprocess to decode audio while down-mixing
|
||||||
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
||||||
|
if from_stdin:
|
||||||
|
cmd = ["ffmpeg", "-i", "pipe:0"]
|
||||||
|
else:
|
||||||
|
cmd = ["ffmpeg", "-nostdin", "-i", file]
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
cmd = [
|
cmd.extend([
|
||||||
"ffmpeg",
|
|
||||||
"-nostdin",
|
|
||||||
"-threads", "0",
|
"-threads", "0",
|
||||||
"-i", file,
|
|
||||||
"-f", "s16le",
|
"-f", "s16le",
|
||||||
"-ac", "1",
|
"-ac", "1",
|
||||||
"-acodec", "pcm_s16le",
|
"-acodec", "pcm_s16le",
|
||||||
"-ar", str(sr),
|
"-ar", str(sr),
|
||||||
"-"
|
"-"
|
||||||
]
|
])
|
||||||
# fmt: on
|
# fmt: on
|
||||||
try:
|
try:
|
||||||
out = run(cmd, capture_output=True, check=True).stdout
|
out = run(cmd, capture_output=True, check=True).stdout
|
||||||
|
@ -2,9 +2,11 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
import pathlib
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
from . import audio
|
||||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE
|
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE
|
||||||
from .transcribe import transcribe
|
from .transcribe import transcribe
|
||||||
from .writers import get_writer
|
from .writers import get_writer
|
||||||
@ -27,15 +29,24 @@ def build_parser():
|
|||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"audio", nargs="+", type=str, help="Audio file(s) to transcribe"
|
parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe")
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
default="mlx-community/whisper-tiny",
|
default="mlx-community/whisper-tiny",
|
||||||
type=str,
|
type=str,
|
||||||
help="The model directory or hugging face repo",
|
help="The model directory or hugging face repo",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-name",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=(
|
||||||
|
"The name of transcription/translation output files before "
|
||||||
|
"--output-format extensions"
|
||||||
|
),
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output-dir",
|
"--output-dir",
|
||||||
"-o",
|
"-o",
|
||||||
@ -200,6 +211,7 @@ def main():
|
|||||||
path_or_hf_repo: str = args.pop("model")
|
path_or_hf_repo: str = args.pop("model")
|
||||||
output_dir: str = args.pop("output_dir")
|
output_dir: str = args.pop("output_dir")
|
||||||
output_format: str = args.pop("output_format")
|
output_format: str = args.pop("output_format")
|
||||||
|
output_name: str = args.pop("output_name")
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
writer = get_writer(output_format, output_dir)
|
writer = get_writer(output_format, output_dir)
|
||||||
@ -219,17 +231,25 @@ def main():
|
|||||||
warnings.warn("--max-line-count has no effect without --max-line-width")
|
warnings.warn("--max-line-count has no effect without --max-line-width")
|
||||||
if writer_args["max_words_per_line"] and writer_args["max_line_width"]:
|
if writer_args["max_words_per_line"] and writer_args["max_line_width"]:
|
||||||
warnings.warn("--max-words-per-line has no effect with --max-line-width")
|
warnings.warn("--max-words-per-line has no effect with --max-line-width")
|
||||||
for audio_path in args.pop("audio"):
|
|
||||||
|
for audio_obj in args.pop("audio"):
|
||||||
|
if audio_obj == "-":
|
||||||
|
# receive the contents from stdin rather than read a file
|
||||||
|
audio_obj = audio.load_audio(from_stdin=True)
|
||||||
|
|
||||||
|
output_name = output_name or "content"
|
||||||
|
else:
|
||||||
|
output_name = output_name or pathlib.Path(audio_obj).stem
|
||||||
try:
|
try:
|
||||||
result = transcribe(
|
result = transcribe(
|
||||||
audio_path,
|
audio_obj,
|
||||||
path_or_hf_repo=path_or_hf_repo,
|
path_or_hf_repo=path_or_hf_repo,
|
||||||
**args,
|
**args,
|
||||||
)
|
)
|
||||||
writer(result, audio_path, **writer_args)
|
writer(result, output_name, **writer_args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
|
print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import pathlib
|
||||||
import re
|
import re
|
||||||
import sys
|
|
||||||
import zlib
|
|
||||||
from typing import Callable, List, Optional, TextIO
|
from typing import Callable, List, Optional, TextIO
|
||||||
|
|
||||||
|
|
||||||
@ -43,15 +41,13 @@ class ResultWriter:
|
|||||||
self.output_dir = output_dir
|
self.output_dir = output_dir
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
|
self, result: dict, output_name: str, options: Optional[dict] = None, **kwargs
|
||||||
):
|
):
|
||||||
audio_basename = os.path.basename(audio_path)
|
output_path = (pathlib.Path(self.output_dir) / output_name).with_suffix(
|
||||||
audio_basename = os.path.splitext(audio_basename)[0]
|
f".{self.extension}"
|
||||||
output_path = os.path.join(
|
|
||||||
self.output_dir, audio_basename + "." + self.extension
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(output_path, "w", encoding="utf-8") as f:
|
with output_path.open("wt", encoding="utf-8") as f:
|
||||||
self.write_result(result, file=f, options=options, **kwargs)
|
self.write_result(result, file=f, options=options, **kwargs)
|
||||||
|
|
||||||
def write_result(
|
def write_result(
|
||||||
|
Loading…
Reference in New Issue
Block a user