diff --git a/whisper/mlx_whisper/_version.py b/whisper/mlx_whisper/_version.py index 45e522d1..8280e038 100644 --- a/whisper/mlx_whisper/_version.py +++ b/whisper/mlx_whisper/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.4.0" +__version__ = "0.4.1" diff --git a/whisper/mlx_whisper/decoding.py b/whisper/mlx_whisper/decoding.py index 6bf975d5..4e060cd5 100644 --- a/whisper/mlx_whisper/decoding.py +++ b/whisper/mlx_whisper/decoding.py @@ -589,35 +589,34 @@ class DecodingTask: ) return tokens, completed, sum_logprobs, pre_logits - try: - tokens, completed, sum_logprobs, pre_logits = _step( - tokens, audio_features, tokens, sum_logprobs + tokens, completed, sum_logprobs, pre_logits = _step( + tokens, audio_features, tokens, sum_logprobs + ) + if self.tokenizer.no_speech is not None: # compute no_speech_probs + probs_at_sot = mx.softmax(pre_logits[:, self.sot_index], axis=-1) + no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech] + else: + no_speech_probs = mx.full(n_batch, mx.nan) + mx.async_eval(completed, tokens, sum_logprobs, no_speech_probs) + + for i in range(1, self.sample_len): + inputs = tokens[:, -1:] + if tokens.shape[-1] > self.n_ctx: + break + next_tokens, next_completed, next_sum_logprobs, _ = _step( + inputs, audio_features, tokens, sum_logprobs ) - if self.tokenizer.no_speech is not None: # compute no_speech_probs - probs_at_sot = mx.softmax(pre_logits[:, self.sot_index], axis=-1) - no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech] - else: - no_speech_probs = mx.full(n_batch, mx.nan) - mx.async_eval(completed, tokens, sum_logprobs, no_speech_probs) - - for i in range(1, self.sample_len): - inputs = tokens[:, -1:] - next_tokens, next_completed, next_sum_logprobs, _ = _step( - inputs, audio_features, tokens, sum_logprobs - ) - mx.async_eval(next_completed, next_tokens, next_sum_logprobs) - if completed or tokens.shape[-1] > self.n_ctx: - break - tokens = next_tokens - completed = next_completed - sum_logprobs = next_sum_logprobs - - finally: - self.inference.reset() + mx.async_eval(next_completed, next_tokens, next_sum_logprobs) + if completed: + break + tokens = next_tokens + completed = next_completed + sum_logprobs = next_sum_logprobs return tokens, sum_logprobs, no_speech_probs def run(self, mel: mx.array) -> List[DecodingResult]: + self.inference.reset() self.decoder.reset() tokenizer: Tokenizer = self.tokenizer n_audio: int = mel.shape[0]