mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
[Whisper] Add word timestamps and confidence scores (#201)
* Add word timestamps and confidence scores * Create a separate forward_with_cross_qk function * Move multiple ops from np to mlx, clean comments * Save alignment_heads * Cast qk to fp32 * Add test for word-level timestamps and confidence scores * format + readme * nit --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
131
whisper/test.py
131
whisper/test.py
@@ -311,6 +311,137 @@ class TestWhisper(unittest.TestCase):
|
||||
check_segment(result["segments"][5], expected_5)
|
||||
check_segment(result["segments"][73], expected_73)
|
||||
|
||||
def test_transcribe_word_level_timestamps_confidence_scores(self):
|
||||
result = whisper.transcribe(
|
||||
# TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, word_timestamps=True, fp16=False
|
||||
TEST_AUDIO,
|
||||
model_path=MLX_FP16_MODEL_PATH,
|
||||
word_timestamps=True,
|
||||
)
|
||||
|
||||
# result predicted with openai-whisper
|
||||
expected_0 = [
|
||||
{
|
||||
"word": " Then",
|
||||
"start": 0.0,
|
||||
"end": 0.94,
|
||||
"probability": 0.855542778968811,
|
||||
},
|
||||
{
|
||||
"word": " the",
|
||||
"start": 0.94,
|
||||
"end": 1.12,
|
||||
"probability": 0.6500106453895569,
|
||||
},
|
||||
{
|
||||
"word": " good",
|
||||
"start": 1.12,
|
||||
"end": 1.32,
|
||||
"probability": 0.5503873825073242,
|
||||
},
|
||||
{
|
||||
"word": " soul",
|
||||
"start": 1.32,
|
||||
"end": 1.56,
|
||||
"probability": 0.46757155656814575,
|
||||
},
|
||||
{
|
||||
"word": " openly",
|
||||
"start": 1.56,
|
||||
"end": 2.0,
|
||||
"probability": 0.9840946793556213,
|
||||
},
|
||||
{
|
||||
"word": " sorted",
|
||||
"start": 2.0,
|
||||
"end": 2.38,
|
||||
"probability": 0.24167272448539734,
|
||||
},
|
||||
{
|
||||
"word": " the",
|
||||
"start": 2.38,
|
||||
"end": 2.58,
|
||||
"probability": 0.9875414967536926,
|
||||
},
|
||||
{
|
||||
"word": " boat",
|
||||
"start": 2.58,
|
||||
"end": 2.8,
|
||||
"probability": 0.5856029391288757,
|
||||
},
|
||||
{
|
||||
"word": " and",
|
||||
"start": 2.8,
|
||||
"end": 2.98,
|
||||
"probability": 0.913351833820343,
|
||||
},
|
||||
{
|
||||
"word": " she",
|
||||
"start": 2.98,
|
||||
"end": 3.1,
|
||||
"probability": 0.9913808703422546,
|
||||
},
|
||||
{
|
||||
"word": " had",
|
||||
"start": 3.1,
|
||||
"end": 3.32,
|
||||
"probability": 0.9952940344810486,
|
||||
},
|
||||
{
|
||||
"word": " buoyed",
|
||||
"start": 3.32,
|
||||
"end": 3.58,
|
||||
"probability": 0.6411589980125427,
|
||||
},
|
||||
{
|
||||
"word": " so",
|
||||
"start": 3.58,
|
||||
"end": 3.8,
|
||||
"probability": 0.9682658314704895,
|
||||
},
|
||||
{
|
||||
"word": " long",
|
||||
"start": 3.8,
|
||||
"end": 4.06,
|
||||
"probability": 0.9953522682189941,
|
||||
},
|
||||
{
|
||||
"word": " in",
|
||||
"start": 4.06,
|
||||
"end": 4.26,
|
||||
"probability": 0.6745936870574951,
|
||||
},
|
||||
{
|
||||
"word": " secret",
|
||||
"start": 4.26,
|
||||
"end": 4.56,
|
||||
"probability": 0.9905064702033997,
|
||||
},
|
||||
{
|
||||
"word": " and",
|
||||
"start": 4.56,
|
||||
"end": 4.9,
|
||||
"probability": 0.856008768081665,
|
||||
},
|
||||
{
|
||||
"word": " bravely",
|
||||
"start": 4.9,
|
||||
"end": 5.28,
|
||||
"probability": 0.8477402329444885,
|
||||
},
|
||||
]
|
||||
|
||||
def check_words(words, expected_words):
|
||||
for word, expected_word in zip(words, expected_words):
|
||||
for k, v in expected_word.items():
|
||||
if isinstance(v, float):
|
||||
self.assertAlmostEqual(word[k], v, places=1)
|
||||
else:
|
||||
self.assertEqual(word[k], v)
|
||||
|
||||
# Randomly check a couple of segments
|
||||
check_words(result["segments"][0]["words"], expected_0)
|
||||
|
||||
|
||||
class TestAudio(unittest.TestCase):
|
||||
def test_load(self):
|
||||
|
Reference in New Issue
Block a user