mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-05 00:01:13 +08:00
fix (#1082)
This commit is contained in:
parent
0f799947d0
commit
29c954f4cb
@ -1,3 +1,3 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
__version__ = "0.4.0"
|
||||
__version__ = "0.4.1"
|
||||
|
@ -589,7 +589,6 @@ class DecodingTask:
|
||||
)
|
||||
return tokens, completed, sum_logprobs, pre_logits
|
||||
|
||||
try:
|
||||
tokens, completed, sum_logprobs, pre_logits = _step(
|
||||
tokens, audio_features, tokens, sum_logprobs
|
||||
)
|
||||
@ -602,22 +601,22 @@ class DecodingTask:
|
||||
|
||||
for i in range(1, self.sample_len):
|
||||
inputs = tokens[:, -1:]
|
||||
if tokens.shape[-1] > self.n_ctx:
|
||||
break
|
||||
next_tokens, next_completed, next_sum_logprobs, _ = _step(
|
||||
inputs, audio_features, tokens, sum_logprobs
|
||||
)
|
||||
mx.async_eval(next_completed, next_tokens, next_sum_logprobs)
|
||||
if completed or tokens.shape[-1] > self.n_ctx:
|
||||
if completed:
|
||||
break
|
||||
tokens = next_tokens
|
||||
completed = next_completed
|
||||
sum_logprobs = next_sum_logprobs
|
||||
|
||||
finally:
|
||||
self.inference.reset()
|
||||
|
||||
return tokens, sum_logprobs, no_speech_probs
|
||||
|
||||
def run(self, mel: mx.array) -> List[DecodingResult]:
|
||||
self.inference.reset()
|
||||
self.decoder.reset()
|
||||
tokenizer: Tokenizer = self.tokenizer
|
||||
n_audio: int = mel.shape[0]
|
||||
|
Loading…
Reference in New Issue
Block a user