mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-02 23:01:15 +08:00
fix (#1082)
This commit is contained in:
parent
0f799947d0
commit
29c954f4cb
@ -1,3 +1,3 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
__version__ = "0.4.0"
|
||||
__version__ = "0.4.1"
|
||||
|
@ -589,35 +589,34 @@ class DecodingTask:
|
||||
)
|
||||
return tokens, completed, sum_logprobs, pre_logits
|
||||
|
||||
try:
|
||||
tokens, completed, sum_logprobs, pre_logits = _step(
|
||||
tokens, audio_features, tokens, sum_logprobs
|
||||
tokens, completed, sum_logprobs, pre_logits = _step(
|
||||
tokens, audio_features, tokens, sum_logprobs
|
||||
)
|
||||
if self.tokenizer.no_speech is not None: # compute no_speech_probs
|
||||
probs_at_sot = mx.softmax(pre_logits[:, self.sot_index], axis=-1)
|
||||
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech]
|
||||
else:
|
||||
no_speech_probs = mx.full(n_batch, mx.nan)
|
||||
mx.async_eval(completed, tokens, sum_logprobs, no_speech_probs)
|
||||
|
||||
for i in range(1, self.sample_len):
|
||||
inputs = tokens[:, -1:]
|
||||
if tokens.shape[-1] > self.n_ctx:
|
||||
break
|
||||
next_tokens, next_completed, next_sum_logprobs, _ = _step(
|
||||
inputs, audio_features, tokens, sum_logprobs
|
||||
)
|
||||
if self.tokenizer.no_speech is not None: # compute no_speech_probs
|
||||
probs_at_sot = mx.softmax(pre_logits[:, self.sot_index], axis=-1)
|
||||
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech]
|
||||
else:
|
||||
no_speech_probs = mx.full(n_batch, mx.nan)
|
||||
mx.async_eval(completed, tokens, sum_logprobs, no_speech_probs)
|
||||
|
||||
for i in range(1, self.sample_len):
|
||||
inputs = tokens[:, -1:]
|
||||
next_tokens, next_completed, next_sum_logprobs, _ = _step(
|
||||
inputs, audio_features, tokens, sum_logprobs
|
||||
)
|
||||
mx.async_eval(next_completed, next_tokens, next_sum_logprobs)
|
||||
if completed or tokens.shape[-1] > self.n_ctx:
|
||||
break
|
||||
tokens = next_tokens
|
||||
completed = next_completed
|
||||
sum_logprobs = next_sum_logprobs
|
||||
|
||||
finally:
|
||||
self.inference.reset()
|
||||
mx.async_eval(next_completed, next_tokens, next_sum_logprobs)
|
||||
if completed:
|
||||
break
|
||||
tokens = next_tokens
|
||||
completed = next_completed
|
||||
sum_logprobs = next_sum_logprobs
|
||||
|
||||
return tokens, sum_logprobs, no_speech_probs
|
||||
|
||||
def run(self, mel: mx.array) -> List[DecodingResult]:
|
||||
self.inference.reset()
|
||||
self.decoder.reset()
|
||||
tokenizer: Tokenizer = self.tokenizer
|
||||
n_audio: int = mel.shape[0]
|
||||
|
Loading…
Reference in New Issue
Block a user