[Whisper] Add word timestamps and confidence scores (#201)

* Add word timestamps and confidence scores

* Create a separate forward_with_cross_qk function

* Move multiple ops from np to mlx, clean comments

* Save alignment_heads

* Cast qk to fp32

* Add test for word-level timestamps and confidence scores

* format + readme

* nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
bofeng huang
2024-01-07 19:01:29 +01:00
committed by GitHub
parent 25ebd36112
commit bf9926489e
7 changed files with 398 additions and 111 deletions

View File

@@ -311,6 +311,137 @@ class TestWhisper(unittest.TestCase):
check_segment(result["segments"][5], expected_5)
check_segment(result["segments"][73], expected_73)
def test_transcribe_word_level_timestamps_confidence_scores(self):
result = whisper.transcribe(
# TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, word_timestamps=True, fp16=False
TEST_AUDIO,
model_path=MLX_FP16_MODEL_PATH,
word_timestamps=True,
)
# result predicted with openai-whisper
expected_0 = [
{
"word": " Then",
"start": 0.0,
"end": 0.94,
"probability": 0.855542778968811,
},
{
"word": " the",
"start": 0.94,
"end": 1.12,
"probability": 0.6500106453895569,
},
{
"word": " good",
"start": 1.12,
"end": 1.32,
"probability": 0.5503873825073242,
},
{
"word": " soul",
"start": 1.32,
"end": 1.56,
"probability": 0.46757155656814575,
},
{
"word": " openly",
"start": 1.56,
"end": 2.0,
"probability": 0.9840946793556213,
},
{
"word": " sorted",
"start": 2.0,
"end": 2.38,
"probability": 0.24167272448539734,
},
{
"word": " the",
"start": 2.38,
"end": 2.58,
"probability": 0.9875414967536926,
},
{
"word": " boat",
"start": 2.58,
"end": 2.8,
"probability": 0.5856029391288757,
},
{
"word": " and",
"start": 2.8,
"end": 2.98,
"probability": 0.913351833820343,
},
{
"word": " she",
"start": 2.98,
"end": 3.1,
"probability": 0.9913808703422546,
},
{
"word": " had",
"start": 3.1,
"end": 3.32,
"probability": 0.9952940344810486,
},
{
"word": " buoyed",
"start": 3.32,
"end": 3.58,
"probability": 0.6411589980125427,
},
{
"word": " so",
"start": 3.58,
"end": 3.8,
"probability": 0.9682658314704895,
},
{
"word": " long",
"start": 3.8,
"end": 4.06,
"probability": 0.9953522682189941,
},
{
"word": " in",
"start": 4.06,
"end": 4.26,
"probability": 0.6745936870574951,
},
{
"word": " secret",
"start": 4.26,
"end": 4.56,
"probability": 0.9905064702033997,
},
{
"word": " and",
"start": 4.56,
"end": 4.9,
"probability": 0.856008768081665,
},
{
"word": " bravely",
"start": 4.9,
"end": 5.28,
"probability": 0.8477402329444885,
},
]
def check_words(words, expected_words):
for word, expected_word in zip(words, expected_words):
for k, v in expected_word.items():
if isinstance(v, float):
self.assertAlmostEqual(word[k], v, places=1)
else:
self.assertEqual(word[k], v)
# Randomly check a couple of segments
check_words(result["segments"][0]["words"], expected_0)
class TestAudio(unittest.TestCase):
def test_load(self):