mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00

* Whisper: Add CLI command * Whisper: Prevent precision loss when converting to words dictionary * Whisper: disable json ensure_ascii * Whisper: add cli setup config * Whisper: pre-commit * Whisper: Adjust the _ in the command line arguments to - * nits * version + readme * nit --------- Co-authored-by: Awni Hannun <awni@apple.com>
331 lines
11 KiB
Python
331 lines
11 KiB
Python
# Copyright © 2023 Apple Inc.
|
||
|
||
import itertools
|
||
from dataclasses import dataclass
|
||
from typing import TYPE_CHECKING, List
|
||
|
||
import mlx.core as mx
|
||
import numba
|
||
import numpy as np
|
||
from scipy import signal
|
||
|
||
from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
|
||
from .tokenizer import Tokenizer
|
||
|
||
if TYPE_CHECKING:
|
||
from .model import Whisper
|
||
|
||
|
||
def median_filter(x: np.ndarray, filter_width: int):
|
||
"""Apply a median filter of width `filter_width` along the last dimension of `x`"""
|
||
pad_width = filter_width // 2
|
||
if x.shape[-1] <= pad_width:
|
||
# F.pad requires the padding width to be smaller than the input dimension
|
||
return x
|
||
|
||
if (ndim := x.ndim) <= 2:
|
||
# `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
|
||
x = x[None, None, :]
|
||
|
||
assert (
|
||
filter_width > 0 and filter_width % 2 == 1
|
||
), "`filter_width` should be an odd number"
|
||
|
||
x = np.pad(x, ((0, 0), (0, 0), (pad_width, pad_width)), mode="reflect")
|
||
|
||
# todo: more efficient version in mlx
|
||
result = signal.medfilt(x.astype(np.float32), kernel_size=(1, 1, filter_width))[
|
||
..., pad_width:-pad_width
|
||
]
|
||
|
||
if ndim <= 2:
|
||
result = result[0, 0]
|
||
|
||
return result
|
||
|
||
|
||
@numba.jit(nopython=True)
|
||
def backtrace(trace: np.ndarray):
|
||
i = trace.shape[0] - 1
|
||
j = trace.shape[1] - 1
|
||
trace[0, :] = 2
|
||
trace[:, 0] = 1
|
||
|
||
result = []
|
||
while i > 0 or j > 0:
|
||
result.append((i - 1, j - 1))
|
||
|
||
if trace[i, j] == 0:
|
||
i -= 1
|
||
j -= 1
|
||
elif trace[i, j] == 1:
|
||
i -= 1
|
||
elif trace[i, j] == 2:
|
||
j -= 1
|
||
else:
|
||
raise ValueError("Unexpected trace[i, j]")
|
||
|
||
result = np.array(result)
|
||
return result[::-1, :].T
|
||
|
||
|
||
@numba.jit(nopython=True, parallel=True)
|
||
def dtw_cpu(x: np.ndarray):
|
||
N, M = x.shape
|
||
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
|
||
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
|
||
|
||
cost[0, 0] = 0
|
||
for j in range(1, M + 1):
|
||
for i in range(1, N + 1):
|
||
c0 = cost[i - 1, j - 1]
|
||
c1 = cost[i - 1, j]
|
||
c2 = cost[i, j - 1]
|
||
|
||
if c0 < c1 and c0 < c2:
|
||
c, t = c0, 0
|
||
elif c1 < c0 and c1 < c2:
|
||
c, t = c1, 1
|
||
else:
|
||
c, t = c2, 2
|
||
|
||
cost[i, j] = x[i - 1, j - 1] + c
|
||
trace[i, j] = t
|
||
|
||
return backtrace(trace)
|
||
|
||
|
||
def dtw(x: np.ndarray) -> np.ndarray:
|
||
# todo: more efficient version in mlx
|
||
return dtw_cpu(x)
|
||
|
||
|
||
@dataclass
|
||
class WordTiming:
|
||
word: str
|
||
tokens: List[int]
|
||
start: float
|
||
end: float
|
||
probability: float
|
||
|
||
|
||
def find_alignment(
|
||
model: "Whisper",
|
||
tokenizer: Tokenizer,
|
||
text_tokens: List[int],
|
||
mel: mx.array,
|
||
num_frames: int,
|
||
*,
|
||
medfilt_width: int = 7,
|
||
qk_scale: float = 1.0,
|
||
) -> List[WordTiming]:
|
||
if len(text_tokens) == 0:
|
||
return []
|
||
|
||
tokens = mx.array(
|
||
[
|
||
*tokenizer.sot_sequence,
|
||
tokenizer.no_timestamps,
|
||
*text_tokens,
|
||
tokenizer.eot,
|
||
]
|
||
)
|
||
|
||
logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :])
|
||
# consider only the logits associated with predicting text
|
||
sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot]
|
||
token_probs = mx.softmax(sampled_logits.astype(mx.float32), axis=-1).astype(
|
||
sampled_logits.dtype
|
||
)
|
||
text_token_probs = mx.take_along_axis(
|
||
token_probs, mx.array(text_tokens)[:, None], axis=1
|
||
).squeeze(1)
|
||
text_token_probs = np.array(text_token_probs)
|
||
|
||
# heads * tokens * frames
|
||
weights = mx.stack(
|
||
[cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads]
|
||
)
|
||
weights = weights[:, :, : num_frames // 2]
|
||
weights = mx.softmax(weights * qk_scale, axis=-1)
|
||
mean = mx.mean(weights, axis=-2, keepdims=True)
|
||
std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt()
|
||
weights = (weights - mean) / std
|
||
weights = median_filter(np.array(weights), medfilt_width)
|
||
|
||
matrix = weights.mean(axis=0)
|
||
matrix = matrix[len(tokenizer.sot_sequence) : -1]
|
||
text_indices, time_indices = dtw(-matrix)
|
||
|
||
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
|
||
if len(word_tokens) <= 1:
|
||
# return on eot only
|
||
# >>> np.pad([], (1, 0))
|
||
# array([0.])
|
||
# This results in crashes when we lookup jump_times with float, like
|
||
# IndexError: arrays used as indices must be of integer (or boolean) type
|
||
return []
|
||
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
|
||
|
||
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
||
jump_times = time_indices[jumps] / TOKENS_PER_SECOND
|
||
start_times = jump_times[word_boundaries[:-1]]
|
||
end_times = jump_times[word_boundaries[1:]]
|
||
word_probabilities = [
|
||
np.mean(text_token_probs[i:j])
|
||
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
|
||
]
|
||
|
||
return [
|
||
WordTiming(word, tokens, start, end, probability)
|
||
for word, tokens, start, end, probability in zip(
|
||
words, word_tokens, start_times, end_times, word_probabilities
|
||
)
|
||
]
|
||
|
||
|
||
def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str):
|
||
# merge prepended punctuations
|
||
i = len(alignment) - 2
|
||
j = len(alignment) - 1
|
||
while i >= 0:
|
||
previous = alignment[i]
|
||
following = alignment[j]
|
||
if previous.word.startswith(" ") and previous.word.strip() in prepended:
|
||
# prepend it to the following word
|
||
following.word = previous.word + following.word
|
||
following.tokens = previous.tokens + following.tokens
|
||
previous.word = ""
|
||
previous.tokens = []
|
||
else:
|
||
j = i
|
||
i -= 1
|
||
|
||
# merge appended punctuations
|
||
i = 0
|
||
j = 1
|
||
while j < len(alignment):
|
||
previous = alignment[i]
|
||
following = alignment[j]
|
||
if not previous.word.endswith(" ") and following.word in appended:
|
||
# append it to the previous word
|
||
previous.word = previous.word + following.word
|
||
previous.tokens = previous.tokens + following.tokens
|
||
following.word = ""
|
||
following.tokens = []
|
||
else:
|
||
i = j
|
||
j += 1
|
||
|
||
|
||
def add_word_timestamps(
|
||
*,
|
||
segments: List[dict],
|
||
model: "Whisper",
|
||
tokenizer: Tokenizer,
|
||
mel: mx.array,
|
||
num_frames: int,
|
||
prepend_punctuations: str = "\"'“¿([{-",
|
||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||
last_speech_timestamp: float,
|
||
**kwargs,
|
||
):
|
||
if len(segments) == 0:
|
||
return
|
||
|
||
text_tokens_per_segment = [
|
||
[token for token in segment["tokens"] if token < tokenizer.eot]
|
||
for segment in segments
|
||
]
|
||
|
||
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
|
||
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
|
||
word_durations = np.array([t.end - t.start for t in alignment])
|
||
word_durations = word_durations[word_durations.nonzero()]
|
||
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
|
||
median_duration = min(0.7, float(median_duration))
|
||
max_duration = median_duration * 2
|
||
|
||
# hack: truncate long words at sentence boundaries.
|
||
# a better segmentation algorithm based on VAD should be able to replace this.
|
||
if len(word_durations) > 0:
|
||
sentence_end_marks = ".。!!??"
|
||
# ensure words at sentence boundaries are not longer than twice the median word duration.
|
||
for i in range(1, len(alignment)):
|
||
if alignment[i].end - alignment[i].start > max_duration:
|
||
if alignment[i].word in sentence_end_marks:
|
||
alignment[i].end = alignment[i].start + max_duration
|
||
elif alignment[i - 1].word in sentence_end_marks:
|
||
alignment[i].start = alignment[i].end - max_duration
|
||
|
||
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
||
|
||
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
|
||
word_index = 0
|
||
|
||
for segment, text_tokens in zip(segments, text_tokens_per_segment):
|
||
saved_tokens = 0
|
||
words = []
|
||
|
||
while word_index < len(alignment) and saved_tokens < len(text_tokens):
|
||
timing = alignment[word_index]
|
||
|
||
if timing.word:
|
||
words.append(
|
||
dict(
|
||
word=timing.word,
|
||
start=round(time_offset + timing.start, 2),
|
||
end=round(time_offset + timing.end, 2),
|
||
probability=float(timing.probability),
|
||
)
|
||
)
|
||
|
||
saved_tokens += len(timing.tokens)
|
||
word_index += 1
|
||
|
||
# hack: truncate long words at segment boundaries.
|
||
# a better segmentation algorithm based on VAD should be able to replace this.
|
||
if len(words) > 0:
|
||
# ensure the first and second word after a pause is not longer than
|
||
# twice the median word duration.
|
||
if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
|
||
words[0]["end"] - words[0]["start"] > max_duration
|
||
or (
|
||
len(words) > 1
|
||
and words[1]["end"] - words[0]["start"] > max_duration * 2
|
||
)
|
||
):
|
||
if (
|
||
len(words) > 1
|
||
and words[1]["end"] - words[1]["start"] > max_duration
|
||
):
|
||
boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration)
|
||
words[0]["end"] = words[1]["start"] = boundary
|
||
words[0]["start"] = max(0, words[0]["end"] - max_duration)
|
||
|
||
# prefer the segment-level start timestamp if the first word is too long.
|
||
if (
|
||
segment["start"] < words[0]["end"]
|
||
and segment["start"] - 0.5 > words[0]["start"]
|
||
):
|
||
words[0]["start"] = max(
|
||
0, min(words[0]["end"] - median_duration, segment["start"])
|
||
)
|
||
else:
|
||
segment["start"] = words[0]["start"]
|
||
|
||
# prefer the segment-level end timestamp if the last word is too long.
|
||
if (
|
||
segment["end"] > words[-1]["start"]
|
||
and segment["end"] + 0.5 < words[-1]["end"]
|
||
):
|
||
words[-1]["end"] = max(
|
||
words[-1]["start"] + median_duration, segment["end"]
|
||
)
|
||
else:
|
||
segment["end"] = words[-1]["end"]
|
||
|
||
last_speech_timestamp = segment["end"]
|
||
|
||
segment["words"] = words
|