Merge branch 'ml-explore:main' into main

This commit is contained in:
Chime Ogbuji
2024-01-07 17:06:42 -05:00
committed by GitHub
9 changed files with 400 additions and 113 deletions

View File

@@ -41,7 +41,7 @@ models](https://github.com/ml-explore/mlx-examples/issues/155).
We are grateful for all of [our
contributors](ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
to MLX Examples and wish to be acknowledged, please add your name to to the list in your
to MLX Examples and wish to be acknowledged, please add your name to the list in your
pull request.
## Citing MLX Examples

View File

@@ -87,7 +87,7 @@ if __name__ == "__main__":
# Copy the tokenizer
tokenizer_path = torch_path / "tokenizer.model"
if not tokenizer_path.exists():
print(f"Make sure there is a file tokenizer.model in {args.torch-path}")
print(f"Make sure there is a file tokenizer.model in {args.torch_path}")
exit(0)
shutil.copyfile(
str(tokenizer_path),

View File

@@ -2,7 +2,7 @@
Speech recognition with Whisper in MLX. Whisper is a set of open source speech
recognition models from OpenAI, ranging from 39 million to 1.5 billion
parameters[^1].
parameters.[^1]
### Setup
@@ -19,7 +19,8 @@ Install [`ffmpeg`](https://ffmpeg.org/):
brew install ffmpeg
```
Next, download the Whisper PyTorch checkpoint and convert the weights to the MLX format. For example, to convert the `tiny` model use:
Next, download the Whisper PyTorch checkpoint and convert the weights to the
MLX format. For example, to convert the `tiny` model use:
```
python convert.py --torch-name-or-path tiny --mlx-path mlx_models/tiny
@@ -45,10 +46,24 @@ the converted `weights.npz` and `config.json` there.
Transcribe audio with:
```
```python
import whisper
text = whisper.transcribe(speech_file)["text"]
```
The `transcribe` function also supports word-level timestamps. You can generate
these with:
```python
output = whisper.transcribe(speech_file, word_timestamps=True)
print(output["segments"][0]["words"])
```
To see more transcription options use:
```
>>> help(whisper.transcribe)
```
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2212.04356), [blog post](https://openai.com/research/whisper), and [code](https://github.com/openai/whisper) for more details.

View File

@@ -199,6 +199,10 @@ def torch_to_mlx(
mlx_model = Whisper(torch_model.dims, dtype)
params = tree_map(lambda p: p.astype(dtype), params)
mlx_model.update(params)
if (alignment_heads := getattr(torch_model, "alignment_heads", None)) is not None:
mlx_model.set_alignment_heads(alignment_heads.indices().T.numpy())
return mlx_model

View File

@@ -311,6 +311,137 @@ class TestWhisper(unittest.TestCase):
check_segment(result["segments"][5], expected_5)
check_segment(result["segments"][73], expected_73)
def test_transcribe_word_level_timestamps_confidence_scores(self):
result = whisper.transcribe(
# TEST_AUDIO, model_path=MLX_FP32_MODEL_PATH, word_timestamps=True, fp16=False
TEST_AUDIO,
model_path=MLX_FP16_MODEL_PATH,
word_timestamps=True,
)
# result predicted with openai-whisper
expected_0 = [
{
"word": " Then",
"start": 0.0,
"end": 0.94,
"probability": 0.855542778968811,
},
{
"word": " the",
"start": 0.94,
"end": 1.12,
"probability": 0.6500106453895569,
},
{
"word": " good",
"start": 1.12,
"end": 1.32,
"probability": 0.5503873825073242,
},
{
"word": " soul",
"start": 1.32,
"end": 1.56,
"probability": 0.46757155656814575,
},
{
"word": " openly",
"start": 1.56,
"end": 2.0,
"probability": 0.9840946793556213,
},
{
"word": " sorted",
"start": 2.0,
"end": 2.38,
"probability": 0.24167272448539734,
},
{
"word": " the",
"start": 2.38,
"end": 2.58,
"probability": 0.9875414967536926,
},
{
"word": " boat",
"start": 2.58,
"end": 2.8,
"probability": 0.5856029391288757,
},
{
"word": " and",
"start": 2.8,
"end": 2.98,
"probability": 0.913351833820343,
},
{
"word": " she",
"start": 2.98,
"end": 3.1,
"probability": 0.9913808703422546,
},
{
"word": " had",
"start": 3.1,
"end": 3.32,
"probability": 0.9952940344810486,
},
{
"word": " buoyed",
"start": 3.32,
"end": 3.58,
"probability": 0.6411589980125427,
},
{
"word": " so",
"start": 3.58,
"end": 3.8,
"probability": 0.9682658314704895,
},
{
"word": " long",
"start": 3.8,
"end": 4.06,
"probability": 0.9953522682189941,
},
{
"word": " in",
"start": 4.06,
"end": 4.26,
"probability": 0.6745936870574951,
},
{
"word": " secret",
"start": 4.26,
"end": 4.56,
"probability": 0.9905064702033997,
},
{
"word": " and",
"start": 4.56,
"end": 4.9,
"probability": 0.856008768081665,
},
{
"word": " bravely",
"start": 4.9,
"end": 5.28,
"probability": 0.8477402329444885,
},
]
def check_words(words, expected_words):
for word, expected_word in zip(words, expected_words):
for k, v in expected_word.items():
if isinstance(v, float):
self.assertAlmostEqual(word[k], v, places=1)
else:
self.assertEqual(word[k], v)
# Randomly check a couple of segments
check_words(result["segments"][0]["words"], expected_0)
class TestAudio(unittest.TestCase):
def test_load(self):

View File

@@ -141,7 +141,7 @@ class Inference:
# only need to use the last token except in the first forward pass
tokens = tokens[:, -1:]
logits, self.kv_cache = self.model.decoder(
logits, self.kv_cache, _ = self.model.decoder(
tokens, audio_features, kv_cache=self.kv_cache
)
return logits.astype(mx.float32)

View File

@@ -1,15 +1,13 @@
# Copyright © 2023 Apple Inc.
import itertools
import subprocess
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, List
import mlx.core as mx
import numba
import numpy as np
import torch
import torch.nn.functional as F
from scipy import signal
from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
from .tokenizer import Tokenizer
@@ -18,7 +16,7 @@ if TYPE_CHECKING:
from .model import Whisper
def median_filter(x: torch.Tensor, filter_width: int):
def median_filter(x: np.ndarray, filter_width: int):
"""Apply a median filter of width `filter_width` along the last dimension of `x`"""
pad_width = filter_width // 2
if x.shape[-1] <= pad_width:
@@ -33,22 +31,12 @@ def median_filter(x: torch.Tensor, filter_width: int):
filter_width > 0 and filter_width % 2 == 1
), "`filter_width` should be an odd number"
result = None
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
if x.is_cuda:
try:
from .triton_ops import median_filter_cuda
x = np.pad(x, ((0, 0), (0, 0), (pad_width, pad_width)), mode="reflect")
result = median_filter_cuda(x, filter_width)
except (RuntimeError, subprocess.CalledProcessError):
warnings.warn(
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
"falling back to a slower median kernel implementation..."
)
if result is None:
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
# todo: more efficient version in mlx
result = signal.medfilt(x.astype(np.float32), kernel_size=(1, 1, filter_width))[
..., pad_width:-pad_width
]
if ndim <= 2:
result = result[0, 0]
@@ -107,50 +95,9 @@ def dtw_cpu(x: np.ndarray):
return backtrace(trace)
def dtw_cuda(x, BLOCK_SIZE=1024):
from .triton_ops import dtw_kernel
M, N = x.shape
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
x_skew = (
F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
)
x_skew = x_skew.T.contiguous()
cost = torch.ones(N + M + 2, M + 2) * np.inf
cost[0, 0] = 0
cost = cost.cuda()
trace = torch.zeros_like(cost, dtype=torch.int32)
dtw_kernel[(1,)](
cost,
trace,
x_skew,
x_skew.stride(0),
cost.stride(0),
trace.stride(0),
N,
M,
BLOCK_SIZE=BLOCK_SIZE,
)
trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
:, : N + 1
]
return backtrace(trace.cpu().numpy())
def dtw(x: torch.Tensor) -> np.ndarray:
if x.is_cuda:
try:
return dtw_cuda(x)
except (RuntimeError, subprocess.CalledProcessError):
warnings.warn(
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
"falling back to a slower DTW implementation..."
)
return dtw_cpu(x.double().cpu().numpy())
def dtw(x: np.ndarray) -> np.ndarray:
# todo: more efficient version in mlx
return dtw_cpu(x)
@dataclass
@@ -166,7 +113,7 @@ def find_alignment(
model: "Whisper",
tokenizer: Tokenizer,
text_tokens: List[int],
mel: torch.Tensor,
mel: mx.array,
num_frames: int,
*,
medfilt_width: int = 7,
@@ -175,41 +122,36 @@ def find_alignment(
if len(text_tokens) == 0:
return []
tokens = torch.tensor(
tokens = mx.array(
[
*tokenizer.sot_sequence,
tokenizer.no_timestamps,
*text_tokens,
tokenizer.eot,
]
).to(model.device)
)
# install hooks on the cross attention layers to retrieve the attention weights
QKs = [None] * model.dims.n_text_layer
hooks = [
block.cross_attn.register_forward_hook(
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
)
for i, block in enumerate(model.decoder.blocks)
]
with torch.no_grad():
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
token_probs = sampled_logits.softmax(dim=-1)
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
text_token_probs = text_token_probs.tolist()
for hook in hooks:
hook.remove()
logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :])
# consider only the logits associated with predicting text
sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot]
token_probs = mx.softmax(sampled_logits.astype(mx.float32), axis=-1).astype(
sampled_logits.dtype
)
text_token_probs = mx.take_along_axis(
token_probs, mx.array(text_tokens)[:, None], axis=1
).squeeze(1)
text_token_probs = np.array(text_token_probs)
# heads * tokens * frames
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
weights = mx.stack(
[cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads]
)
weights = weights[:, :, : num_frames // 2]
weights = (weights * qk_scale).softmax(dim=-1)
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
weights = mx.softmax(weights * qk_scale, axis=-1)
mean = mx.mean(weights, axis=-2, keepdims=True)
std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt()
weights = (weights - mean) / std
weights = median_filter(weights, medfilt_width)
weights = median_filter(np.array(weights), medfilt_width)
matrix = weights.mean(axis=0)
matrix = matrix[len(tokenizer.sot_sequence) : -1]
@@ -281,7 +223,7 @@ def add_word_timestamps(
segments: List[dict],
model: "Whisper",
tokenizer: Tokenizer,
mel: torch.Tensor,
mel: mx.array,
num_frames: int,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
@@ -301,6 +243,7 @@ def add_word_timestamps(
word_durations = np.array([t.end - t.start for t in alignment])
word_durations = word_durations[word_durations.nonzero()]
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
median_duration = min(0.7, float(median_duration))
max_duration = median_duration * 2
# hack: truncate long words at sentence boundaries.

View File

@@ -1,7 +1,8 @@
# Copyright © 2023 Apple Inc.
import sys
from typing import Optional, Tuple, Union
import warnings
from typing import List, Optional, Tuple, Union
import mlx.core as mx
import numpy as np
@@ -18,7 +19,8 @@ from .audio import (
)
from .decoding import DecodingOptions, DecodingResult
from .load_models import load_model
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .timing import add_word_timestamps
from .tokenizer import LANGUAGES, get_tokenizer
def _format_timestamp(seconds: float):
@@ -38,6 +40,13 @@ def _format_timestamp(seconds: float):
return f"{hours_marker}{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
def _get_end(segments: List[dict]) -> Optional[float]:
return next(
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
segments[-1]["end"] if segments else None,
)
class ModelHolder:
model = None
model_path = None
@@ -61,8 +70,11 @@ def transcribe(
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None,
**decode_options,
):
"""
@@ -99,6 +111,16 @@ def transcribe(
disabling may make the text inconsistent across windows, but the model becomes less prone to
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
word_timestamps: bool
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
and include the timestamps for each word in each segment.
prepend_punctuations: str
If word_timestamps is True, merge these punctuation symbols with the next word
append_punctuations: str
If word_timestamps is True, merge these punctuation symbols with the previous word
initial_prompt: Optional[str]
Optional text to provide as a prompt for the first window. This can be used to provide, or
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
@@ -107,6 +129,14 @@ def transcribe(
decode_options: dict
Keyword arguments to construct `DecodingOptions` instances
clip_timestamps: Union[str, List[float]]
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
The last end timestamp defaults to the end of the file.
hallucination_silence_threshold: Optional[float]
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
when a possible hallucination is detected
Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
@@ -119,6 +149,7 @@ def transcribe(
# Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES)
content_frames = mel.shape[-2] - N_FRAMES
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
if verbose:
system_encoding = sys.getdefaultencoding()
@@ -155,6 +186,22 @@ def transcribe(
task=task,
)
if isinstance(clip_timestamps, str):
clip_timestamps = [
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
]
seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
if len(seek_points) == 0:
seek_points.append(0)
if len(seek_points) % 2 == 1:
seek_points.append(content_frames)
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
punctuation = "\"'“¿([{-\"'.。,!?::”)]}、"
if word_timestamps and task == "translate":
warnings.warn("Word-level timestamps on translations may not be reliable.")
def decode_with_fallback(segment: mx.array) -> DecodingResult:
temperatures = (
[temperature] if isinstance(temperature, (int, float)) else temperature
@@ -195,7 +242,8 @@ def transcribe(
return decode_result
seek = 0
clip_idx = 0
seek = seek_clips[clip_idx][0]
input_stride = N_FRAMES // model.dims.n_audio_ctx # mel frames per output token: 2
time_precision = (
input_stride * HOP_LENGTH / SAMPLE_RATE
@@ -232,10 +280,23 @@ def transcribe(
total=content_frames, unit="frames", disable=verbose is not False
) as pbar:
last_speech_timestamp = 0.0
while seek < content_frames:
# NOTE: This loop is obscurely flattened to make the diff readable.
# A later commit should turn this into a simpler nested loop.
# for seek_clip_start, seek_clip_end in seek_clips:
# while seek < seek_clip_end
while clip_idx < len(seek_clips):
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
if seek < seek_clip_start:
seek = seek_clip_start
if seek >= seek_clip_end:
clip_idx += 1
if clip_idx < len(seek_clips):
seek = seek_clips[clip_idx][0]
continue
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
mel_segment = mel[seek : seek + N_FRAMES]
segment_size = min(N_FRAMES, content_frames - seek)
window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
mel_segment = mel[seek : seek + segment_size]
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES, axis=-2).astype(dtype)
@@ -260,6 +321,30 @@ def transcribe(
previous_seek = seek
current_segments = []
# anomalous words are very long/short/improbable
def word_anomaly_score(word: dict) -> float:
probability = word.get("probability", 0.0)
duration = word["end"] - word["start"]
score = 0.0
if probability < 0.15:
score += 1.0
if duration < 0.133:
score += (0.133 - duration) * 15
if duration > 2.0:
score += duration - 2.0
return score
def is_segment_anomaly(segment: Optional[dict]) -> bool:
if segment is None or not segment["words"]:
return False
words = [w for w in segment["words"] if w["word"] not in punctuation]
words = words[:8]
score = sum(word_anomaly_score(w) for w in words)
return score >= 3 or score + 0.01 >= len(words)
def next_words_segment(segments: List[dict]) -> Optional[dict]:
return next((s for s in segments if s["words"]), None)
timestamp_tokens = tokens >= tokenizer.timestamp_begin
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
@@ -324,6 +409,83 @@ def transcribe(
)
seek += segment_size
if word_timestamps:
add_word_timestamps(
segments=current_segments,
model=model,
tokenizer=tokenizer,
mel=mel_segment,
num_frames=segment_size,
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
last_speech_timestamp=last_speech_timestamp,
)
if not single_timestamp_ending:
last_word_end = _get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset:
seek = round(last_word_end * FRAMES_PER_SECOND)
# skip silence before possible hallucinations
if hallucination_silence_threshold is not None:
threshold = hallucination_silence_threshold
if not single_timestamp_ending:
last_word_end = _get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset:
remaining_duration = window_end_time - last_word_end
if remaining_duration > threshold:
seek = round(last_word_end * FRAMES_PER_SECOND)
else:
seek = previous_seek + segment_size
# if first segment might be a hallucination, skip leading silence
first_segment = next_words_segment(current_segments)
if first_segment is not None and is_segment_anomaly(first_segment):
gap = first_segment["start"] - time_offset
if gap > threshold:
seek = previous_seek + round(gap * FRAMES_PER_SECOND)
continue
# skip silence before any possible hallucination that is surrounded
# by silence or more hallucinations
hal_last_end = last_speech_timestamp
for si in range(len(current_segments)):
segment = current_segments[si]
if not segment["words"]:
continue
if is_segment_anomaly(segment):
next_segment = next_words_segment(
current_segments[si + 1 :]
)
if next_segment is not None:
hal_next_start = next_segment["words"][0]["start"]
else:
hal_next_start = time_offset + segment_duration
silence_before = (
segment["start"] - hal_last_end > threshold
or segment["start"] < threshold
or segment["start"] - time_offset < 2.0
)
silence_after = (
hal_next_start - segment["end"] > threshold
or is_segment_anomaly(next_segment)
or window_end_time - segment["end"] < 2.0
)
if silence_before and silence_after:
seek = round(
max(time_offset + 1, segment["start"])
* FRAMES_PER_SECOND
)
if content_duration - segment["end"] < threshold:
seek = content_frames
current_segments[si:] = []
break
hal_last_end = segment["end"]
last_word_end = _get_end(current_segments)
if last_word_end is not None:
last_speech_timestamp = last_word_end
if verbose:
for segment in current_segments:
start, end, text = segment["start"], segment["end"], segment["text"]

View File

@@ -4,7 +4,7 @@ import base64
import gzip
import math
from dataclasses import dataclass
from typing import Dict, Iterable, Optional
from typing import Union
import mlx.core as mx
import mlx.nn as nn
@@ -72,8 +72,8 @@ class MultiHeadAttention(nn.Module):
else:
k, v = kv_cache
wv = self.qkv_attention(q, k, v, mask)
return self.out(wv), (k, v)
wv, qk = self.qkv_attention(q, k, v, mask)
return self.out(wv), (k, v), qk
def qkv_attention(self, q, k, v, mask=None):
n_batch, n_ctx, n_state = q.shape
@@ -85,12 +85,12 @@ class MultiHeadAttention(nn.Module):
qk = q @ k
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
qk = qk.astype(mx.float32)
w = mx.softmax(qk, axis=-1).astype(q.dtype)
out = (w @ v).transpose(0, 2, 1, 3)
out = out.reshape(n_batch, n_ctx, n_state)
return out
return out, qk
class ResidualAttentionBlock(nn.Module):
@@ -112,13 +112,16 @@ class ResidualAttentionBlock(nn.Module):
def __call__(self, x, xa=None, mask=None, kv_cache=None):
kv, cross_kv = kv_cache if kv_cache else (None, None)
y, kv = self.attn(self.attn_ln(x), mask=mask, kv_cache=kv)
y, kv, _ = self.attn(self.attn_ln(x), mask=mask, kv_cache=kv)
x += y
cross_qk = None
if self.cross_attn:
y, cross_kv = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=cross_kv)
y, cross_kv, cross_qk = self.cross_attn(
self.cross_attn_ln(x), xa, kv_cache=cross_kv
)
x += y
x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))).astype(x.dtype))
return x, (kv, cross_kv)
return x, (kv, cross_kv), cross_qk
class AudioEncoder(nn.Module):
@@ -146,7 +149,7 @@ class AudioEncoder(nn.Module):
x = x + self._positional_embedding
for block in self.blocks:
x, _ = block(x)
x, _, _ = block(x)
x = self.ln_post(x)
return x
@@ -191,11 +194,14 @@ class TextDecoder(nn.Module):
if kv_cache is None:
kv_cache = [None] * len(self.blocks)
cross_qk = [None] * len(self.blocks)
for e, block in enumerate(self.blocks):
x, kv_cache[e] = block(x, xa, mask=self._mask, kv_cache=kv_cache[e])
x, kv_cache[e], cross_qk[e] = block(
x, xa, mask=self._mask, kv_cache=kv_cache[e]
)
x = self.ln(x)
return x @ self.token_embedding.weight.T, kv_cache
return x @ self.token_embedding.weight.T, kv_cache, cross_qk
class Whisper(nn.Module):
@@ -218,6 +224,28 @@ class Whisper(nn.Module):
self.dims.n_text_layer,
dtype,
)
# use the last half among the decoder layers for time alignment by default;
# to use a specific set of heads, see `set_alignment_heads()` below.
all_heads = np.zeros(
(self.dims.n_text_layer, self.dims.n_text_head), dtype=bool
)
all_heads[self.dims.n_text_layer // 2 :] = True
self.alignment_heads = mx.array(np.asarray(all_heads.nonzero()).T)
def set_alignment_heads(self, dump: Union[bytes, np.ndarray]):
if isinstance(dump, np.ndarray):
self.alignment_heads = mx.array(dump)
elif isinstance(dump, bytes):
array = np.frombuffer(
gzip.decompress(base64.b85decode(dump)), dtype=bool
).copy()
mask = array.reshape(self.dims.n_text_layer, self.dims.n_text_head)
self.alignment_heads = mx.array(np.asarray(mask.nonzero()).T)
else:
raise ValueError(
f"Invalid type for `dump`: {type(dump)}. Expected a np.ndarray or base85-encoded bytes containing"
" alignment_head information"
)
def embed_audio(self, mel):
return self.encoder(mel)
@@ -225,6 +253,10 @@ class Whisper(nn.Module):
def logits(self, tokens, audio_features):
return self.decoder(tokens, audio_features)[0]
def forward_with_cross_qk(self, mel, tokens):
logits, _, cross_qk = self.decoder(tokens, self.encoder(mel))
return logits, cross_qk
def __call__(self, mel, tokens):
return self.decoder(tokens, self.encoder(mel))[0]