mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 12:26:07 +08:00
fix temperature based sampling
This commit is contained in:
parent
4b2a0df237
commit
04a7c07a9e
@ -1,3 +1,3 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
__version__ = "0.4.1"
|
__version__ = "0.4.2"
|
||||||
|
@ -265,7 +265,7 @@ class GreedyDecoder(TokenDecoder):
|
|||||||
else:
|
else:
|
||||||
next_tokens = categorical(logits, self.temperature)
|
next_tokens = categorical(logits, self.temperature)
|
||||||
|
|
||||||
logprobs = logits - mx.logsumexp(logits, axis=-1)
|
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
|
||||||
|
|
||||||
current_logprobs = logprobs[mx.arange(logprobs.shape[0]), next_tokens]
|
current_logprobs = logprobs[mx.arange(logprobs.shape[0]), next_tokens]
|
||||||
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
|
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
|
||||||
@ -380,7 +380,7 @@ class ApplyTimestampRules(LogitFilter):
|
|||||||
|
|
||||||
# if sum of probability over timestamps is above any other token, sample timestamp
|
# if sum of probability over timestamps is above any other token, sample timestamp
|
||||||
mask = mx.array(mask)
|
mask = mx.array(mask)
|
||||||
logprobs = logits - mx.logsumexp(logits, axis=-1)
|
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
|
||||||
timestamp_logprob = logprobs[:, self.tokenizer.timestamp_begin :].logsumexp(
|
timestamp_logprob = logprobs[:, self.tokenizer.timestamp_begin :].logsumexp(
|
||||||
axis=-1, keepdims=True
|
axis=-1, keepdims=True
|
||||||
)
|
)
|
||||||
@ -603,6 +603,7 @@ class DecodingTask:
|
|||||||
inputs = tokens[:, -1:]
|
inputs = tokens[:, -1:]
|
||||||
if tokens.shape[-1] > self.n_ctx:
|
if tokens.shape[-1] > self.n_ctx:
|
||||||
break
|
break
|
||||||
|
|
||||||
next_tokens, next_completed, next_sum_logprobs, _ = _step(
|
next_tokens, next_completed, next_sum_logprobs, _ = _step(
|
||||||
inputs, audio_features, tokens, sum_logprobs
|
inputs, audio_features, tokens, sum_logprobs
|
||||||
)
|
)
|
||||||
@ -643,9 +644,7 @@ class DecodingTask:
|
|||||||
tokens = mx.broadcast_to(
|
tokens = mx.broadcast_to(
|
||||||
tokens, [n_audio, self.n_group, len(self.initial_tokens)]
|
tokens, [n_audio, self.n_group, len(self.initial_tokens)]
|
||||||
)
|
)
|
||||||
tokens = tokens.reshape(
|
tokens = tokens.reshape((n_audio * self.n_group, len(self.initial_tokens)))
|
||||||
tokens, (n_audio * self.n_group, len(self.initial_tokens))
|
|
||||||
)
|
|
||||||
|
|
||||||
# call the main sampling loop
|
# call the main sampling loop
|
||||||
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
|
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
|
||||||
|
Loading…
Reference in New Issue
Block a user