mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Compare commits
4 Commits
4477876d2f
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e52c128d11 | ||
|
|
7ddca42f4d | ||
|
|
21a4d4cdab | ||
|
|
8e4391ca21 |
@@ -1,40 +0,0 @@
|
||||
version: 2.1
|
||||
|
||||
orbs:
|
||||
apple: ml-explore/pr-approval@0.1.0
|
||||
|
||||
jobs:
|
||||
linux_build_and_test:
|
||||
docker:
|
||||
- image: cimg/python:3.9
|
||||
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Run style checks
|
||||
command: |
|
||||
pip install pre-commit
|
||||
pre-commit run --all
|
||||
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
|
||||
|
||||
workflows:
|
||||
build_and_test:
|
||||
when:
|
||||
matches:
|
||||
pattern: "^(?!pull/)[-\\w]+$"
|
||||
value: << pipeline.git.branch >>
|
||||
jobs:
|
||||
- linux_build_and_test
|
||||
|
||||
prb:
|
||||
when:
|
||||
matches:
|
||||
pattern: "^pull/\\d+(/head)?$"
|
||||
value: << pipeline.git.branch >>
|
||||
jobs:
|
||||
- hold:
|
||||
type: approval
|
||||
- apple/authenticate:
|
||||
context: pr-approval
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
25
.github/workflows/pull_request.yml
vendored
Normal file
25
.github/workflows/pull_request.yml
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
name: Test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/head/main' }}
|
||||
|
||||
jobs:
|
||||
check_lint:
|
||||
if: github.repository == 'ml-explore/mlx-examples'
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
|
||||
@@ -11,12 +11,6 @@ audio_file = "mlx_whisper/assets/ls_test.flac"
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="Benchmark script.")
|
||||
parser.add_argument(
|
||||
"--mlx-dir",
|
||||
type=str,
|
||||
default="mlx_models",
|
||||
help="The folder of MLX models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--all",
|
||||
action="store_true",
|
||||
|
||||
@@ -382,7 +382,7 @@ if __name__ == "__main__":
|
||||
|
||||
# Save weights
|
||||
print("[INFO] Saving")
|
||||
mx.save_safetensors(str(mlx_path / "weights.safetensors"), weights)
|
||||
mx.save_safetensors(str(mlx_path / "model.safetensors"), weights)
|
||||
|
||||
# Save config.json with model_type
|
||||
with open(str(mlx_path / "config.json"), "w") as f:
|
||||
|
||||
@@ -156,42 +156,42 @@ def build_parser():
|
||||
"--prepend-punctuations",
|
||||
type=str,
|
||||
default="\"'“¿([{-",
|
||||
help="If word-timestamps is True, merge these punctuation symbols with the next word",
|
||||
help="If --word-timestamps is True, merge these punctuation symbols with the next word",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--append-punctuations",
|
||||
type=str,
|
||||
default="\"'.。,,!!??::”)]}、",
|
||||
help="If word_timestamps is True, merge these punctuation symbols with the previous word",
|
||||
help="If --word-timestamps is True, merge these punctuation symbols with the previous word",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--highlight-words",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt",
|
||||
help="(requires --word-timestamps True) underline each word as it is spoken in srt and vtt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-line-width",
|
||||
type=int,
|
||||
default=None,
|
||||
help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line",
|
||||
help="(requires --word-timestampss True) the maximum number of characters in a line before breaking the line",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-line-count",
|
||||
type=int,
|
||||
default=None,
|
||||
help="(requires --word_timestamps True) the maximum number of lines in a segment",
|
||||
help="(requires --word-timestamps True) the maximum number of lines in a segment",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-words-per-line",
|
||||
type=int,
|
||||
default=None,
|
||||
help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment",
|
||||
help="(requires --word-timestamps True, no effect with --max-line-width) the maximum number of words in a segment",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hallucination-silence-threshold",
|
||||
type=optional_float,
|
||||
help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected",
|
||||
help="(requires --word-timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clip-timestamps",
|
||||
|
||||
@@ -26,7 +26,10 @@ def load_model(
|
||||
|
||||
model_args = whisper.ModelDimensions(**config)
|
||||
|
||||
wf = model_path / "weights.safetensors"
|
||||
# Prefer model.safetensors, fall back to weights.safetensors, then weights.npz
|
||||
wf = model_path / "model.safetensors"
|
||||
if not wf.exists():
|
||||
wf = model_path / "weights.safetensors"
|
||||
if not wf.exists():
|
||||
wf = model_path / "weights.npz"
|
||||
weights = mx.load(str(wf))
|
||||
|
||||
@@ -62,7 +62,7 @@ class ModelHolder:
|
||||
def transcribe(
|
||||
audio: Union[str, np.ndarray, mx.array],
|
||||
*,
|
||||
path_or_hf_repo: str = "mlx-community/whisper-tiny",
|
||||
path_or_hf_repo: str = "mlx-community/whisper-turbo",
|
||||
verbose: Optional[bool] = None,
|
||||
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
||||
compression_ratio_threshold: Optional[float] = 2.4,
|
||||
|
||||
Reference in New Issue
Block a user