4 Commits

Author SHA1 Message Date
Anthony
e52c128d11 Use model.safetensors with Whisper (#1399)
Some checks failed
Test / check_lint (push) Has been cancelled
2025-12-15 06:17:08 -08:00
Awni Hannun
7ddca42f4d switch to github actions (#1394)
Some checks failed
Test / check_lint (push) Has been cancelled
2025-11-20 09:57:43 -08:00
Armin Stross-Radschinski
21a4d4cdab Update whisper command line help mentioning --word-timestamps (#1390) 2025-10-07 11:19:46 -07:00
Awni Hannun
8e4391ca21 whisper nits (#1388) 2025-09-03 13:18:50 -07:00
7 changed files with 38 additions and 56 deletions

View File

@@ -1,40 +0,0 @@
version: 2.1
orbs:
apple: ml-explore/pr-approval@0.1.0
jobs:
linux_build_and_test:
docker:
- image: cimg/python:3.9
steps:
- checkout
- run:
name: Run style checks
command: |
pip install pre-commit
pre-commit run --all
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
workflows:
build_and_test:
when:
matches:
pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >>
jobs:
- linux_build_and_test
prb:
when:
matches:
pattern: "^pull/\\d+(/head)?$"
value: << pipeline.git.branch >>
jobs:
- hold:
type: approval
- apple/authenticate:
context: pr-approval
- linux_build_and_test:
requires: [ hold ]

25
.github/workflows/pull_request.yml vendored Normal file
View File

@@ -0,0 +1,25 @@
name: Test
on:
push:
branches: ["main"]
pull_request:
permissions:
contents: read
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/head/main' }}
jobs:
check_lint:
if: github.repository == 'ml-explore/mlx-examples'
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: actions/setup-python@v6
with:
python-version: "3.10"
- uses: pre-commit/action@v3.0.1

View File

@@ -11,12 +11,6 @@ audio_file = "mlx_whisper/assets/ls_test.flac"
def parse_arguments():
parser = argparse.ArgumentParser(description="Benchmark script.")
parser.add_argument(
"--mlx-dir",
type=str,
default="mlx_models",
help="The folder of MLX models",
)
parser.add_argument(
"--all",
action="store_true",

View File

@@ -382,7 +382,7 @@ if __name__ == "__main__":
# Save weights
print("[INFO] Saving")
mx.save_safetensors(str(mlx_path / "weights.safetensors"), weights)
mx.save_safetensors(str(mlx_path / "model.safetensors"), weights)
# Save config.json with model_type
with open(str(mlx_path / "config.json"), "w") as f:

View File

@@ -156,42 +156,42 @@ def build_parser():
"--prepend-punctuations",
type=str,
default="\"'“¿([{-",
help="If word-timestamps is True, merge these punctuation symbols with the next word",
help="If --word-timestamps is True, merge these punctuation symbols with the next word",
)
parser.add_argument(
"--append-punctuations",
type=str,
default="\"'.。,!?::”)]}、",
help="If word_timestamps is True, merge these punctuation symbols with the previous word",
help="If --word-timestamps is True, merge these punctuation symbols with the previous word",
)
parser.add_argument(
"--highlight-words",
type=str2bool,
default=False,
help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt",
help="(requires --word-timestamps True) underline each word as it is spoken in srt and vtt",
)
parser.add_argument(
"--max-line-width",
type=int,
default=None,
help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line",
help="(requires --word-timestampss True) the maximum number of characters in a line before breaking the line",
)
parser.add_argument(
"--max-line-count",
type=int,
default=None,
help="(requires --word_timestamps True) the maximum number of lines in a segment",
help="(requires --word-timestamps True) the maximum number of lines in a segment",
)
parser.add_argument(
"--max-words-per-line",
type=int,
default=None,
help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment",
help="(requires --word-timestamps True, no effect with --max-line-width) the maximum number of words in a segment",
)
parser.add_argument(
"--hallucination-silence-threshold",
type=optional_float,
help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected",
help="(requires --word-timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected",
)
parser.add_argument(
"--clip-timestamps",

View File

@@ -26,7 +26,10 @@ def load_model(
model_args = whisper.ModelDimensions(**config)
wf = model_path / "weights.safetensors"
# Prefer model.safetensors, fall back to weights.safetensors, then weights.npz
wf = model_path / "model.safetensors"
if not wf.exists():
wf = model_path / "weights.safetensors"
if not wf.exists():
wf = model_path / "weights.npz"
weights = mx.load(str(wf))

View File

@@ -62,7 +62,7 @@ class ModelHolder:
def transcribe(
audio: Union[str, np.ndarray, mx.array],
*,
path_or_hf_repo: str = "mlx-community/whisper-tiny",
path_or_hf_repo: str = "mlx-community/whisper-turbo",
verbose: Optional[bool] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4,