[Whisper] Add word timestamps and confidence scores (#201)

* Add word timestamps and confidence scores

* Create a separate forward_with_cross_qk function

* Move multiple ops from np to mlx, clean comments

* Save alignment_heads

* Cast qk to fp32

* Add test for word-level timestamps and confidence scores

* format + readme

* nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
bofeng huang 2024-01-07 19:01:29 +01:00 committed by GitHub
parent 25ebd36112
commit bf9926489e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 398 additions and 111 deletions

View File

@ -2,7 +2,7 @@
Speech recognition with Whisper in MLX. Whisper is a set of open source speech
recognition models from OpenAI, ranging from 39 million to 1.5 billion
parameters[^1].
parameters.[^1]
### Setup
@ -19,7 +19,8 @@ Install [`ffmpeg`](https://ffmpeg.org/):
brew install ffmpeg
```
Next, download the Whisper PyTorch checkpoint and convert the weights to the MLX format. For example, to convert the `tiny` model use:
Next, download the Whisper PyTorch checkpoint and convert the weights to the
MLX format. For example, to convert the `tiny` model use:
```
python convert.py --torch-name-or-path tiny --mlx-path mlx_models/tiny
@ -45,10 +46,24 @@ the converted `weights.npz` and `config.json` there.
Transcribe audio with:
```
```python
import whisper
text = whisper.transcribe(speech_file)["text"]
```
The `transcribe` function also supports word-level timestamps. You can generate
these with:
```python
output = whisper.transcribe(speech_file, word_timestamps=True)
print(output["segments"][0]["words"])
```
To see more transcription options use:
```
>>> help(whisper.transcribe)
```
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2212.04356), [blog post](https://openai.com/research/whisper), and [code](https://github.com/openai/whisper) for more details.

View File

@ -199,6 +199,10 @@ def torch_to_mlx(
mlx_model = Whisper(torch_model.dims, dtype)
params = tree_map(lambda p: p.astype(dtype), params)
mlx_model.update(params)
if (alignment_heads := getattr(torch_model, "alignment_heads", None)) is not None:
mlx_model.set_alignment_heads(alignment_heads.indices().T.numpy())
return mlx_model

View File

@ -311,6 +311,137 @@ class TestWhisper(unittest.TestCase):
check_segment(result["segments"][5], expected_5)
check_segment(result["segments"][73], expected_73)
def test_transcribe_word_level_timestamps_confidence_scores(self):
result = whisper.transcribe(
# TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, word_timestamps=True, fp16=False
TEST_AUDIO,
model_path=MLX_FP16_MODEL_PATH,
word_timestamps=True,
)
# result predicted with openai-whisper
expected_0 = [
{
"word": " Then",
"start": 0.0,
"end": 0.94,
"probability": 0.855542778968811,
},
{
"word": " the",
"start": 0.94,
"end": 1.12,
"probability": 0.6500106453895569,
},
{
"word": " good",
"start": 1.12,
"end": 1.32,
"probability": 0.5503873825073242,
},
{
"word": " soul",
"start": 1.32,
"end": 1.56,
"probability": 0.46757155656814575,
},
{
"word": " openly",
"start": 1.56,
"end": 2.0,
"probability": 0.9840946793556213,
},
{
"word": " sorted",
"start": 2.0,
"end": 2.38,
"probability": 0.24167272448539734,
},
{
"word": " the",
"start": 2.38,
"end": 2.58,
"probability": 0.9875414967536926,
},
{
"word": " boat",
"start": 2.58,
"end": 2.8,
"probability": 0.5856029391288757,
},
{
"word": " and",
"start": 2.8,
"end": 2.98,
"probability": 0.913351833820343,
},
{
"word": " she",
"start": 2.98,
"end": 3.1,
"probability": 0.9913808703422546,
},
{
"word": " had",
"start": 3.1,
"end": 3.32,
"probability": 0.9952940344810486,
},
{
"word": " buoyed",
"start": 3.32,
"end": 3.58,
"probability": 0.6411589980125427,
},
{
"word": " so",
"start": 3.58,
"end": 3.8,
"probability": 0.9682658314704895,
},
{
"word": " long",
"start": 3.8,
"end": 4.06,
"probability": 0.9953522682189941,
},
{
"word": " in",
"start": 4.06,
"end": 4.26,
"probability": 0.6745936870574951,
},
{
"word": " secret",
"start": 4.26,
"end": 4.56,
"probability": 0.9905064702033997,
},
{
"word": " and",
"start": 4.56,
"end": 4.9,
"probability": 0.856008768081665,
},
{
"word": " bravely",
"start": 4.9,
"end": 5.28,
"probability": 0.8477402329444885,
},
]
def check_words(words, expected_words):
for word, expected_word in zip(words, expected_words):
for k, v in expected_word.items():
if isinstance(v, float):
self.assertAlmostEqual(word[k], v, places=1)
else:
self.assertEqual(word[k], v)
# Randomly check a couple of segments
check_words(result["segments"][0]["words"], expected_0)
class TestAudio(unittest.TestCase):
def test_load(self):

View File

@ -141,7 +141,7 @@ class Inference:
# only need to use the last token except in the first forward pass
tokens = tokens[:, -1:]
logits, self.kv_cache = self.model.decoder(
logits, self.kv_cache, _ = self.model.decoder(
tokens, audio_features, kv_cache=self.kv_cache
)
return logits.astype(mx.float32)

View File

@ -1,15 +1,13 @@
# Copyright © 2023 Apple Inc.
import itertools
import subprocess
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, List
import mlx.core as mx
import numba
import numpy as np
import torch
import torch.nn.functional as F
from scipy import signal
from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
from .tokenizer import Tokenizer
@ -18,7 +16,7 @@ if TYPE_CHECKING:
from .model import Whisper
def median_filter(x: torch.Tensor, filter_width: int):
def median_filter(x: np.ndarray, filter_width: int):
"""Apply a median filter of width `filter_width` along the last dimension of `x`"""
pad_width = filter_width // 2
if x.shape[-1] <= pad_width:
@ -33,22 +31,12 @@ def median_filter(x: torch.Tensor, filter_width: int):
filter_width > 0 and filter_width % 2 == 1
), "`filter_width` should be an odd number"
result = None
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
if x.is_cuda:
try:
from .triton_ops import median_filter_cuda
x = np.pad(x, ((0, 0), (0, 0), (pad_width, pad_width)), mode="reflect")
result = median_filter_cuda(x, filter_width)
except (RuntimeError, subprocess.CalledProcessError):
warnings.warn(
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
"falling back to a slower median kernel implementation..."
)
if result is None:
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
# todo: more efficient version in mlx
result = signal.medfilt(x.astype(np.float32), kernel_size=(1, 1, filter_width))[
..., pad_width:-pad_width
]
if ndim <= 2:
result = result[0, 0]
@ -107,50 +95,9 @@ def dtw_cpu(x: np.ndarray):
return backtrace(trace)
def dtw_cuda(x, BLOCK_SIZE=1024):
from .triton_ops import dtw_kernel
M, N = x.shape
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
x_skew = (
F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
)
x_skew = x_skew.T.contiguous()
cost = torch.ones(N + M + 2, M + 2) * np.inf
cost[0, 0] = 0
cost = cost.cuda()
trace = torch.zeros_like(cost, dtype=torch.int32)
dtw_kernel[(1,)](
cost,
trace,
x_skew,
x_skew.stride(0),
cost.stride(0),
trace.stride(0),
N,
M,
BLOCK_SIZE=BLOCK_SIZE,
)
trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
:, : N + 1
]
return backtrace(trace.cpu().numpy())
def dtw(x: torch.Tensor) -> np.ndarray:
if x.is_cuda:
try:
return dtw_cuda(x)
except (RuntimeError, subprocess.CalledProcessError):
warnings.warn(
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
"falling back to a slower DTW implementation..."
)
return dtw_cpu(x.double().cpu().numpy())
def dtw(x: np.ndarray) -> np.ndarray:
# todo: more efficient version in mlx
return dtw_cpu(x)
@dataclass
@ -166,7 +113,7 @@ def find_alignment(
model: "Whisper",
tokenizer: Tokenizer,
text_tokens: List[int],
mel: torch.Tensor,
mel: mx.array,
num_frames: int,
*,
medfilt_width: int = 7,
@ -175,41 +122,36 @@ def find_alignment(
if len(text_tokens) == 0:
return []
tokens = torch.tensor(
tokens = mx.array(
[
*tokenizer.sot_sequence,
tokenizer.no_timestamps,
*text_tokens,
tokenizer.eot,
]
).to(model.device)
)
# install hooks on the cross attention layers to retrieve the attention weights
QKs = [None] * model.dims.n_text_layer
hooks = [
block.cross_attn.register_forward_hook(
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
)
for i, block in enumerate(model.decoder.blocks)
]
with torch.no_grad():
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
token_probs = sampled_logits.softmax(dim=-1)
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
text_token_probs = text_token_probs.tolist()
for hook in hooks:
hook.remove()
logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :])
# consider only the logits associated with predicting text
sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot]
token_probs = mx.softmax(sampled_logits.astype(mx.float32), axis=-1).astype(
sampled_logits.dtype
)
text_token_probs = mx.take_along_axis(
token_probs, mx.array(text_tokens)[:, None], axis=1
).squeeze(1)
text_token_probs = np.array(text_token_probs)
# heads * tokens * frames
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
weights = mx.stack(
[cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads]
)
weights = weights[:, :, : num_frames // 2]
weights = (weights * qk_scale).softmax(dim=-1)
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
weights = mx.softmax(weights * qk_scale, axis=-1)
mean = mx.mean(weights, axis=-2, keepdims=True)
std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt()
weights = (weights - mean) / std
weights = median_filter(weights, medfilt_width)
weights = median_filter(np.array(weights), medfilt_width)
matrix = weights.mean(axis=0)
matrix = matrix[len(tokenizer.sot_sequence) : -1]
@ -281,7 +223,7 @@ def add_word_timestamps(
segments: List[dict],
model: "Whisper",
tokenizer: Tokenizer,
mel: torch.Tensor,
mel: mx.array,
num_frames: int,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
@ -301,6 +243,7 @@ def add_word_timestamps(
word_durations = np.array([t.end - t.start for t in alignment])
word_durations = word_durations[word_durations.nonzero()]
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
median_duration = min(0.7, float(median_duration))
max_duration = median_duration * 2
# hack: truncate long words at sentence boundaries.

View File

@ -1,7 +1,8 @@
# Copyright © 2023 Apple Inc.
import sys
from typing import Optional, Tuple, Union
import warnings
from typing import List, Optional, Tuple, Union
import mlx.core as mx
import numpy as np
@ -18,7 +19,8 @@ from .audio import (
)
from .decoding import DecodingOptions, DecodingResult
from .load_models import load_model
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .timing import add_word_timestamps
from .tokenizer import LANGUAGES, get_tokenizer
def _format_timestamp(seconds: float):
@ -38,6 +40,13 @@ def _format_timestamp(seconds: float):
return f"{hours_marker}{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
def _get_end(segments: List[dict]) -> Optional[float]:
return next(
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
segments[-1]["end"] if segments else None,
)
class ModelHolder:
model = None
model_path = None
@ -61,8 +70,11 @@ def transcribe(
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None,
**decode_options,
):
"""
@ -99,6 +111,16 @@ def transcribe(
disabling may make the text inconsistent across windows, but the model becomes less prone to
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
word_timestamps: bool
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
and include the timestamps for each word in each segment.
prepend_punctuations: str
If word_timestamps is True, merge these punctuation symbols with the next word
append_punctuations: str
If word_timestamps is True, merge these punctuation symbols with the previous word
initial_prompt: Optional[str]
Optional text to provide as a prompt for the first window. This can be used to provide, or
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
@ -107,6 +129,14 @@ def transcribe(
decode_options: dict
Keyword arguments to construct `DecodingOptions` instances
clip_timestamps: Union[str, List[float]]
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
The last end timestamp defaults to the end of the file.
hallucination_silence_threshold: Optional[float]
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
when a possible hallucination is detected
Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
@ -119,6 +149,7 @@ def transcribe(
# Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES)
content_frames = mel.shape[-2] - N_FRAMES
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
if verbose:
system_encoding = sys.getdefaultencoding()
@ -155,6 +186,22 @@ def transcribe(
task=task,
)
if isinstance(clip_timestamps, str):
clip_timestamps = [
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
]
seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
if len(seek_points) == 0:
seek_points.append(0)
if len(seek_points) % 2 == 1:
seek_points.append(content_frames)
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
punctuation = "\"'“¿([{-\"'.。,!?::”)]}、"
if word_timestamps and task == "translate":
warnings.warn("Word-level timestamps on translations may not be reliable.")
def decode_with_fallback(segment: mx.array) -> DecodingResult:
temperatures = (
[temperature] if isinstance(temperature, (int, float)) else temperature
@ -195,7 +242,8 @@ def transcribe(
return decode_result
seek = 0
clip_idx = 0
seek = seek_clips[clip_idx][0]
input_stride = N_FRAMES // model.dims.n_audio_ctx # mel frames per output token: 2
time_precision = (
input_stride * HOP_LENGTH / SAMPLE_RATE
@ -232,10 +280,23 @@ def transcribe(
total=content_frames, unit="frames", disable=verbose is not False
) as pbar:
last_speech_timestamp = 0.0
while seek < content_frames:
# NOTE: This loop is obscurely flattened to make the diff readable.
# A later commit should turn this into a simpler nested loop.
# for seek_clip_start, seek_clip_end in seek_clips:
# while seek < seek_clip_end
while clip_idx < len(seek_clips):
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
if seek < seek_clip_start:
seek = seek_clip_start
if seek >= seek_clip_end:
clip_idx += 1
if clip_idx < len(seek_clips):
seek = seek_clips[clip_idx][0]
continue
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
mel_segment = mel[seek : seek + N_FRAMES]
segment_size = min(N_FRAMES, content_frames - seek)
window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
mel_segment = mel[seek : seek + segment_size]
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES, axis=-2).astype(dtype)
@ -260,6 +321,30 @@ def transcribe(
previous_seek = seek
current_segments = []
# anomalous words are very long/short/improbable
def word_anomaly_score(word: dict) -> float:
probability = word.get("probability", 0.0)
duration = word["end"] - word["start"]
score = 0.0
if probability < 0.15:
score += 1.0
if duration < 0.133:
score += (0.133 - duration) * 15
if duration > 2.0:
score += duration - 2.0
return score
def is_segment_anomaly(segment: Optional[dict]) -> bool:
if segment is None or not segment["words"]:
return False
words = [w for w in segment["words"] if w["word"] not in punctuation]
words = words[:8]
score = sum(word_anomaly_score(w) for w in words)
return score >= 3 or score + 0.01 >= len(words)
def next_words_segment(segments: List[dict]) -> Optional[dict]:
return next((s for s in segments if s["words"]), None)
timestamp_tokens = tokens >= tokenizer.timestamp_begin
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
@ -324,6 +409,83 @@ def transcribe(
)
seek += segment_size
if word_timestamps:
add_word_timestamps(
segments=current_segments,
model=model,
tokenizer=tokenizer,
mel=mel_segment,
num_frames=segment_size,
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
last_speech_timestamp=last_speech_timestamp,
)
if not single_timestamp_ending:
last_word_end = _get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset:
seek = round(last_word_end * FRAMES_PER_SECOND)
# skip silence before possible hallucinations
if hallucination_silence_threshold is not None:
threshold = hallucination_silence_threshold
if not single_timestamp_ending:
last_word_end = _get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset:
remaining_duration = window_end_time - last_word_end
if remaining_duration > threshold:
seek = round(last_word_end * FRAMES_PER_SECOND)
else:
seek = previous_seek + segment_size
# if first segment might be a hallucination, skip leading silence
first_segment = next_words_segment(current_segments)
if first_segment is not None and is_segment_anomaly(first_segment):
gap = first_segment["start"] - time_offset
if gap > threshold:
seek = previous_seek + round(gap * FRAMES_PER_SECOND)
continue
# skip silence before any possible hallucination that is surrounded
# by silence or more hallucinations
hal_last_end = last_speech_timestamp
for si in range(len(current_segments)):
segment = current_segments[si]
if not segment["words"]:
continue
if is_segment_anomaly(segment):
next_segment = next_words_segment(
current_segments[si + 1 :]
)
if next_segment is not None:
hal_next_start = next_segment["words"][0]["start"]
else:
hal_next_start = time_offset + segment_duration
silence_before = (
segment["start"] - hal_last_end > threshold
or segment["start"] < threshold
or segment["start"] - time_offset < 2.0
)
silence_after = (
hal_next_start - segment["end"] > threshold
or is_segment_anomaly(next_segment)
or window_end_time - segment["end"] < 2.0
)
if silence_before and silence_after:
seek = round(
max(time_offset + 1, segment["start"])
* FRAMES_PER_SECOND
)
if content_duration - segment["end"] < threshold:
seek = content_frames
current_segments[si:] = []
break
hal_last_end = segment["end"]
last_word_end = _get_end(current_segments)
if last_word_end is not None:
last_speech_timestamp = last_word_end
if verbose:
for segment in current_segments:
start, end, text = segment["start"], segment["end"], segment["text"]

View File

@ -4,7 +4,7 @@ import base64
import gzip
import math
from dataclasses import dataclass
from typing import Dict, Iterable, Optional
from typing import Union
import mlx.core as mx
import mlx.nn as nn
@ -72,8 +72,8 @@ class MultiHeadAttention(nn.Module):
else:
k, v = kv_cache
wv = self.qkv_attention(q, k, v, mask)
return self.out(wv), (k, v)
wv, qk = self.qkv_attention(q, k, v, mask)
return self.out(wv), (k, v), qk
def qkv_attention(self, q, k, v, mask=None):
n_batch, n_ctx, n_state = q.shape
@ -85,12 +85,12 @@ class MultiHeadAttention(nn.Module):
qk = q @ k
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
qk = qk.astype(mx.float32)
w = mx.softmax(qk, axis=-1).astype(q.dtype)
out = (w @ v).transpose(0, 2, 1, 3)
out = out.reshape(n_batch, n_ctx, n_state)
return out
return out, qk
class ResidualAttentionBlock(nn.Module):
@ -112,13 +112,16 @@ class ResidualAttentionBlock(nn.Module):
def __call__(self, x, xa=None, mask=None, kv_cache=None):
kv, cross_kv = kv_cache if kv_cache else (None, None)
y, kv = self.attn(self.attn_ln(x), mask=mask, kv_cache=kv)
y, kv, _ = self.attn(self.attn_ln(x), mask=mask, kv_cache=kv)
x += y
cross_qk = None
if self.cross_attn:
y, cross_kv = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=cross_kv)
y, cross_kv, cross_qk = self.cross_attn(
self.cross_attn_ln(x), xa, kv_cache=cross_kv
)
x += y
x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))).astype(x.dtype))
return x, (kv, cross_kv)
return x, (kv, cross_kv), cross_qk
class AudioEncoder(nn.Module):
@ -146,7 +149,7 @@ class AudioEncoder(nn.Module):
x = x + self._positional_embedding
for block in self.blocks:
x, _ = block(x)
x, _, _ = block(x)
x = self.ln_post(x)
return x
@ -191,11 +194,14 @@ class TextDecoder(nn.Module):
if kv_cache is None:
kv_cache = [None] * len(self.blocks)
cross_qk = [None] * len(self.blocks)
for e, block in enumerate(self.blocks):
x, kv_cache[e] = block(x, xa, mask=self._mask, kv_cache=kv_cache[e])
x, kv_cache[e], cross_qk[e] = block(
x, xa, mask=self._mask, kv_cache=kv_cache[e]
)
x = self.ln(x)
return x @ self.token_embedding.weight.T, kv_cache
return x @ self.token_embedding.weight.T, kv_cache, cross_qk
class Whisper(nn.Module):
@ -218,6 +224,28 @@ class Whisper(nn.Module):
self.dims.n_text_layer,
dtype,
)
# use the last half among the decoder layers for time alignment by default;
# to use a specific set of heads, see `set_alignment_heads()` below.
all_heads = np.zeros(
(self.dims.n_text_layer, self.dims.n_text_head), dtype=bool
)
all_heads[self.dims.n_text_layer // 2 :] = True
self.alignment_heads = mx.array(np.asarray(all_heads.nonzero()).T)
def set_alignment_heads(self, dump: Union[bytes, np.ndarray]):
if isinstance(dump, np.ndarray):
self.alignment_heads = mx.array(dump)
elif isinstance(dump, bytes):
array = np.frombuffer(
gzip.decompress(base64.b85decode(dump)), dtype=bool
).copy()
mask = array.reshape(self.dims.n_text_layer, self.dims.n_text_head)
self.alignment_heads = mx.array(np.asarray(mask.nonzero()).T)
else:
raise ValueError(
f"Invalid type for `dump`: {type(dump)}. Expected a np.ndarray or base85-encoded bytes containing"
" alignment_head information"
)
def embed_audio(self, mel):
return self.encoder(mel)
@ -225,6 +253,10 @@ class Whisper(nn.Module):
def logits(self, tokens, audio_features):
return self.decoder(tokens, audio_features)[0]
def forward_with_cross_qk(self, mel, tokens):
logits, _, cross_qk = self.decoder(tokens, self.encoder(mel))
return logits, cross_qk
def __call__(self, mel, tokens):
return self.decoder(tokens, self.encoder(mel))[0]