Whisper: Add pip distribution configuration to support pip installations. (#739)

* Whisper: rename whisper to mlx_whisper

* Whisper: add setup.py config for publish

* Whisper: add assets data to setup config

* Whisper: pre-commit for setup.py

* Whisper: Update README.md

* Whisper: Update README.md

* nits

* fix package data

* nit in readme

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
madroid
2024-05-02 00:00:02 +08:00
committed by GitHub
parent 4bf2eb17f2
commit 6775d6cb3f
23 changed files with 116 additions and 74 deletions

View File

@@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import json
import os
@@ -7,21 +7,20 @@ from dataclasses import asdict
from pathlib import Path
import mlx.core as mx
import mlx_whisper
import mlx_whisper.audio as audio
import mlx_whisper.decoding as decoding
import mlx_whisper.load_models as load_models
import numpy as np
import torch
from convert import load_torch_model, quantize, torch_to_mlx
from mlx.utils import tree_flatten
import whisper
import whisper.audio as audio
import whisper.decoding as decoding
import whisper.load_models as load_models
MODEL_NAME = "tiny"
MLX_FP32_MODEL_PATH = "mlx_models/tiny_fp32"
MLX_FP16_MODEL_PATH = "mlx_models/tiny_fp16"
MLX_4BITS_MODEL_PATH = "mlx_models/tiny_quantized_4bits"
TEST_AUDIO = "whisper/assets/ls_test.flac"
TEST_AUDIO = "mlx_whisper/assets/ls_test.flac"
def _save_model(save_dir, weights, config):
@@ -187,7 +186,7 @@ class TestWhisper(unittest.TestCase):
self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)
def test_transcribe(self):
result = whisper.transcribe(
result = mlx_whisper.transcribe(
TEST_AUDIO, path_or_hf_repo=MLX_FP32_MODEL_PATH, fp16=False
)
self.assertEqual(
@@ -208,7 +207,7 @@ class TestWhisper(unittest.TestCase):
print("bash path_to_whisper_repo/whisper/assets/download_alice.sh")
return
result = whisper.transcribe(
result = mlx_whisper.transcribe(
audio_file, path_or_hf_repo=MLX_FP32_MODEL_PATH, fp16=False
)
self.assertEqual(len(result["text"]), 10920)
@@ -311,7 +310,7 @@ class TestWhisper(unittest.TestCase):
check_segment(result["segments"][73], expected_73)
def test_transcribe_word_level_timestamps_confidence_scores(self):
result = whisper.transcribe(
result = mlx_whisper.transcribe(
TEST_AUDIO,
path_or_hf_repo=MLX_FP16_MODEL_PATH,
word_timestamps=True,