mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-23 22:18:06 +08:00
Whisper: Add pip distribution configuration to support pip installations. (#739)
* Whisper: rename whisper to mlx_whisper * Whisper: add setup.py config for publish * Whisper: add assets data to setup config * Whisper: pre-commit for setup.py * Whisper: Update README.md * Whisper: Update README.md * nits * fix package data * nit in readme --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import json
|
||||
import os
|
||||
@@ -7,21 +7,20 @@ from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_whisper
|
||||
import mlx_whisper.audio as audio
|
||||
import mlx_whisper.decoding as decoding
|
||||
import mlx_whisper.load_models as load_models
|
||||
import numpy as np
|
||||
import torch
|
||||
from convert import load_torch_model, quantize, torch_to_mlx
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
import whisper
|
||||
import whisper.audio as audio
|
||||
import whisper.decoding as decoding
|
||||
import whisper.load_models as load_models
|
||||
|
||||
MODEL_NAME = "tiny"
|
||||
MLX_FP32_MODEL_PATH = "mlx_models/tiny_fp32"
|
||||
MLX_FP16_MODEL_PATH = "mlx_models/tiny_fp16"
|
||||
MLX_4BITS_MODEL_PATH = "mlx_models/tiny_quantized_4bits"
|
||||
TEST_AUDIO = "whisper/assets/ls_test.flac"
|
||||
TEST_AUDIO = "mlx_whisper/assets/ls_test.flac"
|
||||
|
||||
|
||||
def _save_model(save_dir, weights, config):
|
||||
@@ -187,7 +186,7 @@ class TestWhisper(unittest.TestCase):
|
||||
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
|
||||
|
||||
def test_transcribe(self):
|
||||
result = whisper.transcribe(
|
||||
result = mlx_whisper.transcribe(
|
||||
TEST_AUDIO, path_or_hf_repo=MLX_FP32_MODEL_PATH, fp16=False
|
||||
)
|
||||
self.assertEqual(
|
||||
@@ -208,7 +207,7 @@ class TestWhisper(unittest.TestCase):
|
||||
print("bash path_to_whisper_repo/whisper/assets/download_alice.sh")
|
||||
return
|
||||
|
||||
result = whisper.transcribe(
|
||||
result = mlx_whisper.transcribe(
|
||||
audio_file, path_or_hf_repo=MLX_FP32_MODEL_PATH, fp16=False
|
||||
)
|
||||
self.assertEqual(len(result["text"]), 10920)
|
||||
@@ -311,7 +310,7 @@ class TestWhisper(unittest.TestCase):
|
||||
check_segment(result["segments"][73], expected_73)
|
||||
|
||||
def test_transcribe_word_level_timestamps_confidence_scores(self):
|
||||
result = whisper.transcribe(
|
||||
result = mlx_whisper.transcribe(
|
||||
TEST_AUDIO,
|
||||
path_or_hf_repo=MLX_FP16_MODEL_PATH,
|
||||
word_timestamps=True,
|
||||
|
Reference in New Issue
Block a user