From 04a7c07a9e832be80d9f2c359d9e50bc4feda38e Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 25 Aug 2025 09:43:54 -0700 Subject: [PATCH] fix temperature based sampling --- whisper/mlx_whisper/_version.py | 2 +- whisper/mlx_whisper/decoding.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/whisper/mlx_whisper/_version.py b/whisper/mlx_whisper/_version.py index 8280e038..e09326dd 100644 --- a/whisper/mlx_whisper/_version.py +++ b/whisper/mlx_whisper/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.4.1" +__version__ = "0.4.2" diff --git a/whisper/mlx_whisper/decoding.py b/whisper/mlx_whisper/decoding.py index 4e060cd5..814dc95c 100644 --- a/whisper/mlx_whisper/decoding.py +++ b/whisper/mlx_whisper/decoding.py @@ -265,7 +265,7 @@ class GreedyDecoder(TokenDecoder): else: next_tokens = categorical(logits, self.temperature) - logprobs = logits - mx.logsumexp(logits, axis=-1) + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) current_logprobs = logprobs[mx.arange(logprobs.shape[0]), next_tokens] sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot) @@ -380,7 +380,7 @@ class ApplyTimestampRules(LogitFilter): # if sum of probability over timestamps is above any other token, sample timestamp mask = mx.array(mask) - logprobs = logits - mx.logsumexp(logits, axis=-1) + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) timestamp_logprob = logprobs[:, self.tokenizer.timestamp_begin :].logsumexp( axis=-1, keepdims=True ) @@ -603,6 +603,7 @@ class DecodingTask: inputs = tokens[:, -1:] if tokens.shape[-1] > self.n_ctx: break + next_tokens, next_completed, next_sum_logprobs, _ = _step( inputs, audio_features, tokens, sum_logprobs ) @@ -643,9 +644,7 @@ class DecodingTask: tokens = mx.broadcast_to( tokens, [n_audio, self.n_group, len(self.initial_tokens)] ) - tokens = tokens.reshape( - tokens, (n_audio * self.n_group, len(self.initial_tokens)) - ) + tokens = tokens.reshape((n_audio * self.n_group, len(self.initial_tokens))) # call the main sampling loop tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)