cleanup whisper a little (#639)

This commit is contained in:
Awni Hannun 2024-03-30 13:13:58 -07:00 committed by GitHub
parent f6283ef7ce
commit 78c431dc25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 237 additions and 221 deletions

View File

@ -239,12 +239,13 @@ def generate(
), ),
range(max_tokens), range(max_tokens),
): ):
if token == tokenizer.eos_token_id: token = token.item()
break
if n == 0: if n == 0:
prompt_time = time.perf_counter() - tic prompt_time = time.perf_counter() - tic
tic = time.perf_counter() tic = time.perf_counter()
tokens.append(token.item()) if token == tokenizer.eos_token_id:
break
tokens.append(token)
if verbose: if verbose:
s = tokenizer.decode(tokens) s = tokenizer.decode(tokens)

View File

@ -91,7 +91,8 @@ def _download(url: str, root: str) -> str:
output.write(buffer) output.write(buffer)
loop.update(len(buffer)) loop.update(len(buffer))
model_bytes = open(download_target, "rb").read() with open(download_target, "rb") as fid:
model_bytes = fid.read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError( raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."

View File

@ -297,7 +297,7 @@ class TestWhisper(unittest.TestCase):
"temperature": 0.0, "temperature": 0.0,
"avg_logprob": -0.1350895343440594, "avg_logprob": -0.1350895343440594,
"compression_ratio": 1.6208333333333333, "compression_ratio": 1.6208333333333333,
"no_speech_prob": 0.002246702555567026, "no_speech_prob": 0.009053784422576427,
} }
def check_segment(seg, expected): def check_segment(seg, expected):

View File

@ -58,7 +58,7 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
except CalledProcessError as e: except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 return mx.array(np.frombuffer(out, np.int16)).flatten().astype(mx.float32) / 32768.0
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
@ -73,8 +73,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
if array.shape[axis] < length: if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis]) pad_widths[axis] = (0, length - array.shape[axis])
pad_fn = mx.pad if isinstance(array, mx.array) else np.pad array = mx.pad(array, pad_widths)
array = pad_fn(array, pad_widths)
return array return array
@ -154,9 +153,9 @@ def log_mel_spectrogram(
""" """
device = mx.default_device() device = mx.default_device()
mx.set_default_device(mx.cpu) mx.set_default_device(mx.cpu)
if not isinstance(audio, mx.array): if isinstance(audio, str):
if isinstance(audio, str): audio = load_audio(audio)
audio = load_audio(audio) elif not isinstance(audio, mx.array):
audio = mx.array(audio) audio = mx.array(audio)
if padding > 0: if padding > 0:

View File

@ -280,243 +280,258 @@ def transcribe(
total=content_frames, unit="frames", disable=verbose is not False total=content_frames, unit="frames", disable=verbose is not False
) as pbar: ) as pbar:
last_speech_timestamp = 0.0 last_speech_timestamp = 0.0
# NOTE: This loop is obscurely flattened to make the diff readable. for seek_clip_start, seek_clip_end in seek_clips:
# A later commit should turn this into a simpler nested loop. while seek < seek_clip_end:
# for seek_clip_start, seek_clip_end in seek_clips: time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
# while seek < seek_clip_end window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
while clip_idx < len(seek_clips): segment_size = min(
seek_clip_start, seek_clip_end = seek_clips[clip_idx] N_FRAMES, content_frames - seek, seek_clip_end - seek
if seek < seek_clip_start: )
seek = seek_clip_start mel_segment = mel[seek : seek + segment_size]
if seek >= seek_clip_end: segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
clip_idx += 1 mel_segment = pad_or_trim(mel_segment, N_FRAMES, axis=-2).astype(dtype)
if clip_idx < len(seek_clips):
seek = seek_clips[clip_idx][0]
continue
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
mel_segment = mel[seek : seek + segment_size]
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES, axis=-2).astype(dtype)
decode_options["prompt"] = all_tokens[prompt_reset_since:] decode_options["prompt"] = all_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(mel_segment) result: DecodingResult = decode_with_fallback(mel_segment)
tokens = np.array(result.tokens) tokens = np.array(result.tokens)
if no_speech_threshold is not None: if no_speech_threshold is not None:
# no voice activity check # no voice activity check
should_skip = result.no_speech_prob > no_speech_threshold should_skip = result.no_speech_prob > no_speech_threshold
if ( if (
logprob_threshold is not None logprob_threshold is not None
and result.avg_logprob > logprob_threshold and result.avg_logprob > logprob_threshold
): ):
# don't skip if the logprob is high enough, despite the no_speech_prob # don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False should_skip = False
if should_skip: if should_skip:
seek += segment_size # fast-forward to the next segment boundary seek += (
continue segment_size # fast-forward to the next segment boundary
)
continue
previous_seek = seek previous_seek = seek
current_segments = [] current_segments = []
# anomalous words are very long/short/improbable # anomalous words are very long/short/improbable
def word_anomaly_score(word: dict) -> float: def word_anomaly_score(word: dict) -> float:
probability = word.get("probability", 0.0) probability = word.get("probability", 0.0)
duration = word["end"] - word["start"] duration = word["end"] - word["start"]
score = 0.0 score = 0.0
if probability < 0.15: if probability < 0.15:
score += 1.0 score += 1.0
if duration < 0.133: if duration < 0.133:
score += (0.133 - duration) * 15 score += (0.133 - duration) * 15
if duration > 2.0: if duration > 2.0:
score += duration - 2.0 score += duration - 2.0
return score return score
def is_segment_anomaly(segment: Optional[dict]) -> bool: def is_segment_anomaly(segment: Optional[dict]) -> bool:
if segment is None or not segment["words"]: if segment is None or not segment["words"]:
return False return False
words = [w for w in segment["words"] if w["word"] not in punctuation] words = [
words = words[:8] w for w in segment["words"] if w["word"] not in punctuation
score = sum(word_anomaly_score(w) for w in words) ]
return score >= 3 or score + 0.01 >= len(words) words = words[:8]
score = sum(word_anomaly_score(w) for w in words)
return score >= 3 or score + 0.01 >= len(words)
def next_words_segment(segments: List[dict]) -> Optional[dict]: def next_words_segment(segments: List[dict]) -> Optional[dict]:
return next((s for s in segments if s["words"]), None) return next((s for s in segments if s["words"]), None)
timestamp_tokens = tokens >= tokenizer.timestamp_begin timestamp_tokens = tokens >= tokenizer.timestamp_begin
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] single_timestamp_ending = timestamp_tokens[-2:].tolist() == [
False,
True,
]
consecutive = np.where( consecutive = np.where(
np.logical_and(timestamp_tokens[:-1], timestamp_tokens[1:]) np.logical_and(timestamp_tokens[:-1], timestamp_tokens[1:])
)[0] )[0]
consecutive += 1 consecutive += 1
if len(consecutive) > 0: if len(consecutive) > 0:
# if the output contains two consecutive timestamp tokens # if the output contains two consecutive timestamp tokens
slices = consecutive.tolist() slices = consecutive.tolist()
if single_timestamp_ending: if single_timestamp_ending:
slices.append(len(tokens)) slices.append(len(tokens))
last_slice = 0
for current_slice in slices:
sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_pos = (
sliced_tokens[0].item() - tokenizer.timestamp_begin
)
end_timestamp_pos = (
sliced_tokens[-1].item() - tokenizer.timestamp_begin
)
current_segments.append(
new_segment(
start=time_offset
+ start_timestamp_pos * time_precision,
end=time_offset + end_timestamp_pos * time_precision,
tokens=sliced_tokens,
result=result,
)
)
last_slice = current_slice
if single_timestamp_ending:
# single timestamp at the end means no speech after the last timestamp.
seek += segment_size
else:
# otherwise, ignore the unfinished segment and seek to the last timestamp
last_timestamp_pos = (
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
)
seek += last_timestamp_pos * input_stride
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero()[0]]
if (
len(timestamps) > 0
and timestamps[-1].item() != tokenizer.timestamp_begin
):
# no consecutive timestamps but it has a timestamp; use the last one.
last_timestamp_pos = (
timestamps[-1].item() - tokenizer.timestamp_begin
)
duration = last_timestamp_pos * time_precision
last_slice = 0
for current_slice in slices:
sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_pos = (
sliced_tokens[0].item() - tokenizer.timestamp_begin
)
end_timestamp_pos = (
sliced_tokens[-1].item() - tokenizer.timestamp_begin
)
current_segments.append( current_segments.append(
new_segment( new_segment(
start=time_offset + start_timestamp_pos * time_precision, start=time_offset,
end=time_offset + end_timestamp_pos * time_precision, end=time_offset + duration,
tokens=sliced_tokens, tokens=tokens,
result=result, result=result,
) )
) )
last_slice = current_slice
if single_timestamp_ending:
# single timestamp at the end means no speech after the last timestamp.
seek += segment_size seek += segment_size
else:
# otherwise, ignore the unfinished segment and seek to the last timestamp if word_timestamps:
last_timestamp_pos = ( add_word_timestamps(
tokens[last_slice - 1].item() - tokenizer.timestamp_begin segments=current_segments,
model=model,
tokenizer=tokenizer,
mel=mel_segment,
num_frames=segment_size,
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
last_speech_timestamp=last_speech_timestamp,
) )
seek += last_timestamp_pos * input_stride
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero()[0]]
if (
len(timestamps) > 0
and timestamps[-1].item() != tokenizer.timestamp_begin
):
# no consecutive timestamps but it has a timestamp; use the last one.
last_timestamp_pos = (
timestamps[-1].item() - tokenizer.timestamp_begin
)
duration = last_timestamp_pos * time_precision
current_segments.append(
new_segment(
start=time_offset,
end=time_offset + duration,
tokens=tokens,
result=result,
)
)
seek += segment_size
if word_timestamps:
add_word_timestamps(
segments=current_segments,
model=model,
tokenizer=tokenizer,
mel=mel_segment,
num_frames=segment_size,
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
last_speech_timestamp=last_speech_timestamp,
)
if not single_timestamp_ending:
last_word_end = _get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset:
seek = round(last_word_end * FRAMES_PER_SECOND)
# skip silence before possible hallucinations
if hallucination_silence_threshold is not None:
threshold = hallucination_silence_threshold
if not single_timestamp_ending: if not single_timestamp_ending:
last_word_end = _get_end(current_segments) last_word_end = _get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset: if last_word_end is not None and last_word_end > time_offset:
remaining_duration = window_end_time - last_word_end seek = round(last_word_end * FRAMES_PER_SECOND)
if remaining_duration > threshold:
seek = round(last_word_end * FRAMES_PER_SECOND)
else:
seek = previous_seek + segment_size
# if first segment might be a hallucination, skip leading silence # skip silence before possible hallucinations
first_segment = next_words_segment(current_segments) if hallucination_silence_threshold is not None:
if first_segment is not None and is_segment_anomaly(first_segment): threshold = hallucination_silence_threshold
gap = first_segment["start"] - time_offset if not single_timestamp_ending:
if gap > threshold: last_word_end = _get_end(current_segments)
seek = previous_seek + round(gap * FRAMES_PER_SECOND) if (
continue last_word_end is not None
and last_word_end > time_offset
):
remaining_duration = window_end_time - last_word_end
if remaining_duration > threshold:
seek = round(last_word_end * FRAMES_PER_SECOND)
else:
seek = previous_seek + segment_size
# skip silence before any possible hallucination that is surrounded # if first segment might be a hallucination, skip leading silence
# by silence or more hallucinations first_segment = next_words_segment(current_segments)
hal_last_end = last_speech_timestamp if first_segment is not None and is_segment_anomaly(
for si in range(len(current_segments)): first_segment
segment = current_segments[si] ):
if not segment["words"]: gap = first_segment["start"] - time_offset
continue if gap > threshold:
if is_segment_anomaly(segment): seek = previous_seek + round(gap * FRAMES_PER_SECOND)
next_segment = next_words_segment( continue
current_segments[si + 1 :]
) # skip silence before any possible hallucination that is surrounded
if next_segment is not None: # by silence or more hallucinations
hal_next_start = next_segment["words"][0]["start"] hal_last_end = last_speech_timestamp
else: for si in range(len(current_segments)):
hal_next_start = time_offset + segment_duration segment = current_segments[si]
silence_before = ( if not segment["words"]:
segment["start"] - hal_last_end > threshold continue
or segment["start"] < threshold if is_segment_anomaly(segment):
or segment["start"] - time_offset < 2.0 next_segment = next_words_segment(
) current_segments[si + 1 :]
silence_after = (
hal_next_start - segment["end"] > threshold
or is_segment_anomaly(next_segment)
or window_end_time - segment["end"] < 2.0
)
if silence_before and silence_after:
seek = round(
max(time_offset + 1, segment["start"])
* FRAMES_PER_SECOND
) )
if content_duration - segment["end"] < threshold: if next_segment is not None:
seek = content_frames hal_next_start = next_segment["words"][0]["start"]
current_segments[si:] = [] else:
break hal_next_start = time_offset + segment_duration
hal_last_end = segment["end"] silence_before = (
segment["start"] - hal_last_end > threshold
or segment["start"] < threshold
or segment["start"] - time_offset < 2.0
)
silence_after = (
hal_next_start - segment["end"] > threshold
or is_segment_anomaly(next_segment)
or window_end_time - segment["end"] < 2.0
)
if silence_before and silence_after:
seek = round(
max(time_offset + 1, segment["start"])
* FRAMES_PER_SECOND
)
if content_duration - segment["end"] < threshold:
seek = content_frames
current_segments[si:] = []
break
hal_last_end = segment["end"]
last_word_end = _get_end(current_segments) last_word_end = _get_end(current_segments)
if last_word_end is not None: if last_word_end is not None:
last_speech_timestamp = last_word_end last_speech_timestamp = last_word_end
if verbose: if verbose:
for segment in current_segments: for segment in current_segments:
start, end, text = segment["start"], segment["end"], segment["text"] start, end, text = (
line = f"[{_format_timestamp(start)} --> {_format_timestamp(end)}] {text}" segment["start"],
print(make_safe(line)) segment["end"],
segment["text"],
)
line = f"[{_format_timestamp(start)} --> {_format_timestamp(end)}] {text}"
print(make_safe(line))
# if a segment is instantaneous or does not contain text, clear it # if a segment is instantaneous or does not contain text, clear it
for i, segment in enumerate(current_segments): for i, segment in enumerate(current_segments):
if segment["start"] == segment["end"] or segment["text"].strip() == "": if (
segment["text"] = "" segment["start"] == segment["end"]
segment["tokens"] = [] or segment["text"].strip() == ""
segment["words"] = [] ):
segment["text"] = ""
segment["tokens"] = []
segment["words"] = []
all_segments.extend( all_segments.extend(
[ [
{"id": i, **segment} {"id": i, **segment}
for i, segment in enumerate( for i, segment in enumerate(
current_segments, start=len(all_segments) current_segments, start=len(all_segments)
) )
] ]
) )
all_tokens.extend( all_tokens.extend(
[token for segment in current_segments for token in segment["tokens"]] [
) token
for segment in current_segments
for token in segment["tokens"]
]
)
if not condition_on_previous_text or result.temperature > 0.5: if not condition_on_previous_text or result.temperature > 0.5:
# do not feed the prompt tokens if a high temperature was used # do not feed the prompt tokens if a high temperature was used
prompt_reset_since = len(all_tokens) prompt_reset_since = len(all_tokens)
# update progress bar # update progress bar
pbar.update(min(content_frames, seek) - previous_seek) pbar.update(min(content_frames, seek) - previous_seek)
return dict( return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),

View File

@ -115,7 +115,7 @@ class ResidualAttentionBlock(nn.Module):
self.cross_attn_ln(x), xa, kv_cache=cross_kv self.cross_attn_ln(x), xa, kv_cache=cross_kv
) )
x += y x += y
x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))).astype(x.dtype)) x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))))
return x, (kv, cross_kv), cross_qk return x, (kv, cross_kv), cross_qk
@ -138,8 +138,8 @@ class AudioEncoder(nn.Module):
self.ln_post = nn.LayerNorm(n_state) self.ln_post = nn.LayerNorm(n_state)
def __call__(self, x): def __call__(self, x):
x = nn.gelu(self.conv1(x)).astype(x.dtype) x = nn.gelu(self.conv1(x))
x = nn.gelu(self.conv2(x)).astype(x.dtype) x = nn.gelu(self.conv2(x))
assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape" assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape"
x = x + self._positional_embedding x = x + self._positional_embedding