cleanup whisper a little (#639)

This commit is contained in:
Awni Hannun 2024-03-30 13:13:58 -07:00 committed by GitHub
parent f6283ef7ce
commit 78c431dc25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 237 additions and 221 deletions

View File

@ -239,12 +239,13 @@ def generate(
), ),
range(max_tokens), range(max_tokens),
): ):
if token == tokenizer.eos_token_id: token = token.item()
break
if n == 0: if n == 0:
prompt_time = time.perf_counter() - tic prompt_time = time.perf_counter() - tic
tic = time.perf_counter() tic = time.perf_counter()
tokens.append(token.item()) if token == tokenizer.eos_token_id:
break
tokens.append(token)
if verbose: if verbose:
s = tokenizer.decode(tokens) s = tokenizer.decode(tokens)

View File

@ -91,7 +91,8 @@ def _download(url: str, root: str) -> str:
output.write(buffer) output.write(buffer)
loop.update(len(buffer)) loop.update(len(buffer))
model_bytes = open(download_target, "rb").read() with open(download_target, "rb") as fid:
model_bytes = fid.read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError( raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."

View File

@ -297,7 +297,7 @@ class TestWhisper(unittest.TestCase):
"temperature": 0.0, "temperature": 0.0,
"avg_logprob": -0.1350895343440594, "avg_logprob": -0.1350895343440594,
"compression_ratio": 1.6208333333333333, "compression_ratio": 1.6208333333333333,
"no_speech_prob": 0.002246702555567026, "no_speech_prob": 0.009053784422576427,
} }
def check_segment(seg, expected): def check_segment(seg, expected):

View File

@ -58,7 +58,7 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
except CalledProcessError as e: except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 return mx.array(np.frombuffer(out, np.int16)).flatten().astype(mx.float32) / 32768.0
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
@ -73,8 +73,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
if array.shape[axis] < length: if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis]) pad_widths[axis] = (0, length - array.shape[axis])
pad_fn = mx.pad if isinstance(array, mx.array) else np.pad array = mx.pad(array, pad_widths)
array = pad_fn(array, pad_widths)
return array return array
@ -154,9 +153,9 @@ def log_mel_spectrogram(
""" """
device = mx.default_device() device = mx.default_device()
mx.set_default_device(mx.cpu) mx.set_default_device(mx.cpu)
if not isinstance(audio, mx.array):
if isinstance(audio, str): if isinstance(audio, str):
audio = load_audio(audio) audio = load_audio(audio)
elif not isinstance(audio, mx.array):
audio = mx.array(audio) audio = mx.array(audio)
if padding > 0: if padding > 0:

View File

@ -280,22 +280,13 @@ def transcribe(
total=content_frames, unit="frames", disable=verbose is not False total=content_frames, unit="frames", disable=verbose is not False
) as pbar: ) as pbar:
last_speech_timestamp = 0.0 last_speech_timestamp = 0.0
# NOTE: This loop is obscurely flattened to make the diff readable. for seek_clip_start, seek_clip_end in seek_clips:
# A later commit should turn this into a simpler nested loop. while seek < seek_clip_end:
# for seek_clip_start, seek_clip_end in seek_clips:
# while seek < seek_clip_end
while clip_idx < len(seek_clips):
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
if seek < seek_clip_start:
seek = seek_clip_start
if seek >= seek_clip_end:
clip_idx += 1
if clip_idx < len(seek_clips):
seek = seek_clips[clip_idx][0]
continue
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE) window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek) segment_size = min(
N_FRAMES, content_frames - seek, seek_clip_end - seek
)
mel_segment = mel[seek : seek + segment_size] mel_segment = mel[seek : seek + segment_size]
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES, axis=-2).astype(dtype) mel_segment = pad_or_trim(mel_segment, N_FRAMES, axis=-2).astype(dtype)
@ -315,7 +306,9 @@ def transcribe(
should_skip = False should_skip = False
if should_skip: if should_skip:
seek += segment_size # fast-forward to the next segment boundary seek += (
segment_size # fast-forward to the next segment boundary
)
continue continue
previous_seek = seek previous_seek = seek
@ -337,7 +330,9 @@ def transcribe(
def is_segment_anomaly(segment: Optional[dict]) -> bool: def is_segment_anomaly(segment: Optional[dict]) -> bool:
if segment is None or not segment["words"]: if segment is None or not segment["words"]:
return False return False
words = [w for w in segment["words"] if w["word"] not in punctuation] words = [
w for w in segment["words"] if w["word"] not in punctuation
]
words = words[:8] words = words[:8]
score = sum(word_anomaly_score(w) for w in words) score = sum(word_anomaly_score(w) for w in words)
return score >= 3 or score + 0.01 >= len(words) return score >= 3 or score + 0.01 >= len(words)
@ -346,7 +341,10 @@ def transcribe(
return next((s for s in segments if s["words"]), None) return next((s for s in segments if s["words"]), None)
timestamp_tokens = tokens >= tokenizer.timestamp_begin timestamp_tokens = tokens >= tokenizer.timestamp_begin
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] single_timestamp_ending = timestamp_tokens[-2:].tolist() == [
False,
True,
]
consecutive = np.where( consecutive = np.where(
np.logical_and(timestamp_tokens[:-1], timestamp_tokens[1:]) np.logical_and(timestamp_tokens[:-1], timestamp_tokens[1:])
@ -369,7 +367,8 @@ def transcribe(
) )
current_segments.append( current_segments.append(
new_segment( new_segment(
start=time_offset + start_timestamp_pos * time_precision, start=time_offset
+ start_timestamp_pos * time_precision,
end=time_offset + end_timestamp_pos * time_precision, end=time_offset + end_timestamp_pos * time_precision,
tokens=sliced_tokens, tokens=sliced_tokens,
result=result, result=result,
@ -431,7 +430,10 @@ def transcribe(
threshold = hallucination_silence_threshold threshold = hallucination_silence_threshold
if not single_timestamp_ending: if not single_timestamp_ending:
last_word_end = _get_end(current_segments) last_word_end = _get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset: if (
last_word_end is not None
and last_word_end > time_offset
):
remaining_duration = window_end_time - last_word_end remaining_duration = window_end_time - last_word_end
if remaining_duration > threshold: if remaining_duration > threshold:
seek = round(last_word_end * FRAMES_PER_SECOND) seek = round(last_word_end * FRAMES_PER_SECOND)
@ -440,7 +442,9 @@ def transcribe(
# if first segment might be a hallucination, skip leading silence # if first segment might be a hallucination, skip leading silence
first_segment = next_words_segment(current_segments) first_segment = next_words_segment(current_segments)
if first_segment is not None and is_segment_anomaly(first_segment): if first_segment is not None and is_segment_anomaly(
first_segment
):
gap = first_segment["start"] - time_offset gap = first_segment["start"] - time_offset
if gap > threshold: if gap > threshold:
seek = previous_seek + round(gap * FRAMES_PER_SECOND) seek = previous_seek + round(gap * FRAMES_PER_SECOND)
@ -488,13 +492,20 @@ def transcribe(
if verbose: if verbose:
for segment in current_segments: for segment in current_segments:
start, end, text = segment["start"], segment["end"], segment["text"] start, end, text = (
segment["start"],
segment["end"],
segment["text"],
)
line = f"[{_format_timestamp(start)} --> {_format_timestamp(end)}] {text}" line = f"[{_format_timestamp(start)} --> {_format_timestamp(end)}] {text}"
print(make_safe(line)) print(make_safe(line))
# if a segment is instantaneous or does not contain text, clear it # if a segment is instantaneous or does not contain text, clear it
for i, segment in enumerate(current_segments): for i, segment in enumerate(current_segments):
if segment["start"] == segment["end"] or segment["text"].strip() == "": if (
segment["start"] == segment["end"]
or segment["text"].strip() == ""
):
segment["text"] = "" segment["text"] = ""
segment["tokens"] = [] segment["tokens"] = []
segment["words"] = [] segment["words"] = []
@ -508,7 +519,11 @@ def transcribe(
] ]
) )
all_tokens.extend( all_tokens.extend(
[token for segment in current_segments for token in segment["tokens"]] [
token
for segment in current_segments
for token in segment["tokens"]
]
) )
if not condition_on_previous_text or result.temperature > 0.5: if not condition_on_previous_text or result.temperature > 0.5:

View File

@ -115,7 +115,7 @@ class ResidualAttentionBlock(nn.Module):
self.cross_attn_ln(x), xa, kv_cache=cross_kv self.cross_attn_ln(x), xa, kv_cache=cross_kv
) )
x += y x += y
x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))).astype(x.dtype)) x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))))
return x, (kv, cross_kv), cross_qk return x, (kv, cross_kv), cross_qk
@ -138,8 +138,8 @@ class AudioEncoder(nn.Module):
self.ln_post = nn.LayerNorm(n_state) self.ln_post = nn.LayerNorm(n_state)
def __call__(self, x): def __call__(self, x):
x = nn.gelu(self.conv1(x)).astype(x.dtype) x = nn.gelu(self.conv1(x))
x = nn.gelu(self.conv2(x)).astype(x.dtype) x = nn.gelu(self.conv2(x))
assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape" assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape"
x = x + self._positional_embedding x = x + self._positional_embedding