diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 4f5f8b15..bf42a5d1 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -239,12 +239,13 @@ def generate( ), range(max_tokens), ): - if token == tokenizer.eos_token_id: - break + token = token.item() if n == 0: prompt_time = time.perf_counter() - tic tic = time.perf_counter() - tokens.append(token.item()) + if token == tokenizer.eos_token_id: + break + tokens.append(token) if verbose: s = tokenizer.decode(tokens) diff --git a/whisper/convert.py b/whisper/convert.py index c15623d1..824b0986 100644 --- a/whisper/convert.py +++ b/whisper/convert.py @@ -91,7 +91,8 @@ def _download(url: str, root: str) -> str: output.write(buffer) loop.update(len(buffer)) - model_bytes = open(download_target, "rb").read() + with open(download_target, "rb") as fid: + model_bytes = fid.read() if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: raise RuntimeError( "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." diff --git a/whisper/test.py b/whisper/test.py index 9fc3f0d5..dc097492 100644 --- a/whisper/test.py +++ b/whisper/test.py @@ -297,7 +297,7 @@ class TestWhisper(unittest.TestCase): "temperature": 0.0, "avg_logprob": -0.1350895343440594, "compression_ratio": 1.6208333333333333, - "no_speech_prob": 0.002246702555567026, + "no_speech_prob": 0.009053784422576427, } def check_segment(seg, expected): diff --git a/whisper/whisper/audio.py b/whisper/whisper/audio.py index 5e63fb7f..b7e4217e 100644 --- a/whisper/whisper/audio.py +++ b/whisper/whisper/audio.py @@ -58,7 +58,7 @@ def load_audio(file: str, sr: int = SAMPLE_RATE): except CalledProcessError as e: raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e - return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 + return mx.array(np.frombuffer(out, np.int16)).flatten().astype(mx.float32) / 32768.0 def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): @@ -73,8 +73,7 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): if array.shape[axis] < length: pad_widths = [(0, 0)] * array.ndim pad_widths[axis] = (0, length - array.shape[axis]) - pad_fn = mx.pad if isinstance(array, mx.array) else np.pad - array = pad_fn(array, pad_widths) + array = mx.pad(array, pad_widths) return array @@ -154,9 +153,9 @@ def log_mel_spectrogram( """ device = mx.default_device() mx.set_default_device(mx.cpu) - if not isinstance(audio, mx.array): - if isinstance(audio, str): - audio = load_audio(audio) + if isinstance(audio, str): + audio = load_audio(audio) + elif not isinstance(audio, mx.array): audio = mx.array(audio) if padding > 0: diff --git a/whisper/whisper/transcribe.py b/whisper/whisper/transcribe.py index 43f07802..786b4232 100644 --- a/whisper/whisper/transcribe.py +++ b/whisper/whisper/transcribe.py @@ -280,243 +280,258 @@ def transcribe( total=content_frames, unit="frames", disable=verbose is not False ) as pbar: last_speech_timestamp = 0.0 - # NOTE: This loop is obscurely flattened to make the diff readable. - # A later commit should turn this into a simpler nested loop. - # for seek_clip_start, seek_clip_end in seek_clips: - # while seek < seek_clip_end - while clip_idx < len(seek_clips): - seek_clip_start, seek_clip_end = seek_clips[clip_idx] - if seek < seek_clip_start: - seek = seek_clip_start - if seek >= seek_clip_end: - clip_idx += 1 - if clip_idx < len(seek_clips): - seek = seek_clips[clip_idx][0] - continue - time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) - window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE) - segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek) - mel_segment = mel[seek : seek + segment_size] - segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE - mel_segment = pad_or_trim(mel_segment, N_FRAMES, axis=-2).astype(dtype) + for seek_clip_start, seek_clip_end in seek_clips: + while seek < seek_clip_end: + time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) + window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE) + segment_size = min( + N_FRAMES, content_frames - seek, seek_clip_end - seek + ) + mel_segment = mel[seek : seek + segment_size] + segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE + mel_segment = pad_or_trim(mel_segment, N_FRAMES, axis=-2).astype(dtype) - decode_options["prompt"] = all_tokens[prompt_reset_since:] - result: DecodingResult = decode_with_fallback(mel_segment) - tokens = np.array(result.tokens) + decode_options["prompt"] = all_tokens[prompt_reset_since:] + result: DecodingResult = decode_with_fallback(mel_segment) + tokens = np.array(result.tokens) - if no_speech_threshold is not None: - # no voice activity check - should_skip = result.no_speech_prob > no_speech_threshold - if ( - logprob_threshold is not None - and result.avg_logprob > logprob_threshold - ): - # don't skip if the logprob is high enough, despite the no_speech_prob - should_skip = False + if no_speech_threshold is not None: + # no voice activity check + should_skip = result.no_speech_prob > no_speech_threshold + if ( + logprob_threshold is not None + and result.avg_logprob > logprob_threshold + ): + # don't skip if the logprob is high enough, despite the no_speech_prob + should_skip = False - if should_skip: - seek += segment_size # fast-forward to the next segment boundary - continue + if should_skip: + seek += ( + segment_size # fast-forward to the next segment boundary + ) + continue - previous_seek = seek - current_segments = [] + previous_seek = seek + current_segments = [] - # anomalous words are very long/short/improbable - def word_anomaly_score(word: dict) -> float: - probability = word.get("probability", 0.0) - duration = word["end"] - word["start"] - score = 0.0 - if probability < 0.15: - score += 1.0 - if duration < 0.133: - score += (0.133 - duration) * 15 - if duration > 2.0: - score += duration - 2.0 - return score + # anomalous words are very long/short/improbable + def word_anomaly_score(word: dict) -> float: + probability = word.get("probability", 0.0) + duration = word["end"] - word["start"] + score = 0.0 + if probability < 0.15: + score += 1.0 + if duration < 0.133: + score += (0.133 - duration) * 15 + if duration > 2.0: + score += duration - 2.0 + return score - def is_segment_anomaly(segment: Optional[dict]) -> bool: - if segment is None or not segment["words"]: - return False - words = [w for w in segment["words"] if w["word"] not in punctuation] - words = words[:8] - score = sum(word_anomaly_score(w) for w in words) - return score >= 3 or score + 0.01 >= len(words) + def is_segment_anomaly(segment: Optional[dict]) -> bool: + if segment is None or not segment["words"]: + return False + words = [ + w for w in segment["words"] if w["word"] not in punctuation + ] + words = words[:8] + score = sum(word_anomaly_score(w) for w in words) + return score >= 3 or score + 0.01 >= len(words) - def next_words_segment(segments: List[dict]) -> Optional[dict]: - return next((s for s in segments if s["words"]), None) + def next_words_segment(segments: List[dict]) -> Optional[dict]: + return next((s for s in segments if s["words"]), None) - timestamp_tokens = tokens >= tokenizer.timestamp_begin - single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] + timestamp_tokens = tokens >= tokenizer.timestamp_begin + single_timestamp_ending = timestamp_tokens[-2:].tolist() == [ + False, + True, + ] - consecutive = np.where( - np.logical_and(timestamp_tokens[:-1], timestamp_tokens[1:]) - )[0] - consecutive += 1 - if len(consecutive) > 0: - # if the output contains two consecutive timestamp tokens - slices = consecutive.tolist() - if single_timestamp_ending: - slices.append(len(tokens)) + consecutive = np.where( + np.logical_and(timestamp_tokens[:-1], timestamp_tokens[1:]) + )[0] + consecutive += 1 + if len(consecutive) > 0: + # if the output contains two consecutive timestamp tokens + slices = consecutive.tolist() + if single_timestamp_ending: + slices.append(len(tokens)) + + last_slice = 0 + for current_slice in slices: + sliced_tokens = tokens[last_slice:current_slice] + start_timestamp_pos = ( + sliced_tokens[0].item() - tokenizer.timestamp_begin + ) + end_timestamp_pos = ( + sliced_tokens[-1].item() - tokenizer.timestamp_begin + ) + current_segments.append( + new_segment( + start=time_offset + + start_timestamp_pos * time_precision, + end=time_offset + end_timestamp_pos * time_precision, + tokens=sliced_tokens, + result=result, + ) + ) + last_slice = current_slice + + if single_timestamp_ending: + # single timestamp at the end means no speech after the last timestamp. + seek += segment_size + else: + # otherwise, ignore the unfinished segment and seek to the last timestamp + last_timestamp_pos = ( + tokens[last_slice - 1].item() - tokenizer.timestamp_begin + ) + seek += last_timestamp_pos * input_stride + else: + duration = segment_duration + timestamps = tokens[timestamp_tokens.nonzero()[0]] + if ( + len(timestamps) > 0 + and timestamps[-1].item() != tokenizer.timestamp_begin + ): + # no consecutive timestamps but it has a timestamp; use the last one. + last_timestamp_pos = ( + timestamps[-1].item() - tokenizer.timestamp_begin + ) + duration = last_timestamp_pos * time_precision - last_slice = 0 - for current_slice in slices: - sliced_tokens = tokens[last_slice:current_slice] - start_timestamp_pos = ( - sliced_tokens[0].item() - tokenizer.timestamp_begin - ) - end_timestamp_pos = ( - sliced_tokens[-1].item() - tokenizer.timestamp_begin - ) current_segments.append( new_segment( - start=time_offset + start_timestamp_pos * time_precision, - end=time_offset + end_timestamp_pos * time_precision, - tokens=sliced_tokens, + start=time_offset, + end=time_offset + duration, + tokens=tokens, result=result, ) ) - last_slice = current_slice - - if single_timestamp_ending: - # single timestamp at the end means no speech after the last timestamp. seek += segment_size - else: - # otherwise, ignore the unfinished segment and seek to the last timestamp - last_timestamp_pos = ( - tokens[last_slice - 1].item() - tokenizer.timestamp_begin + + if word_timestamps: + add_word_timestamps( + segments=current_segments, + model=model, + tokenizer=tokenizer, + mel=mel_segment, + num_frames=segment_size, + prepend_punctuations=prepend_punctuations, + append_punctuations=append_punctuations, + last_speech_timestamp=last_speech_timestamp, ) - seek += last_timestamp_pos * input_stride - else: - duration = segment_duration - timestamps = tokens[timestamp_tokens.nonzero()[0]] - if ( - len(timestamps) > 0 - and timestamps[-1].item() != tokenizer.timestamp_begin - ): - # no consecutive timestamps but it has a timestamp; use the last one. - last_timestamp_pos = ( - timestamps[-1].item() - tokenizer.timestamp_begin - ) - duration = last_timestamp_pos * time_precision - current_segments.append( - new_segment( - start=time_offset, - end=time_offset + duration, - tokens=tokens, - result=result, - ) - ) - seek += segment_size - - if word_timestamps: - add_word_timestamps( - segments=current_segments, - model=model, - tokenizer=tokenizer, - mel=mel_segment, - num_frames=segment_size, - prepend_punctuations=prepend_punctuations, - append_punctuations=append_punctuations, - last_speech_timestamp=last_speech_timestamp, - ) - - if not single_timestamp_ending: - last_word_end = _get_end(current_segments) - if last_word_end is not None and last_word_end > time_offset: - seek = round(last_word_end * FRAMES_PER_SECOND) - - # skip silence before possible hallucinations - if hallucination_silence_threshold is not None: - threshold = hallucination_silence_threshold if not single_timestamp_ending: last_word_end = _get_end(current_segments) if last_word_end is not None and last_word_end > time_offset: - remaining_duration = window_end_time - last_word_end - if remaining_duration > threshold: - seek = round(last_word_end * FRAMES_PER_SECOND) - else: - seek = previous_seek + segment_size + seek = round(last_word_end * FRAMES_PER_SECOND) - # if first segment might be a hallucination, skip leading silence - first_segment = next_words_segment(current_segments) - if first_segment is not None and is_segment_anomaly(first_segment): - gap = first_segment["start"] - time_offset - if gap > threshold: - seek = previous_seek + round(gap * FRAMES_PER_SECOND) - continue + # skip silence before possible hallucinations + if hallucination_silence_threshold is not None: + threshold = hallucination_silence_threshold + if not single_timestamp_ending: + last_word_end = _get_end(current_segments) + if ( + last_word_end is not None + and last_word_end > time_offset + ): + remaining_duration = window_end_time - last_word_end + if remaining_duration > threshold: + seek = round(last_word_end * FRAMES_PER_SECOND) + else: + seek = previous_seek + segment_size - # skip silence before any possible hallucination that is surrounded - # by silence or more hallucinations - hal_last_end = last_speech_timestamp - for si in range(len(current_segments)): - segment = current_segments[si] - if not segment["words"]: - continue - if is_segment_anomaly(segment): - next_segment = next_words_segment( - current_segments[si + 1 :] - ) - if next_segment is not None: - hal_next_start = next_segment["words"][0]["start"] - else: - hal_next_start = time_offset + segment_duration - silence_before = ( - segment["start"] - hal_last_end > threshold - or segment["start"] < threshold - or segment["start"] - time_offset < 2.0 - ) - silence_after = ( - hal_next_start - segment["end"] > threshold - or is_segment_anomaly(next_segment) - or window_end_time - segment["end"] < 2.0 - ) - if silence_before and silence_after: - seek = round( - max(time_offset + 1, segment["start"]) - * FRAMES_PER_SECOND + # if first segment might be a hallucination, skip leading silence + first_segment = next_words_segment(current_segments) + if first_segment is not None and is_segment_anomaly( + first_segment + ): + gap = first_segment["start"] - time_offset + if gap > threshold: + seek = previous_seek + round(gap * FRAMES_PER_SECOND) + continue + + # skip silence before any possible hallucination that is surrounded + # by silence or more hallucinations + hal_last_end = last_speech_timestamp + for si in range(len(current_segments)): + segment = current_segments[si] + if not segment["words"]: + continue + if is_segment_anomaly(segment): + next_segment = next_words_segment( + current_segments[si + 1 :] ) - if content_duration - segment["end"] < threshold: - seek = content_frames - current_segments[si:] = [] - break - hal_last_end = segment["end"] + if next_segment is not None: + hal_next_start = next_segment["words"][0]["start"] + else: + hal_next_start = time_offset + segment_duration + silence_before = ( + segment["start"] - hal_last_end > threshold + or segment["start"] < threshold + or segment["start"] - time_offset < 2.0 + ) + silence_after = ( + hal_next_start - segment["end"] > threshold + or is_segment_anomaly(next_segment) + or window_end_time - segment["end"] < 2.0 + ) + if silence_before and silence_after: + seek = round( + max(time_offset + 1, segment["start"]) + * FRAMES_PER_SECOND + ) + if content_duration - segment["end"] < threshold: + seek = content_frames + current_segments[si:] = [] + break + hal_last_end = segment["end"] - last_word_end = _get_end(current_segments) - if last_word_end is not None: - last_speech_timestamp = last_word_end + last_word_end = _get_end(current_segments) + if last_word_end is not None: + last_speech_timestamp = last_word_end - if verbose: - for segment in current_segments: - start, end, text = segment["start"], segment["end"], segment["text"] - line = f"[{_format_timestamp(start)} --> {_format_timestamp(end)}] {text}" - print(make_safe(line)) + if verbose: + for segment in current_segments: + start, end, text = ( + segment["start"], + segment["end"], + segment["text"], + ) + line = f"[{_format_timestamp(start)} --> {_format_timestamp(end)}] {text}" + print(make_safe(line)) - # if a segment is instantaneous or does not contain text, clear it - for i, segment in enumerate(current_segments): - if segment["start"] == segment["end"] or segment["text"].strip() == "": - segment["text"] = "" - segment["tokens"] = [] - segment["words"] = [] + # if a segment is instantaneous or does not contain text, clear it + for i, segment in enumerate(current_segments): + if ( + segment["start"] == segment["end"] + or segment["text"].strip() == "" + ): + segment["text"] = "" + segment["tokens"] = [] + segment["words"] = [] - all_segments.extend( - [ - {"id": i, **segment} - for i, segment in enumerate( - current_segments, start=len(all_segments) - ) - ] - ) - all_tokens.extend( - [token for segment in current_segments for token in segment["tokens"]] - ) + all_segments.extend( + [ + {"id": i, **segment} + for i, segment in enumerate( + current_segments, start=len(all_segments) + ) + ] + ) + all_tokens.extend( + [ + token + for segment in current_segments + for token in segment["tokens"] + ] + ) - if not condition_on_previous_text or result.temperature > 0.5: - # do not feed the prompt tokens if a high temperature was used - prompt_reset_since = len(all_tokens) + if not condition_on_previous_text or result.temperature > 0.5: + # do not feed the prompt tokens if a high temperature was used + prompt_reset_since = len(all_tokens) - # update progress bar - pbar.update(min(content_frames, seek) - previous_seek) + # update progress bar + pbar.update(min(content_frames, seek) - previous_seek) return dict( text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), diff --git a/whisper/whisper/whisper.py b/whisper/whisper/whisper.py index 37495130..f5cc3888 100644 --- a/whisper/whisper/whisper.py +++ b/whisper/whisper/whisper.py @@ -115,7 +115,7 @@ class ResidualAttentionBlock(nn.Module): self.cross_attn_ln(x), xa, kv_cache=cross_kv ) x += y - x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x))).astype(x.dtype)) + x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x)))) return x, (kv, cross_kv), cross_qk @@ -138,8 +138,8 @@ class AudioEncoder(nn.Module): self.ln_post = nn.LayerNorm(n_state) def __call__(self, x): - x = nn.gelu(self.conv1(x)).astype(x.dtype) - x = nn.gelu(self.conv2(x)).astype(x.dtype) + x = nn.gelu(self.conv1(x)) + x = nn.gelu(self.conv2(x)) assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape" x = x + self._positional_embedding