fix temperature based sampling

This commit is contained in:
Awni Hannun 2025-08-25 09:43:54 -07:00
parent 4b2a0df237
commit 04a7c07a9e
2 changed files with 5 additions and 6 deletions

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
__version__ = "0.4.1" __version__ = "0.4.2"

View File

@ -265,7 +265,7 @@ class GreedyDecoder(TokenDecoder):
else: else:
next_tokens = categorical(logits, self.temperature) next_tokens = categorical(logits, self.temperature)
logprobs = logits - mx.logsumexp(logits, axis=-1) logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
current_logprobs = logprobs[mx.arange(logprobs.shape[0]), next_tokens] current_logprobs = logprobs[mx.arange(logprobs.shape[0]), next_tokens]
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot) sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
@ -380,7 +380,7 @@ class ApplyTimestampRules(LogitFilter):
# if sum of probability over timestamps is above any other token, sample timestamp # if sum of probability over timestamps is above any other token, sample timestamp
mask = mx.array(mask) mask = mx.array(mask)
logprobs = logits - mx.logsumexp(logits, axis=-1) logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
timestamp_logprob = logprobs[:, self.tokenizer.timestamp_begin :].logsumexp( timestamp_logprob = logprobs[:, self.tokenizer.timestamp_begin :].logsumexp(
axis=-1, keepdims=True axis=-1, keepdims=True
) )
@ -603,6 +603,7 @@ class DecodingTask:
inputs = tokens[:, -1:] inputs = tokens[:, -1:]
if tokens.shape[-1] > self.n_ctx: if tokens.shape[-1] > self.n_ctx:
break break
next_tokens, next_completed, next_sum_logprobs, _ = _step( next_tokens, next_completed, next_sum_logprobs, _ = _step(
inputs, audio_features, tokens, sum_logprobs inputs, audio_features, tokens, sum_logprobs
) )
@ -643,9 +644,7 @@ class DecodingTask:
tokens = mx.broadcast_to( tokens = mx.broadcast_to(
tokens, [n_audio, self.n_group, len(self.initial_tokens)] tokens, [n_audio, self.n_group, len(self.initial_tokens)]
) )
tokens = tokens.reshape( tokens = tokens.reshape((n_audio * self.n_group, len(self.initial_tokens)))
tokens, (n_audio * self.n_group, len(self.initial_tokens))
)
# call the main sampling loop # call the main sampling loop
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens) tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)